Compare commits

...

42 Commits

Author SHA1 Message Date
Emily Rockman
6552b3a715 WIP 2024-01-08 15:00:25 -06:00
Emily Rockman
fb7a04358a cleanup 2024-01-08 14:57:31 -06:00
Emily Rockman
493d14e088 WIP
more WIP

WIP

    version: Optional[UnitTestNodeVersion] = None
2024-01-08 14:57:31 -06:00
Gerda Shank
56dfb34343 Merge branch 'main' into unit_testing_feature_branch 2024-01-02 17:52:54 -05:00
Gerda Shank
a0177e3333 In build command run unit tests before models (#9273) 2023-12-20 16:04:48 -05:00
Gerda Shank
4e87f4697a Merge branch 'main' into unit_testing_feature_branch 2023-12-07 16:25:00 -05:00
Gerda Shank
9a79fba8aa Make fixtures files full-fledged members of manifest and enable partial parsing (#9225) 2023-12-07 10:53:15 -05:00
Emily Rockman
a570a2c530 convert test to data_test (#9201)
* convert test to data_test

* generate proto types

* fixing tests

* add tests

* add more tests

* test cleanup

* WIP

* fix graph

* fix testing manifest

* set resource type back to test and reset unique id

* reset expected run results

* cleanup

* changie

* modify to only look for tests under columns in schema files

* stop using dashes
2023-12-07 08:03:18 -06:00
Gerda Shank
ca82f54808 Enable unit testing in non-root packages (#9184) 2023-11-30 14:42:52 -05:00
Gerda Shank
bf6bffad94 Merge branch 'main' into unit_testing_feature_branch 2023-11-30 12:26:16 -05:00
Gerda Shank
3f1ed23c11 Move unit testing to test and build commands (#9108)
* Switch to using 'test' command instead of 'unit-test'

* Remove old unit test

* Put daff changes into task/test.py

* changie

* test_type:unit

* Add unit test to build and make test

* Select test_type:data

* Add unit tets to test_graph_selector_methods.py

* Fix fqn to incude path components

* Update build test

* Remove part of message in test_csv_fixtures.py that's different on Windows

* Rename build test directory
2023-11-27 09:50:39 -05:00
Gerda Shank
e00199186e Merge branch 'main' into unit_testing_feature_branch 2023-11-21 12:31:35 -05:00
Gerda Shank
a559259e32 Merge branch 'unit_testing_feature_branch' of github.com:dbt-labs/dbt-core into unit_testing_feature_branch 2023-11-16 13:45:51 -05:00
Gerda Shank
964a7283cb Merge branch 'main' into unit_testing_feature_branch 2023-11-16 13:45:35 -05:00
Kshitij Aranke
3432436dae Fix #8652: Use seed file from disk for unit testing if rows not specified in YAML config (#9064)
Co-authored-by: Michelle Ark <MichelleArk@users.noreply.github.com>
Fix #8652: Use seed value if rows not specified
2023-11-16 16:24:55 +00:00
Michelle Ark
35f579e3eb Use daff for diff formatting in unit testing (#8984) 2023-11-15 11:27:50 -05:00
Gerda Shank
c6be2d288f Allow use of sources as unit testing inputs (#9059) 2023-11-15 10:31:32 -05:00
Gerda Shank
436dae6bb3 Merge branch 'main' into unit_testing_feature_branch 2023-11-14 17:29:01 -05:00
Jeremy Cohen
ebf48d2b50 Unit test support for state:modified and --defer (#9032)
* Add unit tests to state:modified selection

* Get defer working too yolo

* Refactor per marky suggestion

* Add changelog

* separate out unit test state tests + fix csv fixture tests

* formatting

* detect changes to fixture files with state:modified

---------

Co-authored-by: Michelle Ark <michelle.ark@dbtlabs.com>
2023-11-14 13:42:52 -05:00
Emily Rockman
3b033ac108 csv file fixtures (#9044)
* WIP

* adding tests

* fix tests

* more tests

* fix broken tests

* fix broken tests

* change to csv extension

* fix unit test yaml

* add mixed inline and file csv test

* add mixed inline and file csv test

* additional changes

* read file directly

* some cleanup and soem test fixes - wip

* fix test

* use better file searching

* fix final test

* cleanup

* use absolute path and fix tests
2023-11-09 14:35:26 -06:00
Gerda Shank
2792e0c2ce Merge branch 'main' into unit_testing_feature_branch 2023-11-08 11:26:37 -05:00
Emily Rockman
f629baa95d convert to use unit test name at top level key (#8966)
* use unit test name as top level

* remove breakpoints

* finish converting tests

* fix unit test node name

* breakpoints

* fix partial parsing bug

* comment out duplicate test

* fix test and make unique id match other uniqu id patterns

* clean up

* fix incremental test

* Update tests/functional/unit_testing/test_unit_testing.py
2023-11-03 13:59:02 -05:00
Emily Rockman
02a3dc5be3 update unit test key: unit -> unit-tests (#8988)
* WIP

* remove breakpoint

* fix tests, fix schema
2023-11-03 13:17:18 -05:00
Michelle Ark
aa91ea4c00 Support unit testing incremental models (#8891) 2023-11-01 21:08:20 -04:00
Gerda Shank
f77c2260f2 Merge branch 'main' into unit_testing_feature_branch 2023-11-01 11:26:47 -04:00
Gerda Shank
df4e4ed388 Merge branch 'main' into unit_testing_feature_branch 2023-10-12 13:35:19 -04:00
Gerda Shank
3b6f9bdef4 Enable inline csv format in unit testing (#8743) 2023-10-05 11:17:27 -04:00
Gerda Shank
5cafb96956 Merge branch 'main' into unit_testing_feature_branch 2023-10-05 10:08:09 -04:00
Gerda Shank
bb6fd3029b Merge branch 'main' into unit_testing_feature_branch 2023-10-02 15:59:24 -04:00
Gerda Shank
ac719e441c Merge branch 'main' into unit_testing_feature_branch 2023-09-26 19:57:32 -04:00
Gerda Shank
08ef90aafa Merge branch 'main' into unit_testing_feature_branch 2023-09-22 09:55:36 -04:00
Gerda Shank
3dbf0951b2 Merge branch 'main' into unit_testing_feature_branch 2023-09-13 17:36:15 -04:00
Gerda Shank
c48e34c47a Add additional functional test for unit testing selection, artifacts, etc (#8639) 2023-09-13 10:46:00 -04:00
Michelle Ark
12342ca92b unit test config: tags & meta (#8565) 2023-09-12 10:54:11 +01:00
Gerda Shank
2b376d9dba Merge branch 'main' into unit_testing_feature_branch 2023-09-11 13:01:51 -04:00
Gerda Shank
120b36e2f5 Merge branch 'main' into unit_testing_feature_branch 2023-09-07 11:08:52 -04:00
Gerda Shank
1e64f94bf0 Merge branch 'main' into unit_testing_feature_branch 2023-08-31 09:22:53 -04:00
Gerda Shank
b3bcbd5ea4 Merge branch 'main' into unit_testing_feature_branch 2023-08-30 14:51:30 -04:00
Gerda Shank
42e66fda65 8295 unit testing artifacts (#8477) 2023-08-29 17:54:54 -04:00
Gerda Shank
7ea7069999 Merge branch 'main' into unit_testing_feature_branch 2023-08-28 16:28:32 -04:00
Gerda Shank
24abc3719a Merge branch 'main' into unit_testing_feature_branch 2023-08-23 10:15:43 -04:00
Gerda Shank
181f5209a0 Initial implementation of unit testing (from pr #2911)
Co-authored-by: Michelle Ark <michelle.ark@dbtlabs.com>
2023-08-14 16:04:23 -04:00
170 changed files with 10065 additions and 1532 deletions

View File

@@ -0,0 +1,6 @@
kind: Features
body: Initial implementation of unit testing
time: 2023-08-02T14:50:11.391992-04:00
custom:
Author: gshank
Issue: "8287"

View File

@@ -0,0 +1,6 @@
kind: Features
body: Unit test manifest artifacts and selection
time: 2023-08-28T10:18:25.958929-04:00
custom:
Author: gshank
Issue: "8295"

View File

@@ -0,0 +1,6 @@
kind: Features
body: Support config with tags & meta for unit tests
time: 2023-09-06T23:47:41.059915-04:00
custom:
Author: michelleark
Issue: "8294"

View File

@@ -0,0 +1,6 @@
kind: Features
body: Enable inline csv fixtures in unit tests
time: 2023-09-28T16:32:05.573776-04:00
custom:
Author: gshank
Issue: "8626"

View File

@@ -0,0 +1,6 @@
kind: Features
body: Support unit testing incremental models
time: 2023-11-01T10:18:45.341781-04:00
custom:
Author: michelleark
Issue: "8422"

View File

@@ -0,0 +1,6 @@
kind: Features
body: Add support of csv file fixtures to unit testing
time: 2023-11-06T19:47:52.501495-06:00
custom:
Author: emmyoop
Issue: "8290"

View File

@@ -0,0 +1,6 @@
kind: Features
body: Unit tests support --defer and state:modified
time: 2023-11-07T23:10:06.376588-05:00
custom:
Author: jtcohen6
Issue: "8517"

View File

@@ -0,0 +1,6 @@
kind: Features
body: Support source inputs in unit tests
time: 2023-11-11T19:11:50.870494-05:00
custom:
Author: gshank
Issue: "8507"

View File

@@ -0,0 +1,6 @@
kind: Features
body: Use daff to render diff displayed in stdout when unit test fails
time: 2023-11-14T10:15:55.689307-05:00
custom:
Author: michelleark
Issue: "8558"

View File

@@ -0,0 +1,6 @@
kind: Features
body: Move unit testing to test command
time: 2023-11-16T14:40:06.121336-05:00
custom:
Author: gshank
Issue: "8979"

View File

@@ -0,0 +1,6 @@
kind: Features
body: Support unit tests in non-root packages
time: 2023-11-30T13:09:48.206007-05:00
custom:
Author: gshank
Issue: "8285"

View File

@@ -0,0 +1,7 @@
kind: Features
body: Convert the `tests` config to `data_tests` in both dbt_project.yml and schema files.
in schema files.
time: 2023-12-05T13:17:17.647765-06:00
custom:
Author: emmyoop
Issue: "8699"

View File

@@ -0,0 +1,6 @@
kind: Features
body: Make fixture files full-fledged parts of the manifest and enable partial parsing
time: 2023-12-05T20:04:47.117029-05:00
custom:
Author: gshank
Issue: "9067"

View File

@@ -0,0 +1,6 @@
kind: Features
body: In build command run unit tests before models
time: 2023-12-12T15:05:56.778829-05:00
custom:
Author: gshank
Issue: "9128"

View File

@@ -0,0 +1,6 @@
kind: Fixes
body: Use seed file from disk for unit testing if rows not specified in YAML config
time: 2023-11-13T15:45:35.008565Z
custom:
Author: aranke
Issue: "8652"

View File

@@ -0,0 +1,6 @@
kind: Under the Hood
body: Add unit testing functional tests
time: 2023-09-12T19:05:06.023126-04:00
custom:
Author: gshank
Issue: "8512"

View File

@@ -214,7 +214,7 @@ class BaseRelation(FakeAPIObject, Hashable):
def create_ephemeral_from( def create_ephemeral_from(
cls: Type[Self], cls: Type[Self],
relation_config: RelationConfig, relation_config: RelationConfig,
limit: Optional[int], limit: Optional[int] = None,
) -> Self: ) -> Self:
# Note that ephemeral models are based on the name. # Note that ephemeral models are based on the name.
identifier = cls.add_ephemeral_prefix(relation_config.name) identifier = cls.add_ephemeral_prefix(relation_config.name)

File diff suppressed because one or more lines are too long

View File

@@ -12,3 +12,31 @@
{{ "limit " ~ limit if limit != none }} {{ "limit " ~ limit if limit != none }}
) dbt_internal_test ) dbt_internal_test
{%- endmacro %} {%- endmacro %}
{% macro get_unit_test_sql(main_sql, expected_fixture_sql, expected_column_names) -%}
{{ adapter.dispatch('get_unit_test_sql', 'dbt')(main_sql, expected_fixture_sql, expected_column_names) }}
{%- endmacro %}
{% macro default__get_unit_test_sql(main_sql, expected_fixture_sql, expected_column_names) -%}
-- Build actual result given inputs
with dbt_internal_unit_test_actual AS (
select
{% for expected_column_name in expected_column_names %}{{expected_column_name}}{% if not loop.last -%},{% endif %}{%- endfor -%}, {{ dbt.string_literal("actual") }} as actual_or_expected
from (
{{ main_sql }}
) _dbt_internal_unit_test_actual
),
-- Build expected result
dbt_internal_unit_test_expected AS (
select
{% for expected_column_name in expected_column_names %}{{expected_column_name}}{% if not loop.last -%}, {% endif %}{%- endfor -%}, {{ dbt.string_literal("expected") }} as actual_or_expected
from (
{{ expected_fixture_sql }}
) _dbt_internal_unit_test_expected
)
-- Union actual and expected results
select * from dbt_internal_unit_test_actual
union all
select * from dbt_internal_unit_test_expected
{%- endmacro %}

View File

@@ -0,0 +1,29 @@
{%- materialization unit, default -%}
{% set relations = [] %}
{% set expected_rows = config.get('expected_rows') %}
{% set tested_expected_column_names = expected_rows[0].keys() if (expected_rows | length ) > 0 else get_columns_in_query(sql) %} %}
{%- set target_relation = this.incorporate(type='table') -%}
{%- set temp_relation = make_temp_relation(target_relation)-%}
{% do run_query(get_create_table_as_sql(True, temp_relation, get_empty_subquery_sql(sql))) %}
{%- set columns_in_relation = adapter.get_columns_in_relation(temp_relation) -%}
{%- set column_name_to_data_types = {} -%}
{%- for column in columns_in_relation -%}
{%- do column_name_to_data_types.update({column.name: column.dtype}) -%}
{%- endfor -%}
{% set unit_test_sql = get_unit_test_sql(sql, get_expected_sql(expected_rows, column_name_to_data_types), tested_expected_column_names) %}
{% call statement('main', fetch_result=True) -%}
{{ unit_test_sql }}
{%- endcall %}
{% do adapter.drop_relation(temp_relation) %}
{{ return({'relations': relations}) }}
{%- endmaterialization -%}

View File

@@ -0,0 +1,76 @@
{% macro get_fixture_sql(rows, column_name_to_data_types) %}
-- Fixture for {{ model.name }}
{% set default_row = {} %}
{%- if not column_name_to_data_types -%}
{%- set columns_in_relation = adapter.get_columns_in_relation(this) -%}
{%- set column_name_to_data_types = {} -%}
{%- for column in columns_in_relation -%}
{%- do column_name_to_data_types.update({column.name: column.dtype}) -%}
{%- endfor -%}
{%- endif -%}
{%- if not column_name_to_data_types -%}
{{ exceptions.raise_compiler_error("Not able to get columns for unit test '" ~ model.name ~ "' from relation " ~ this) }}
{%- endif -%}
{%- for column_name, column_type in column_name_to_data_types.items() -%}
{%- do default_row.update({column_name: (safe_cast("null", column_type) | trim )}) -%}
{%- endfor -%}
{%- for row in rows -%}
{%- do format_row(row, column_name_to_data_types) -%}
{%- set default_row_copy = default_row.copy() -%}
{%- do default_row_copy.update(row) -%}
select
{%- for column_name, column_value in default_row_copy.items() %} {{ column_value }} AS {{ column_name }}{% if not loop.last -%}, {%- endif %}
{%- endfor %}
{%- if not loop.last %}
union all
{% endif %}
{%- endfor -%}
{%- if (rows | length) == 0 -%}
select
{%- for column_name, column_value in default_row.items() %} {{ column_value }} AS {{ column_name }}{% if not loop.last -%},{%- endif %}
{%- endfor %}
limit 0
{%- endif -%}
{% endmacro %}
{% macro get_expected_sql(rows, column_name_to_data_types) %}
{%- if (rows | length) == 0 -%}
select * FROM dbt_internal_unit_test_actual
limit 0
{%- else -%}
{%- for row in rows -%}
{%- do format_row(row, column_name_to_data_types) -%}
select
{%- for column_name, column_value in row.items() %} {{ column_value }} AS {{ column_name }}{% if not loop.last -%}, {%- endif %}
{%- endfor %}
{%- if not loop.last %}
union all
{% endif %}
{%- endfor -%}
{%- endif -%}
{% endmacro %}
{%- macro format_row(row, column_name_to_data_types) -%}
{#-- wrap yaml strings in quotes, apply cast --#}
{%- for column_name, column_value in row.items() -%}
{% set row_update = {column_name: column_value} %}
{%- if column_value is string -%}
{%- set row_update = {column_name: safe_cast(dbt.string_literal(column_value), column_name_to_data_types[column_name]) } -%}
{%- elif column_value is none -%}
{%- set row_update = {column_name: safe_cast('null', column_name_to_data_types[column_name]) } -%}
{%- else -%}
{%- set row_update = {column_name: safe_cast(column_value, column_name_to_data_types[column_name]) } -%}
{%- endif -%}
{%- do row.update(row_update) -%}
{%- endfor -%}
{%- endmacro -%}

View File

@@ -12,7 +12,7 @@ class RelationConfigChangeAction(StrEnum):
drop = "drop" drop = "drop"
@dataclass(frozen=True, eq=True, unsafe_hash=True) @dataclass(frozen=True, eq=True, unsafe_hash=True) # type: ignore
class RelationConfigChange(RelationConfigBase, ABC): class RelationConfigChange(RelationConfigBase, ABC):
action: RelationConfigChangeAction action: RelationConfigChangeAction
context: Hashable # this is usually a RelationConfig, e.g. IndexConfig, but shouldn't be limited context: Hashable # this is usually a RelationConfig, e.g. IndexConfig, but shouldn't be limited

View File

@@ -84,6 +84,26 @@ class MacroGenerator(CallableMacroGenerator):
return self.call_macro(*args, **kwargs) return self.call_macro(*args, **kwargs)
class UnitTestMacroGenerator(MacroGenerator):
# this makes UnitTestMacroGenerator objects callable like functions
def __init__(
self,
macro_generator: MacroGenerator,
call_return_value: Any,
) -> None:
super().__init__(
macro_generator.macro,
macro_generator.context,
macro_generator.node,
macro_generator.stack,
)
self.call_return_value = call_return_value
def __call__(self, *args, **kwargs):
with self.track_call():
return self.call_return_value
# performance note: Local benmcharking (so take it with a big grain of salt!) # performance note: Local benmcharking (so take it with a big grain of salt!)
# on this indicates that it is is on average slightly slower than # on this indicates that it is is on average slightly slower than
# checking two separate patterns, but the standard deviation is smaller with # checking two separate patterns, but the standard deviation is smaller with

View File

@@ -3,6 +3,7 @@ from codecs import BOM_UTF8
import agate import agate
import datetime import datetime
import isodate import isodate
import io
import json import json
from typing import Iterable, List, Dict, Union, Optional, Any from typing import Iterable, List, Dict, Union, Optional, Any
@@ -137,6 +138,23 @@ def table_from_data_flat(data, column_names: Iterable[str]) -> agate.Table:
) )
def json_rows_from_table(table: agate.Table) -> List[Dict[str, Any]]:
"Convert a table to a list of row dict objects"
output = io.StringIO()
table.to_json(path=output) # type: ignore
return json.loads(output.getvalue())
def list_rows_from_table(table: agate.Table) -> List[Any]:
"Convert a table to a list of lists, where the first element represents the header"
rows = [[col.name for col in table.columns]]
for row in table.rows:
rows.append(list(row.values()))
return rows
def empty_table(): def empty_table():
"Returns an empty Agate table. To be used in place of None" "Returns an empty Agate table. To be used in place of None"

File diff suppressed because one or more lines are too long

View File

@@ -12,7 +12,10 @@ from dbt.flags import get_flags
from dbt.adapters.factory import get_adapter from dbt.adapters.factory import get_adapter
from dbt.clients import jinja from dbt.clients import jinja
from dbt.common.clients.system import make_directory from dbt.common.clients.system import make_directory
from dbt.context.providers import generate_runtime_model_context from dbt.context.providers import (
generate_runtime_model_context,
generate_runtime_unit_test_context,
)
from dbt.contracts.graph.manifest import Manifest, UniqueID from dbt.contracts.graph.manifest import Manifest, UniqueID
from dbt.contracts.graph.nodes import ( from dbt.contracts.graph.nodes import (
ManifestNode, ManifestNode,
@@ -21,6 +24,8 @@ from dbt.contracts.graph.nodes import (
GraphMemberNode, GraphMemberNode,
InjectedCTE, InjectedCTE,
SeedNode, SeedNode,
UnitTestNode,
UnitTestDefinition,
) )
from dbt.exceptions import ( from dbt.exceptions import (
GraphDependencyNotFoundError, GraphDependencyNotFoundError,
@@ -42,7 +47,8 @@ graph_file_name = "graph.gpickle"
def print_compile_stats(stats): def print_compile_stats(stats):
names = { names = {
NodeType.Model: "model", NodeType.Model: "model",
NodeType.Test: "test", NodeType.Test: "data test",
NodeType.Unit: "unit test",
NodeType.Snapshot: "snapshot", NodeType.Snapshot: "snapshot",
NodeType.Analysis: "analysis", NodeType.Analysis: "analysis",
NodeType.Macro: "macro", NodeType.Macro: "macro",
@@ -90,6 +96,7 @@ def _generate_stats(manifest: Manifest):
stats[NodeType.Macro] += len(manifest.macros) stats[NodeType.Macro] += len(manifest.macros)
stats[NodeType.Group] += len(manifest.groups) stats[NodeType.Group] += len(manifest.groups)
stats[NodeType.SemanticModel] += len(manifest.semantic_models) stats[NodeType.SemanticModel] += len(manifest.semantic_models)
stats[NodeType.Unit] += len(manifest.unit_tests)
# TODO: should we be counting dimensions + entities? # TODO: should we be counting dimensions + entities?
@@ -127,7 +134,7 @@ class Linker:
def __init__(self, data=None) -> None: def __init__(self, data=None) -> None:
if data is None: if data is None:
data = {} data = {}
self.graph = nx.DiGraph(**data) self.graph: nx.DiGraph = nx.DiGraph(**data)
def edges(self): def edges(self):
return self.graph.edges() return self.graph.edges()
@@ -190,6 +197,8 @@ class Linker:
self.link_node(exposure, manifest) self.link_node(exposure, manifest)
for metric in manifest.metrics.values(): for metric in manifest.metrics.values():
self.link_node(metric, manifest) self.link_node(metric, manifest)
for unit_test in manifest.unit_tests.values():
self.link_node(unit_test, manifest)
for saved_query in manifest.saved_queries.values(): for saved_query in manifest.saved_queries.values():
self.link_node(saved_query, manifest) self.link_node(saved_query, manifest)
@@ -233,6 +242,7 @@ class Linker:
# Get all tests that depend on any upstream nodes. # Get all tests that depend on any upstream nodes.
upstream_tests = [] upstream_tests = []
for upstream_node in upstream_nodes: for upstream_node in upstream_nodes:
# This gets tests with unique_ids starting with "test."
upstream_tests += _get_tests_for_node(manifest, upstream_node) upstream_tests += _get_tests_for_node(manifest, upstream_node)
for upstream_test in upstream_tests: for upstream_test in upstream_tests:
@@ -290,8 +300,10 @@ class Compiler:
manifest: Manifest, manifest: Manifest,
extra_context: Dict[str, Any], extra_context: Dict[str, Any],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
if isinstance(node, UnitTestNode):
context = generate_runtime_model_context(node, self.config, manifest) context = generate_runtime_unit_test_context(node, self.config, manifest)
else:
context = generate_runtime_model_context(node, self.config, manifest)
context.update(extra_context) context.update(extra_context)
if isinstance(node, GenericTestNode): if isinstance(node, GenericTestNode):
@@ -459,6 +471,7 @@ class Compiler:
summaries["_invocation_id"] = get_invocation_id() summaries["_invocation_id"] = get_invocation_id()
summaries["linked"] = linker.get_graph_summary(manifest) summaries["linked"] = linker.get_graph_summary(manifest)
# This is only called for the "build" command
if add_test_edges: if add_test_edges:
manifest.build_parent_and_child_maps() manifest.build_parent_and_child_maps()
linker.add_test_edges(manifest) linker.add_test_edges(manifest)
@@ -525,6 +538,9 @@ class Compiler:
the node's raw_code into compiled_code, and then calls the the node's raw_code into compiled_code, and then calls the
recursive method to "prepend" the ctes. recursive method to "prepend" the ctes.
""" """
if isinstance(node, UnitTestDefinition):
return node
# Make sure Lexer for sqlparse 0.4.4 is initialized # Make sure Lexer for sqlparse 0.4.4 is initialized
from sqlparse.lexer import Lexer # type: ignore from sqlparse.lexer import Lexer # type: ignore

View File

@@ -437,7 +437,8 @@ class PartialProject(RenderComponents):
seeds: Dict[str, Any] seeds: Dict[str, Any]
snapshots: Dict[str, Any] snapshots: Dict[str, Any]
sources: Dict[str, Any] sources: Dict[str, Any]
tests: Dict[str, Any] data_tests: Dict[str, Any]
unit_tests: Dict[str, Any]
metrics: Dict[str, Any] metrics: Dict[str, Any]
semantic_models: Dict[str, Any] semantic_models: Dict[str, Any]
saved_queries: Dict[str, Any] saved_queries: Dict[str, Any]
@@ -450,7 +451,10 @@ class PartialProject(RenderComponents):
seeds = cfg.seeds seeds = cfg.seeds
snapshots = cfg.snapshots snapshots = cfg.snapshots
sources = cfg.sources sources = cfg.sources
tests = cfg.tests # the `tests` config is deprecated but still allowed. Copy it into
# `data_tests` to simplify logic throughout the rest of the system.
data_tests = cfg.data_tests if "data_tests" in rendered.project_dict else cfg.tests
unit_tests = cfg.unit_tests
metrics = cfg.metrics metrics = cfg.metrics
semantic_models = cfg.semantic_models semantic_models = cfg.semantic_models
saved_queries = cfg.saved_queries saved_queries = cfg.saved_queries
@@ -511,7 +515,8 @@ class PartialProject(RenderComponents):
selectors=selectors, selectors=selectors,
query_comment=query_comment, query_comment=query_comment,
sources=sources, sources=sources,
tests=tests, data_tests=data_tests,
unit_tests=unit_tests,
metrics=metrics, metrics=metrics,
semantic_models=semantic_models, semantic_models=semantic_models,
saved_queries=saved_queries, saved_queries=saved_queries,
@@ -621,7 +626,8 @@ class Project:
seeds: Dict[str, Any] seeds: Dict[str, Any]
snapshots: Dict[str, Any] snapshots: Dict[str, Any]
sources: Dict[str, Any] sources: Dict[str, Any]
tests: Dict[str, Any] data_tests: Dict[str, Any]
unit_tests: Dict[str, Any]
metrics: Dict[str, Any] metrics: Dict[str, Any]
semantic_models: Dict[str, Any] semantic_models: Dict[str, Any]
saved_queries: Dict[str, Any] saved_queries: Dict[str, Any]
@@ -655,6 +661,13 @@ class Project:
generic_test_paths.append(os.path.join(test_path, "generic")) generic_test_paths.append(os.path.join(test_path, "generic"))
return generic_test_paths return generic_test_paths
@property
def fixture_paths(self):
fixture_paths = []
for test_path in self.test_paths:
fixture_paths.append(os.path.join(test_path, "fixtures"))
return fixture_paths
def __str__(self): def __str__(self):
cfg = self.to_project_config(with_packages=True) cfg = self.to_project_config(with_packages=True)
return str(cfg) return str(cfg)
@@ -699,7 +712,8 @@ class Project:
"seeds": self.seeds, "seeds": self.seeds,
"snapshots": self.snapshots, "snapshots": self.snapshots,
"sources": self.sources, "sources": self.sources,
"tests": self.tests, "data_tests": self.data_tests,
"unit_tests": self.unit_tests,
"metrics": self.metrics, "metrics": self.metrics,
"semantic-models": self.semantic_models, "semantic-models": self.semantic_models,
"saved-queries": self.saved_queries, "saved-queries": self.saved_queries,

View File

@@ -164,7 +164,7 @@ class DbtProjectYamlRenderer(BaseRenderer):
if first == "vars": if first == "vars":
return False return False
if first in {"seeds", "models", "snapshots", "tests"}: if first in {"seeds", "models", "snapshots", "tests", "data_tests"}:
keypath_parts = {(k.lstrip("+ ") if isinstance(k, str) else k) for k in keypath} keypath_parts = {(k.lstrip("+ ") if isinstance(k, str) else k) for k in keypath}
# model-level hooks # model-level hooks
late_rendered_hooks = {"pre-hook", "post-hook", "pre_hook", "post_hook"} late_rendered_hooks = {"pre-hook", "post-hook", "pre_hook", "post_hook"}

View File

@@ -166,7 +166,8 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
selectors=project.selectors, selectors=project.selectors,
query_comment=project.query_comment, query_comment=project.query_comment,
sources=project.sources, sources=project.sources,
tests=project.tests, data_tests=project.data_tests,
unit_tests=project.unit_tests,
metrics=project.metrics, metrics=project.metrics,
semantic_models=project.semantic_models, semantic_models=project.semantic_models,
saved_queries=project.saved_queries, saved_queries=project.saved_queries,
@@ -325,7 +326,8 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
"seeds": self._get_config_paths(self.seeds), "seeds": self._get_config_paths(self.seeds),
"snapshots": self._get_config_paths(self.snapshots), "snapshots": self._get_config_paths(self.snapshots),
"sources": self._get_config_paths(self.sources), "sources": self._get_config_paths(self.sources),
"tests": self._get_config_paths(self.tests), "data_tests": self._get_config_paths(self.data_tests),
"unit_tests": self._get_config_paths(self.unit_tests),
"metrics": self._get_config_paths(self.metrics), "metrics": self._get_config_paths(self.metrics),
"semantic_models": self._get_config_paths(self.semantic_models), "semantic_models": self._get_config_paths(self.semantic_models),
"saved_queries": self._get_config_paths(self.saved_queries), "saved_queries": self._get_config_paths(self.saved_queries),

View File

@@ -43,7 +43,7 @@ class UnrenderedConfig(ConfigSource):
elif resource_type == NodeType.Source: elif resource_type == NodeType.Source:
model_configs = unrendered.get("sources") model_configs = unrendered.get("sources")
elif resource_type == NodeType.Test: elif resource_type == NodeType.Test:
model_configs = unrendered.get("tests") model_configs = unrendered.get("data_tests")
elif resource_type == NodeType.Metric: elif resource_type == NodeType.Metric:
model_configs = unrendered.get("metrics") model_configs = unrendered.get("metrics")
elif resource_type == NodeType.SemanticModel: elif resource_type == NodeType.SemanticModel:
@@ -52,6 +52,8 @@ class UnrenderedConfig(ConfigSource):
model_configs = unrendered.get("saved_queries") model_configs = unrendered.get("saved_queries")
elif resource_type == NodeType.Exposure: elif resource_type == NodeType.Exposure:
model_configs = unrendered.get("exposures") model_configs = unrendered.get("exposures")
elif resource_type == NodeType.Unit:
model_configs = unrendered.get("unit_tests")
else: else:
model_configs = unrendered.get("models") model_configs = unrendered.get("models")
if model_configs is None: if model_configs is None:
@@ -72,7 +74,7 @@ class RenderedConfig(ConfigSource):
elif resource_type == NodeType.Source: elif resource_type == NodeType.Source:
model_configs = self.project.sources model_configs = self.project.sources
elif resource_type == NodeType.Test: elif resource_type == NodeType.Test:
model_configs = self.project.tests model_configs = self.project.data_tests
elif resource_type == NodeType.Metric: elif resource_type == NodeType.Metric:
model_configs = self.project.metrics model_configs = self.project.metrics
elif resource_type == NodeType.SemanticModel: elif resource_type == NodeType.SemanticModel:
@@ -81,6 +83,8 @@ class RenderedConfig(ConfigSource):
model_configs = self.project.saved_queries model_configs = self.project.saved_queries
elif resource_type == NodeType.Exposure: elif resource_type == NodeType.Exposure:
model_configs = self.project.exposures model_configs = self.project.exposures
elif resource_type == NodeType.Unit:
model_configs = self.project.unit_tests
else: else:
model_configs = self.project.models model_configs = self.project.models
return model_configs return model_configs

View File

@@ -1,4 +1,5 @@
import abc import abc
from copy import deepcopy
import os import os
from typing import ( from typing import (
Callable, Callable,
@@ -18,7 +19,7 @@ from dbt.adapters.base.column import Column
from dbt.common.clients.jinja import MacroProtocol from dbt.common.clients.jinja import MacroProtocol
from dbt.adapters.factory import get_adapter, get_adapter_package_names, get_adapter_type_names from dbt.adapters.factory import get_adapter, get_adapter_package_names, get_adapter_type_names
from dbt.common.clients import agate_helper from dbt.common.clients import agate_helper
from dbt.clients.jinja import get_rendered, MacroGenerator, MacroStack from dbt.clients.jinja import get_rendered, MacroGenerator, MacroStack, UnitTestMacroGenerator
from dbt.config import RuntimeConfig, Project from dbt.config import RuntimeConfig, Project
from dbt.constants import SECRET_ENV_PREFIX, DEFAULT_ENV_PLACEHOLDER from dbt.constants import SECRET_ENV_PREFIX, DEFAULT_ENV_PLACEHOLDER
from dbt.context.base import contextmember, contextproperty, Var from dbt.context.base import contextmember, contextproperty, Var
@@ -40,6 +41,7 @@ from dbt.contracts.graph.nodes import (
RefArgs, RefArgs,
AccessType, AccessType,
SemanticModel, SemanticModel,
UnitTestNode,
) )
from dbt.contracts.graph.metrics import MetricReference, ResolvedMetricReference from dbt.contracts.graph.metrics import MetricReference, ResolvedMetricReference
from dbt.contracts.graph.unparsed import NodeVersion from dbt.contracts.graph.unparsed import NodeVersion
@@ -568,6 +570,17 @@ class OperationRefResolver(RuntimeRefResolver):
return super().create_relation(target_model) return super().create_relation(target_model)
class RuntimeUnitTestRefResolver(RuntimeRefResolver):
def resolve(
self,
target_name: str,
target_package: Optional[str] = None,
target_version: Optional[NodeVersion] = None,
) -> RelationProxy:
target_name = f"{self.model.name}__{target_name}"
return super().resolve(target_name, target_package, target_version)
# `source` implementations # `source` implementations
class ParseSourceResolver(BaseSourceResolver): class ParseSourceResolver(BaseSourceResolver):
def resolve(self, source_name: str, table_name: str): def resolve(self, source_name: str, table_name: str):
@@ -595,6 +608,29 @@ class RuntimeSourceResolver(BaseSourceResolver):
return self.Relation.create_from(self.config, target_source, limit=self.resolve_limit) return self.Relation.create_from(self.config, target_source, limit=self.resolve_limit)
class RuntimeUnitTestSourceResolver(BaseSourceResolver):
def resolve(self, source_name: str, table_name: str):
target_source = self.manifest.resolve_source(
source_name,
table_name,
self.current_project,
self.model.package_name,
)
if target_source is None or isinstance(target_source, Disabled):
raise TargetNotFoundError(
node=self.model,
target_name=f"{source_name}.{table_name}",
target_kind="source",
disabled=(isinstance(target_source, Disabled)),
)
# For unit tests, this isn't a "real" source, it's a ModelNode taking
# the place of a source. We don't really need to return the relation here,
# we just need to set_cte, but skipping it confuses typing. We *do* need
# the relation in the "this" property.
self.model.set_cte(target_source.unique_id, None)
return self.Relation.create_ephemeral_from(target_source)
# metric` implementations # metric` implementations
class ParseMetricResolver(BaseMetricResolver): class ParseMetricResolver(BaseMetricResolver):
def resolve(self, name: str, package: Optional[str] = None) -> MetricReference: def resolve(self, name: str, package: Optional[str] = None) -> MetricReference:
@@ -672,6 +708,22 @@ class RuntimeVar(ModelConfiguredVar):
pass pass
class UnitTestVar(RuntimeVar):
def __init__(
self,
context: Dict[str, Any],
config: RuntimeConfig,
node: Resource,
) -> None:
config_copy = None
assert isinstance(node, UnitTestNode)
if node.overrides and node.overrides.vars:
config_copy = deepcopy(config)
config_copy.cli_vars.update(node.overrides.vars)
super().__init__(context, config_copy or config, node=node)
# Providers # Providers
class Provider(Protocol): class Provider(Protocol):
execute: bool execute: bool
@@ -713,6 +765,16 @@ class RuntimeProvider(Provider):
metric = RuntimeMetricResolver metric = RuntimeMetricResolver
class RuntimeUnitTestProvider(Provider):
execute = True
Config = RuntimeConfigObject
DatabaseWrapper = RuntimeDatabaseWrapper
Var = UnitTestVar
ref = RuntimeUnitTestRefResolver
source = RuntimeUnitTestSourceResolver
metric = RuntimeMetricResolver
class OperationProvider(RuntimeProvider): class OperationProvider(RuntimeProvider):
ref = OperationRefResolver ref = OperationRefResolver
@@ -1384,7 +1446,7 @@ class ModelContext(ProviderContext):
@contextproperty() @contextproperty()
def pre_hooks(self) -> List[Dict[str, Any]]: def pre_hooks(self) -> List[Dict[str, Any]]:
if self.model.resource_type in [NodeType.Source, NodeType.Test]: if self.model.resource_type in [NodeType.Source, NodeType.Test, NodeType.Unit]:
return [] return []
# TODO CT-211 # TODO CT-211
return [ return [
@@ -1393,7 +1455,7 @@ class ModelContext(ProviderContext):
@contextproperty() @contextproperty()
def post_hooks(self) -> List[Dict[str, Any]]: def post_hooks(self) -> List[Dict[str, Any]]:
if self.model.resource_type in [NodeType.Source, NodeType.Test]: if self.model.resource_type in [NodeType.Source, NodeType.Test, NodeType.Unit]:
return [] return []
# TODO CT-211 # TODO CT-211
return [ return [
@@ -1486,6 +1548,33 @@ class ModelContext(ProviderContext):
return None return None
class UnitTestContext(ModelContext):
model: UnitTestNode
@contextmember()
def env_var(self, var: str, default: Optional[str] = None) -> str:
"""The env_var() function. Return the overriden unit test environment variable named 'var'.
If there is no unit test override, return the environment variable named 'var'.
If there is no such environment variable set, return the default.
If the default is None, raise an exception for an undefined variable.
"""
if self.model.overrides and var in self.model.overrides.env_vars:
return self.model.overrides.env_vars[var]
else:
return super().env_var(var, default)
@contextproperty()
def this(self) -> Optional[str]:
if self.model.this_input_node_unique_id:
this_node = self.manifest.expect(self.model.this_input_node_unique_id)
self.model.set_cte(this_node.unique_id, None) # type: ignore
return self.adapter.Relation.add_ephemeral_prefix(this_node.name)
return None
# This is called by '_context_for', used in 'render_with_context' # This is called by '_context_for', used in 'render_with_context'
def generate_parser_model_context( def generate_parser_model_context(
model: ManifestNode, model: ManifestNode,
@@ -1530,6 +1619,24 @@ def generate_runtime_macro_context(
return ctx.to_dict() return ctx.to_dict()
def generate_runtime_unit_test_context(
unit_test: UnitTestNode,
config: RuntimeConfig,
manifest: Manifest,
) -> Dict[str, Any]:
ctx = UnitTestContext(unit_test, config, manifest, RuntimeUnitTestProvider(), None)
ctx_dict = ctx.to_dict()
if unit_test.overrides and unit_test.overrides.macros:
for macro_name, macro_value in unit_test.overrides.macros.items():
context_value = ctx_dict.get(macro_name)
if isinstance(context_value, MacroGenerator):
ctx_dict[macro_name] = UnitTestMacroGenerator(context_value, macro_value)
else:
ctx_dict[macro_name] = macro_value
return ctx_dict
class ExposureRefResolver(BaseResolver): class ExposureRefResolver(BaseResolver):
def __call__(self, *args, **kwargs) -> str: def __call__(self, *args, **kwargs) -> str:
package = None package = None

View File

@@ -22,6 +22,7 @@ class ParseFileType(StrEnum):
Documentation = "docs" Documentation = "docs"
Schema = "schema" Schema = "schema"
Hook = "hook" # not a real filetype, from dbt_project.yml Hook = "hook" # not a real filetype, from dbt_project.yml
Fixture = "fixture"
parse_file_type_to_parser = { parse_file_type_to_parser = {
@@ -35,6 +36,7 @@ parse_file_type_to_parser = {
ParseFileType.Documentation: "DocumentationParser", ParseFileType.Documentation: "DocumentationParser",
ParseFileType.Schema: "SchemaParser", ParseFileType.Schema: "SchemaParser",
ParseFileType.Hook: "HookParser", ParseFileType.Hook: "HookParser",
ParseFileType.Fixture: "FixtureParser",
} }
@@ -152,7 +154,6 @@ class BaseSourceFile(dbtClassMixin, SerializableType):
parse_file_type: Optional[ParseFileType] = None parse_file_type: Optional[ParseFileType] = None
# we don't want to serialize this # we don't want to serialize this
contents: Optional[str] = None contents: Optional[str] = None
# the unique IDs contained in this file
@property @property
def file_id(self): def file_id(self):
@@ -172,6 +173,8 @@ class BaseSourceFile(dbtClassMixin, SerializableType):
def _deserialize(cls, dct: Dict[str, int]): def _deserialize(cls, dct: Dict[str, int]):
if dct["parse_file_type"] == "schema": if dct["parse_file_type"] == "schema":
sf = SchemaSourceFile.from_dict(dct) sf = SchemaSourceFile.from_dict(dct)
elif dct["parse_file_type"] == "fixture":
sf = FixtureSourceFile.from_dict(dct)
else: else:
sf = SourceFile.from_dict(dct) sf = SourceFile.from_dict(dct)
return sf return sf
@@ -220,12 +223,11 @@ class SourceFile(BaseSourceFile):
) )
return self return self
@dataclass @dataclass
class SchemaSourceFile(BaseSourceFile): class SchemaSourceFile(BaseSourceFile):
dfy: Dict[str, Any] = field(default_factory=dict) dfy: Dict[str, Any] = field(default_factory=dict)
# these are in the manifest.nodes dictionary # these are in the manifest.nodes dictionary
tests: Dict[str, Any] = field(default_factory=dict) data_tests: Dict[str, Any] = field(default_factory=dict)
sources: List[str] = field(default_factory=list) sources: List[str] = field(default_factory=list)
exposures: List[str] = field(default_factory=list) exposures: List[str] = field(default_factory=list)
metrics: List[str] = field(default_factory=list) metrics: List[str] = field(default_factory=list)
@@ -235,6 +237,9 @@ class SchemaSourceFile(BaseSourceFile):
# node patches contain models, seeds, snapshots, analyses # node patches contain models, seeds, snapshots, analyses
ndp: List[str] = field(default_factory=list) ndp: List[str] = field(default_factory=list)
semantic_models: List[str] = field(default_factory=list) semantic_models: List[str] = field(default_factory=list)
unit_tests: List[str] = field(default_factory=list)
# any unit_test patches in this file by unit_test unique_id.
utp: List[str] = field(default_factory=list)
saved_queries: List[str] = field(default_factory=list) saved_queries: List[str] = field(default_factory=list)
# any macro patches in this file by macro unique_id. # any macro patches in this file by macro unique_id.
mcp: Dict[str, str] = field(default_factory=dict) mcp: Dict[str, str] = field(default_factory=dict)
@@ -261,6 +266,10 @@ class SchemaSourceFile(BaseSourceFile):
@property @property
def source_patches(self): def source_patches(self):
return self.sop return self.sop
@property
def unit_test_patches(self):
return self.utp
def __post_serialize__(self, dct): def __post_serialize__(self, dct):
dct = super().__post_serialize__(dct) dct = super().__post_serialize__(dct)
@@ -276,31 +285,31 @@ class SchemaSourceFile(BaseSourceFile):
def add_test(self, node_unique_id, test_from): def add_test(self, node_unique_id, test_from):
name = test_from["name"] name = test_from["name"]
key = test_from["key"] key = test_from["key"]
if key not in self.tests: if key not in self.data_tests:
self.tests[key] = {} self.data_tests[key] = {}
if name not in self.tests[key]: if name not in self.data_tests[key]:
self.tests[key][name] = [] self.data_tests[key][name] = []
self.tests[key][name].append(node_unique_id) self.data_tests[key][name].append(node_unique_id)
# this is only used in unit tests # this is only used in tests/unit
def remove_tests(self, yaml_key, name): def remove_tests(self, yaml_key, name):
if yaml_key in self.tests: if yaml_key in self.data_tests:
if name in self.tests[yaml_key]: if name in self.data_tests[yaml_key]:
del self.tests[yaml_key][name] del self.data_tests[yaml_key][name]
# this is only used in tests (unit + functional) # this is only used in the tests directory (unit + functional)
def get_tests(self, yaml_key, name): def get_tests(self, yaml_key, name):
if yaml_key in self.tests: if yaml_key in self.data_tests:
if name in self.tests[yaml_key]: if name in self.data_tests[yaml_key]:
return self.tests[yaml_key][name] return self.data_tests[yaml_key][name]
return [] return []
def get_key_and_name_for_test(self, test_unique_id): def get_key_and_name_for_test(self, test_unique_id):
yaml_key = None yaml_key = None
block_name = None block_name = None
for key in self.tests.keys(): for key in self.data_tests.keys():
for name in self.tests[key]: for name in self.data_tests[key]:
for unique_id in self.tests[key][name]: for unique_id in self.data_tests[key][name]:
if unique_id == test_unique_id: if unique_id == test_unique_id:
yaml_key = key yaml_key = key
block_name = name block_name = name
@@ -309,9 +318,9 @@ class SchemaSourceFile(BaseSourceFile):
def get_all_test_ids(self): def get_all_test_ids(self):
test_ids = [] test_ids = []
for key in self.tests.keys(): for key in self.data_tests.keys():
for name in self.tests[key]: for name in self.data_tests[key]:
test_ids.extend(self.tests[key][name]) test_ids.extend(self.data_tests[key][name])
return test_ids return test_ids
def add_env_var(self, var, yaml_key, name): def add_env_var(self, var, yaml_key, name):
@@ -331,4 +340,14 @@ class SchemaSourceFile(BaseSourceFile):
del self.env_vars[yaml_key] del self.env_vars[yaml_key]
AnySourceFile = Union[SchemaSourceFile, SourceFile] @dataclass
class FixtureSourceFile(BaseSourceFile):
fixture: Optional[str] = None
unit_tests: List[str] = field(default_factory=list)
def add_unit_test(self, value):
if value not in self.unit_tests:
self.unit_tests.append(value)
AnySourceFile = Union[SchemaSourceFile, SourceFile, FixtureSourceFile]

View File

@@ -42,10 +42,19 @@ from dbt.contracts.graph.nodes import (
SemanticModel, SemanticModel,
SourceDefinition, SourceDefinition,
UnpatchedSourceDefinition, UnpatchedSourceDefinition,
UnitTestDefinition,
UnitTestFileFixture,
UnpatchedUnitTestDefinition,
) )
from dbt.contracts.graph.unparsed import SourcePatch, NodeVersion, UnparsedVersion from dbt.contracts.graph.unparsed import SourcePatch, NodeVersion, UnparsedVersion, UnitTestPatch
from dbt.contracts.graph.manifest_upgrade import upgrade_manifest_json from dbt.contracts.graph.manifest_upgrade import upgrade_manifest_json
from dbt.contracts.files import SourceFile, SchemaSourceFile, FileHash, AnySourceFile from dbt.contracts.files import (
SourceFile,
SchemaSourceFile,
FileHash,
AnySourceFile,
FixtureSourceFile,
)
from dbt.contracts.util import ( from dbt.contracts.util import (
BaseArtifactMetadata, BaseArtifactMetadata,
SourceKey, SourceKey,
@@ -160,6 +169,39 @@ class SourceLookup(dbtClassMixin):
return manifest.sources[unique_id] return manifest.sources[unique_id]
class UnitTestLookup(dbtClassMixin):
def __init__(self, manifest: "Manifest") -> None:
self.storage: Dict[str, Dict[PackageName, UniqueID]] = {}
self.populate(manifest)
def get_unique_id(self, search_name, package: Optional[PackageName]):
return find_unique_id_for_package(self.storage, search_name, package)
def find(self, search_name, package: Optional[PackageName], manifest: "Manifest"):
unique_id = self.get_unique_id(search_name, package)
if unique_id is not None:
return self.perform_lookup(unique_id, manifest)
return None
def add_unit_test(self, unit_test: UnitTestDefinition):
if unit_test.search_name not in self.storage:
self.storage[unit_test.search_name] = {}
self.storage[unit_test.search_name][unit_test.package_name] = unit_test.unique_id
def populate(self, manifest):
for unit_test in manifest.unit_tests.values():
if hasattr(unit_test, "unit_test"):
self.add_unit_test(unit_test)
def perform_lookup(self, unique_id: UniqueID, manifest: "Manifest") -> UnitTestDefinition:
if unique_id not in manifest.unit_tests:
raise dbt.exceptions.DbtInternalError(
f"Unit test {unique_id} found in cache but not found in manifest"
)
return manifest.unit_tests[unique_id]
class RefableLookup(dbtClassMixin): class RefableLookup(dbtClassMixin):
# model, seed, snapshot # model, seed, snapshot
_lookup_types: ClassVar[set] = set(NodeType.refable()) _lookup_types: ClassVar[set] = set(NodeType.refable())
@@ -666,6 +708,8 @@ MaybeParsedSource = Optional[
] ]
] ]
MaybeParsedUnitTest = Optional[UnitTestDefinition]
MaybeNonSource = Optional[Union[ManifestNode, Disabled[ManifestNode]]] MaybeNonSource = Optional[Union[ManifestNode, Disabled[ManifestNode]]]
@@ -800,7 +844,10 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
disabled: MutableMapping[str, List[GraphMemberNode]] = field(default_factory=dict) disabled: MutableMapping[str, List[GraphMemberNode]] = field(default_factory=dict)
env_vars: MutableMapping[str, str] = field(default_factory=dict) env_vars: MutableMapping[str, str] = field(default_factory=dict)
semantic_models: MutableMapping[str, SemanticModel] = field(default_factory=dict) semantic_models: MutableMapping[str, SemanticModel] = field(default_factory=dict)
unit_tests: MutableMapping[str, UnitTestDefinition] = field(default_factory=dict)
unit_test_patches: MutableMapping[str, UnitTestPatch] = field(default_factory=dict)
saved_queries: MutableMapping[str, SavedQuery] = field(default_factory=dict) saved_queries: MutableMapping[str, SavedQuery] = field(default_factory=dict)
fixtures: MutableMapping[str, UnitTestFileFixture] = field(default_factory=dict)
_doc_lookup: Optional[DocLookup] = field( _doc_lookup: Optional[DocLookup] = field(
default=None, metadata={"serialize": lambda x: None, "deserialize": lambda x: None} default=None, metadata={"serialize": lambda x: None, "deserialize": lambda x: None}
@@ -808,6 +855,9 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
_source_lookup: Optional[SourceLookup] = field( _source_lookup: Optional[SourceLookup] = field(
default=None, metadata={"serialize": lambda x: None, "deserialize": lambda x: None} default=None, metadata={"serialize": lambda x: None, "deserialize": lambda x: None}
) )
_unit_test_lookup: Optional[UnitTestLookup] = field(
default=None, metadata={"serialize": lambda x: None, "deserialize": lambda x: None}
)
_ref_lookup: Optional[RefableLookup] = field( _ref_lookup: Optional[RefableLookup] = field(
default=None, metadata={"serialize": lambda x: None, "deserialize": lambda x: None} default=None, metadata={"serialize": lambda x: None, "deserialize": lambda x: None}
) )
@@ -961,6 +1011,7 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
files={k: _deepcopy(v) for k, v in self.files.items()}, files={k: _deepcopy(v) for k, v in self.files.items()},
state_check=_deepcopy(self.state_check), state_check=_deepcopy(self.state_check),
semantic_models={k: _deepcopy(v) for k, v in self.semantic_models.items()}, semantic_models={k: _deepcopy(v) for k, v in self.semantic_models.items()},
unit_tests={k: _deepcopy(v) for k, v in self.unit_tests.items()},
saved_queries={k: _deepcopy(v) for k, v in self.saved_queries.items()}, saved_queries={k: _deepcopy(v) for k, v in self.saved_queries.items()},
) )
copy.build_flat_graph() copy.build_flat_graph()
@@ -1031,6 +1082,7 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
parent_map=self.parent_map, parent_map=self.parent_map,
group_map=self.group_map, group_map=self.group_map,
semantic_models=self.semantic_models, semantic_models=self.semantic_models,
unit_tests=self.unit_tests,
saved_queries=self.saved_queries, saved_queries=self.saved_queries,
) )
@@ -1050,6 +1102,8 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
return self.metrics[unique_id] return self.metrics[unique_id]
elif unique_id in self.semantic_models: elif unique_id in self.semantic_models:
return self.semantic_models[unique_id] return self.semantic_models[unique_id]
elif unique_id in self.unit_tests:
return self.unit_tests[unique_id]
elif unique_id in self.saved_queries: elif unique_id in self.saved_queries:
return self.saved_queries[unique_id] return self.saved_queries[unique_id]
else: else:
@@ -1076,6 +1130,15 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
def rebuild_source_lookup(self): def rebuild_source_lookup(self):
self._source_lookup = SourceLookup(self) self._source_lookup = SourceLookup(self)
@property
def unit_test_lookup(self) -> UnitTestLookup:
if self._unit_test_lookup is None:
self._unit_test_lookup = UnitTestLookup(self)
return self._unit_test_lookup
def rebuild_unit_test_lookup(self):
self._unit_test_lookup = UnitTestLookup(self)
@property @property
def ref_lookup(self) -> RefableLookup: def ref_lookup(self) -> RefableLookup:
if self._ref_lookup is None: if self._ref_lookup is None:
@@ -1174,6 +1237,22 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
return Disabled(disabled[0]) return Disabled(disabled[0])
return None return None
def resolve_unit_tests(
self,
unit_test_name: str,
current_project: str,
node_package: str,
) -> MaybeParsedUnitTest:
candidates = _packages_to_search(current_project, node_package)
unit_test: Optional[UnitTestDefinition] = None
for pkg in candidates:
unit_test = self.unit_test_lookup.find(unit_test_name, pkg, self)
if unit_test is not None:
return unit_test
return None
# Called by dbt.parser.manifest._resolve_sources_for_exposure # Called by dbt.parser.manifest._resolve_sources_for_exposure
# and dbt.parser.manifest._process_source_for_node # and dbt.parser.manifest._process_source_for_node
def resolve_source( def resolve_source(
@@ -1419,7 +1498,7 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
def add_source(self, source_file: SchemaSourceFile, source: UnpatchedSourceDefinition): def add_source(self, source_file: SchemaSourceFile, source: UnpatchedSourceDefinition):
# sources can't be overwritten! # sources can't be overwritten!
_check_duplicates(source, self.sources) _check_duplicates(source, self.sources)
self.sources[source.unique_id] = source # type: ignore self.sources[source.unique_id] = source # type: ignore[assignment]
source_file.sources.append(source.unique_id) source_file.sources.append(source.unique_id)
def add_node_nofile(self, node: ManifestNode): def add_node_nofile(self, node: ManifestNode):
@@ -1439,6 +1518,8 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
source_file.exposures.append(node.unique_id) source_file.exposures.append(node.unique_id)
if isinstance(node, Group): if isinstance(node, Group):
source_file.groups.append(node.unique_id) source_file.groups.append(node.unique_id)
elif isinstance(source_file, FixtureSourceFile):
pass
else: else:
source_file.nodes.append(node.unique_id) source_file.nodes.append(node.unique_id)
@@ -1481,6 +1562,8 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
source_file.semantic_models.append(node.unique_id) source_file.semantic_models.append(node.unique_id)
if isinstance(node, Exposure): if isinstance(node, Exposure):
source_file.exposures.append(node.unique_id) source_file.exposures.append(node.unique_id)
elif isinstance(source_file, FixtureSourceFile):
pass
else: else:
source_file.nodes.append(node.unique_id) source_file.nodes.append(node.unique_id)
@@ -1494,6 +1577,17 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
self.semantic_models[semantic_model.unique_id] = semantic_model self.semantic_models[semantic_model.unique_id] = semantic_model
source_file.semantic_models.append(semantic_model.unique_id) source_file.semantic_models.append(semantic_model.unique_id)
def add_unit_test(self, source_file: SchemaSourceFile, unit_test: UnpatchedUnitTestDefinition):
_check_duplicates(unit_test, self.unit_tests)
self.unit_tests[unit_test.unique_id] = unit_test # type: ignore[assignment]
source_file.unit_tests.append(unit_test.unique_id)
def add_fixture(self, source_file: FixtureSourceFile, fixture: UnitTestFileFixture):
if fixture.unique_id in self.fixtures:
raise DuplicateResourceNameError(fixture, self.fixtures[fixture.unique_id])
self.fixtures[fixture.unique_id] = fixture
source_file.fixture = fixture.unique_id
def add_saved_query(self, source_file: SchemaSourceFile, saved_query: SavedQuery) -> None: def add_saved_query(self, source_file: SchemaSourceFile, saved_query: SavedQuery) -> None:
_check_duplicates(saved_query, self.saved_queries) _check_duplicates(saved_query, self.saved_queries)
self.saved_queries[saved_query.unique_id] = saved_query self.saved_queries[saved_query.unique_id] = saved_query
@@ -1526,6 +1620,8 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
self.disabled, self.disabled,
self.env_vars, self.env_vars,
self.semantic_models, self.semantic_models,
self.unit_tests,
self.unit_test_patches,
self.saved_queries, self.saved_queries,
self._doc_lookup, self._doc_lookup,
self._source_lookup, self._source_lookup,
@@ -1608,6 +1704,11 @@ class WritableManifest(ArtifactMixin):
description="Metadata about the manifest", description="Metadata about the manifest",
) )
) )
unit_tests: Mapping[UniqueID, UnitTestDefinition] = field(
metadata=dict(
description="The unit tests defined in the project",
)
)
@classmethod @classmethod
def compatible_previous_versions(cls) -> Iterable[Tuple[str, int]]: def compatible_previous_versions(cls) -> Iterable[Tuple[str, int]]:

View File

@@ -145,6 +145,9 @@ def upgrade_manifest_json(manifest: dict, manifest_schema_version: int) -> dict:
manifest["groups"] = {} manifest["groups"] = {}
if "group_map" not in manifest: if "group_map" not in manifest:
manifest["group_map"] = {} manifest["group_map"] = {}
# add unit_tests key
if "unit_tests" not in manifest:
manifest["unit_tests"] = {}
for metric_content in manifest.get("metrics", {}).values(): for metric_content in manifest.get("metrics", {}).values():
# handle attr renames + value translation ("expression" -> "derived") # handle attr renames + value translation ("expression" -> "derived")
metric_content = upgrade_ref_content(metric_content) metric_content = upgrade_ref_content(metric_content)

View File

@@ -227,6 +227,11 @@ class ModelConfig(NodeConfig):
) )
@dataclass
class UnitTestNodeConfig(NodeConfig):
expected_rows: List[Dict[str, Any]] = field(default_factory=list)
@dataclass @dataclass
class SeedConfig(NodeConfig): class SeedConfig(NodeConfig):
materialized: str = "seed" materialized: str = "seed"
@@ -399,6 +404,18 @@ class SnapshotConfig(EmptySnapshotConfig):
return self.from_dict(data) return self.from_dict(data)
@dataclass
class UnitTestConfig(BaseConfig):
tags: Union[str, List[str]] = field(
default_factory=list_str,
metadata=metas(ShowBehavior.Hide, MergeBehavior.Append, CompareBehavior.Exclude),
)
meta: Dict[str, Any] = field(
default_factory=dict,
metadata=MergeBehavior.Update.meta(),
)
RESOURCE_TYPES: Dict[NodeType, Type[BaseConfig]] = { RESOURCE_TYPES: Dict[NodeType, Type[BaseConfig]] = {
NodeType.Metric: MetricConfig, NodeType.Metric: MetricConfig,
NodeType.SemanticModel: SemanticModelConfig, NodeType.SemanticModel: SemanticModelConfig,
@@ -409,6 +426,7 @@ RESOURCE_TYPES: Dict[NodeType, Type[BaseConfig]] = {
NodeType.Test: TestConfig, NodeType.Test: TestConfig,
NodeType.Model: NodeConfig, NodeType.Model: NodeConfig,
NodeType.Snapshot: SnapshotConfig, NodeType.Snapshot: SnapshotConfig,
NodeType.Unit: UnitTestConfig,
} }

View File

@@ -7,6 +7,7 @@ import hashlib
from mashumaro.types import SerializableType from mashumaro.types import SerializableType
from typing import Optional, Union, List, Dict, Any, Sequence, Tuple, Iterator, Literal from typing import Optional, Union, List, Dict, Any, Sequence, Tuple, Iterator, Literal
from dbt import deprecations
from dbt.common.contracts.constraints import ( from dbt.common.contracts.constraints import (
ColumnLevelConstraint, ColumnLevelConstraint,
ConstraintType, ConstraintType,
@@ -40,13 +41,17 @@ from dbt.contracts.graph.unparsed import (
UnparsedSourceDefinition, UnparsedSourceDefinition,
UnparsedSourceTableDefinition, UnparsedSourceTableDefinition,
UnparsedColumn, UnparsedColumn,
UnitTestOverrides,
UnitTestInputFixture,
UnitTestOutputFixture,
UnitTestNodeVersion,
) )
from dbt.contracts.graph.node_args import ModelNodeArgs from dbt.contracts.graph.node_args import ModelNodeArgs
from dbt.contracts.graph.semantic_layer_common import WhereFilterIntersection from dbt.contracts.graph.semantic_layer_common import WhereFilterIntersection
from dbt.contracts.util import Replaceable from dbt.contracts.util import Replaceable
from dbt.common.contracts.config.properties import AdditionalPropertiesMixin from dbt.common.contracts.config.properties import AdditionalPropertiesMixin
from dbt.common.events.functions import warn_or_error from dbt.common.events.functions import warn_or_error
from dbt.exceptions import ParsingError, ContractBreakingChangeError from dbt.exceptions import ParsingError, ContractBreakingChangeError, ValidationError
from dbt.common.events.types import ( from dbt.common.events.types import (
SeedIncreased, SeedIncreased,
SeedExceedsLimitSamePath, SeedExceedsLimitSamePath,
@@ -82,6 +87,8 @@ from .model_config import (
EmptySnapshotConfig, EmptySnapshotConfig,
SnapshotConfig, SnapshotConfig,
SemanticModelConfig, SemanticModelConfig,
UnitTestConfig,
UnitTestNodeConfig,
SavedQueryConfig, SavedQueryConfig,
) )
@@ -482,6 +489,9 @@ class CompiledNode(ParsedNode):
refs: List[RefArgs] = field(default_factory=list) refs: List[RefArgs] = field(default_factory=list)
sources: List[List[str]] = field(default_factory=list) sources: List[List[str]] = field(default_factory=list)
metrics: List[List[str]] = field(default_factory=list) metrics: List[List[str]] = field(default_factory=list)
# TODO: when we do it this way we lose the ability to cross ref the model name without knowing the version yet
unit_tests: List[str] = field(default_factory=list)
# unit_tests: List[UnitTestModelVersion] = field(default_factory=list) # dict of model unique_id to moedl version
depends_on: DependsOn = field(default_factory=DependsOn) depends_on: DependsOn = field(default_factory=DependsOn)
compiled_path: Optional[str] = None compiled_path: Optional[str] = None
compiled: bool = False compiled: bool = False
@@ -1033,6 +1043,110 @@ class GenericTestNode(TestShouldStoreFailures, CompiledNode, HasTestMetadata):
return "generic" return "generic"
# ====================================
# Unit Test node
# ====================================
@dataclass
class UnpatchedUnitTestDefinition(BaseNode):
# unit_test: UnparsedUnitTest
name: str
model: str # name of the model being unit tested
given: Sequence[UnitTestInputFixture]
expect: UnitTestOutputFixture
fqn: List[str]
description: str = ""
config: Dict[str, Any] = field(default_factory=dict)
resource_type: Literal[NodeType.Unit]
versions: Optional[UnitTestNodeVersion] = None
overrides: Optional[UnitTestOverrides] = None
patch_path: Optional[str] = None
@dataclass
class UnitTestSourceDefinition(ModelNode):
source_name: str = "undefined"
quoting: Quoting = field(default_factory=Quoting)
@property
def search_name(self):
return f"{self.source_name}.{self.name}"
@dataclass
class UnitTestNode(CompiledNode):
resource_type: Literal[NodeType.Unit]
tested_node_unique_id: Optional[str] = None
this_input_node_unique_id: Optional[str] = None
overrides: Optional[UnitTestOverrides] = None
config: UnitTestNodeConfig = field(default_factory=UnitTestNodeConfig)
@dataclass
class UnitTestDefinitionMandatory:
model: str
given: Sequence[UnitTestInputFixture]
expect: UnitTestOutputFixture
@dataclass
class UnitTestDefinition(NodeInfoMixin, GraphNode, UnitTestDefinitionMandatory):
description: str = ""
depends_on: DependsOn = field(default_factory=DependsOn)
config: UnitTestConfig = field(default_factory=UnitTestConfig)
version: Optional[NodeVersion] = None
overrides: Optional[UnitTestOverrides] = None
checksum: Optional[str] = None
schema: Optional[str] = None
@property
def unit_tests(self):
return []
@property
def build_path(self):
# TODO: is this actually necessary?
return self.original_file_path
@property
def compiled_path(self):
# TODO: is this actually necessary?
return self.original_file_path
@property
def depends_on_nodes(self):
return self.depends_on.nodes
@property
def tags(self) -> List[str]:
tags = self.config.tags
return [tags] if isinstance(tags, str) else tags
def build_unit_test_checksum(self):
# everything except 'description'
data = f"{self.model}-{self.given}-{self.expect}-{self.overrides}"
# include underlying fixture data
for input in self.given:
if input.fixture:
data += f"-{input.rows}"
self.checksum = hashlib.new("sha256", data.encode("utf-8")).hexdigest()
def same_contents(self, other: Optional["UnitTestDefinition"]) -> bool:
if other is None:
return False
return self.checksum == other.checksum
@dataclass
class UnitTestFileFixture(BaseNode):
resource_type: Literal[NodeType.Fixture]
rows: Optional[List[Dict[str, Any]]] = None
# ==================================== # ====================================
# Snapshot node # Snapshot node
# ==================================== # ====================================
@@ -1136,6 +1250,24 @@ class UnpatchedSourceDefinition(BaseNode):
def get_source_representation(self): def get_source_representation(self):
return f'source("{self.source.name}", "{self.table.name}")' return f'source("{self.source.name}", "{self.table.name}")'
def validate_data_tests(self):
"""
sources parse tests differently than models, so we need to do some validation
here where it's done in the PatchParser for other nodes
"""
for column in self.columns:
if column.tests and column.data_tests:
raise ValidationError(
"Invalid test config: cannot have both 'tests' and 'data_tests' defined"
)
if column.tests:
deprecations.warn(
"project-test-config",
deprecated_path="tests",
exp_path="data_tests",
)
column.data_tests = column.tests
@property @property
def quote_columns(self) -> Optional[bool]: def quote_columns(self) -> Optional[bool]:
result = None result = None
@@ -1150,14 +1282,23 @@ class UnpatchedSourceDefinition(BaseNode):
return [] if self.table.columns is None else self.table.columns return [] if self.table.columns is None else self.table.columns
def get_tests(self) -> Iterator[Tuple[Dict[str, Any], Optional[UnparsedColumn]]]: def get_tests(self) -> Iterator[Tuple[Dict[str, Any], Optional[UnparsedColumn]]]:
for test in self.tests: self.validate_data_tests()
yield normalize_test(test), None for data_test in self.data_tests:
yield normalize_test(data_test), None
for column in self.columns: for column in self.columns:
if column.tests is not None: if column.data_tests is not None:
for test in column.tests: for data_test in column.data_tests:
yield normalize_test(test), column yield normalize_test(data_test), column
@property
def data_tests(self) -> List[TestDef]:
if self.table.data_tests is None:
return []
else:
return self.table.data_tests
# deprecated
@property @property
def tests(self) -> List[TestDef]: def tests(self) -> List[TestDef]:
if self.table.tests is None: if self.table.tests is None:
@@ -1289,6 +1430,10 @@ class SourceDefinition(NodeInfoMixin, ParsedSourceMandatory):
def search_name(self): def search_name(self):
return f"{self.source_name}.{self.name}" return f"{self.source_name}.{self.name}"
@property
def group(self):
return None
# ==================================== # ====================================
# Exposure node # Exposure node
@@ -1839,6 +1984,7 @@ ManifestSQLNode = Union[
SqlNode, SqlNode,
GenericTestNode, GenericTestNode,
SnapshotNode, SnapshotNode,
UnitTestNode,
] ]
# All SQL nodes plus SeedNode (csv files) # All SQL nodes plus SeedNode (csv files)
@@ -1859,6 +2005,7 @@ GraphMemberNode = Union[
Metric, Metric,
SavedQuery, SavedQuery,
SemanticModel, SemanticModel,
UnitTestDefinition,
] ]
# All "nodes" (or node-like objects) in this file # All "nodes" (or node-like objects) in this file
@@ -1869,7 +2016,4 @@ Resource = Union[
Group, Group,
] ]
TestNode = Union[ TestNode = Union[SingularTestNode, GenericTestNode]
SingularTestNode,
GenericTestNode,
]

View File

@@ -105,6 +105,7 @@ TestDef = Union[Dict[str, Any], str]
@dataclass @dataclass
class HasColumnAndTestProps(HasColumnProps): class HasColumnAndTestProps(HasColumnProps):
data_tests: List[TestDef] = field(default_factory=list)
tests: List[TestDef] = field(default_factory=list) tests: List[TestDef] = field(default_factory=list)
@@ -141,6 +142,7 @@ class HasConfig:
NodeVersion = Union[str, float] NodeVersion = Union[str, float]
UnitTestModelVersion: Dict[str, List[NodeVersion]]
@dataclass @dataclass
@@ -152,7 +154,7 @@ class UnparsedVersion(dbtClassMixin):
config: Dict[str, Any] = field(default_factory=dict) config: Dict[str, Any] = field(default_factory=dict)
constraints: List[Dict[str, Any]] = field(default_factory=list) constraints: List[Dict[str, Any]] = field(default_factory=list)
docs: Docs = field(default_factory=Docs) docs: Docs = field(default_factory=Docs)
tests: Optional[List[TestDef]] = None data_tests: Optional[List[TestDef]] = None
columns: Sequence[Union[dbt.common.helper_types.IncludeExclude, UnparsedColumn]] = field( columns: Sequence[Union[dbt.common.helper_types.IncludeExclude, UnparsedColumn]] = field(
default_factory=list default_factory=list
) )
@@ -255,7 +257,11 @@ class UnparsedModelUpdate(UnparsedNodeUpdate):
f"get_tests_for_version called for version '{version}' not in version map" f"get_tests_for_version called for version '{version}' not in version map"
) )
unparsed_version = self._version_map[version] unparsed_version = self._version_map[version]
return unparsed_version.tests if unparsed_version.tests is not None else self.tests return (
unparsed_version.data_tests
if unparsed_version.data_tests is not None
else self.data_tests
)
@dataclass @dataclass
@@ -403,7 +409,7 @@ class SourceTablePatch(dbtClassMixin):
freshness: Optional[FreshnessThreshold] = field(default_factory=FreshnessThreshold) freshness: Optional[FreshnessThreshold] = field(default_factory=FreshnessThreshold)
external: Optional[ExternalTable] = None external: Optional[ExternalTable] = None
tags: Optional[List[str]] = None tags: Optional[List[str]] = None
tests: Optional[List[TestDef]] = None data_tests: Optional[List[TestDef]] = None
columns: Optional[Sequence[UnparsedColumn]] = None columns: Optional[Sequence[UnparsedColumn]] = None
def to_patch_dict(self) -> Dict[str, Any]: def to_patch_dict(self) -> Dict[str, Any]:
@@ -783,3 +789,82 @@ def normalize_date(d: Optional[datetime.date]) -> Optional[datetime.datetime]:
dt = dt.astimezone() dt = dt.astimezone()
return dt return dt
class UnitTestFormat(StrEnum):
CSV = "csv"
Dict = "dict"
@dataclass
class UnitTestInputFixture(dbtClassMixin):
input: str
rows: Optional[Union[str, List[Dict[str, Any]]]] = None
format: UnitTestFormat = UnitTestFormat.Dict
fixture: Optional[str] = None
@dataclass
class UnitTestOutputFixture(dbtClassMixin):
rows: Optional[Union[str, List[Dict[str, Any]]]] = None
format: UnitTestFormat = UnitTestFormat.Dict
fixture: Optional[str] = None
@dataclass
class UnitTestOverrides(dbtClassMixin):
macros: Dict[str, Any] = field(default_factory=dict)
vars: Dict[str, Any] = field(default_factory=dict)
env_vars: Dict[str, Any] = field(default_factory=dict)
@dataclass
class UnitTestNodeVersion(dbtClassMixin):
include: Optional[List[NodeVersion]] = None
exclude: Optional[List[NodeVersion]] = None
@dataclass
class UnparsedUnitTest(dbtClassMixin):
name: str
model: str # name of the model being unit tested
given: Sequence[UnitTestInputFixture]
expect: UnitTestOutputFixture
description: str = ""
versions: Optional[UnitTestNodeVersion] = None
overrides: Optional[UnitTestOverrides] = None
config: Dict[str, Any] = field(default_factory=dict)
@classmethod
def validate(cls, data):
super(UnparsedUnitTest, cls).validate(data)
if data.get("version", None):
if data["version"].get("include") and data["version"].get("exclude"):
raise ValidationError("Unit tests can not both include and exclude versions.")
@dataclass
class UnitTestPatch(dbtClassMixin, Replaceable):
name: str
model: str
given: Sequence[UnitTestInputFixture]
expect: UnitTestOutputFixture
overrides: str = field(
metadata=dict(description="The package of the unit test to override"),
)
path: Path = field(
metadata=dict(description="The path to the patch-defining yml file"),
)
config: Dict[str, Any] = field(default_factory=dict)
versions: Optional[UnitTestNodeVersion] = None
description: Optional[str] = None
schema: Optional[str] = None
def to_patch_dict(self) -> Dict[str, Any]:
dct = self.to_dict(omit_none=True)
remove_keys = ("name", "overrides", "path")
for key in remove_keys:
if key in dct:
del dct[key]
return dct

View File

@@ -1,3 +1,4 @@
from dbt import deprecations
from dbt.contracts.util import Replaceable, Mergeable, list_str, Identifier from dbt.contracts.util import Replaceable, Mergeable, list_str, Identifier
from dbt.adapters.contracts.connection import QueryComment, UserConfigContract from dbt.adapters.contracts.connection import QueryComment, UserConfigContract
from dbt.common.helper_types import NoValue from dbt.common.helper_types import NoValue
@@ -194,7 +195,7 @@ class Project(dbtClassMixin, Replaceable):
source_paths: Optional[List[str]] = None source_paths: Optional[List[str]] = None
model_paths: Optional[List[str]] = None model_paths: Optional[List[str]] = None
macro_paths: Optional[List[str]] = None macro_paths: Optional[List[str]] = None
data_paths: Optional[List[str]] = None data_paths: Optional[List[str]] = None # deprecated
seed_paths: Optional[List[str]] = None seed_paths: Optional[List[str]] = None
test_paths: Optional[List[str]] = None test_paths: Optional[List[str]] = None
analysis_paths: Optional[List[str]] = None analysis_paths: Optional[List[str]] = None
@@ -216,7 +217,9 @@ class Project(dbtClassMixin, Replaceable):
snapshots: Dict[str, Any] = field(default_factory=dict) snapshots: Dict[str, Any] = field(default_factory=dict)
analyses: Dict[str, Any] = field(default_factory=dict) analyses: Dict[str, Any] = field(default_factory=dict)
sources: Dict[str, Any] = field(default_factory=dict) sources: Dict[str, Any] = field(default_factory=dict)
tests: Dict[str, Any] = field(default_factory=dict) tests: Dict[str, Any] = field(default_factory=dict) # deprecated
data_tests: Dict[str, Any] = field(default_factory=dict)
unit_tests: Dict[str, Any] = field(default_factory=dict)
metrics: Dict[str, Any] = field(default_factory=dict) metrics: Dict[str, Any] = field(default_factory=dict)
semantic_models: Dict[str, Any] = field(default_factory=dict) semantic_models: Dict[str, Any] = field(default_factory=dict)
saved_queries: Dict[str, Any] = field(default_factory=dict) saved_queries: Dict[str, Any] = field(default_factory=dict)
@@ -280,6 +283,14 @@ class Project(dbtClassMixin, Replaceable):
raise ValidationError( raise ValidationError(
f"Invalid dbt_cloud config. Expected a 'dict' but got '{type(data['dbt_cloud'])}'" f"Invalid dbt_cloud config. Expected a 'dict' but got '{type(data['dbt_cloud'])}'"
) )
if data.get("tests", None) and data.get("data_tests", None):
raise ValidationError(
"Invalid project config: cannot have both 'tests' and 'data_tests' defined"
)
if "tests" in data:
deprecations.warn(
"project-test-config", deprecated_path="tests", exp_path="data_tests"
)
@dataclass @dataclass

View File

@@ -1,7 +1,12 @@
import threading import threading
from dbt.contracts.graph.unparsed import FreshnessThreshold from dbt.contracts.graph.unparsed import FreshnessThreshold
from dbt.contracts.graph.nodes import CompiledNode, SourceDefinition, ResultNode from dbt.contracts.graph.nodes import (
CompiledNode,
SourceDefinition,
ResultNode,
UnitTestDefinition,
)
from dbt.contracts.util import ( from dbt.contracts.util import (
BaseArtifactMetadata, BaseArtifactMetadata,
ArtifactMixin, ArtifactMixin,
@@ -154,7 +159,7 @@ class BaseResult(dbtClassMixin):
@dataclass @dataclass
class NodeResult(BaseResult): class NodeResult(BaseResult):
node: ResultNode node: Union[ResultNode, UnitTestDefinition]
@dataclass @dataclass

View File

@@ -51,6 +51,8 @@ class PackageInstallPathDeprecation(DBTDeprecation):
_event = "PackageInstallPathDeprecation" _event = "PackageInstallPathDeprecation"
# deprecations with a pattern of `project-config-*` for the name are not hardcoded
# they are called programatically via the pattern below
class ConfigSourcePathDeprecation(DBTDeprecation): class ConfigSourcePathDeprecation(DBTDeprecation):
_name = "project-config-source-paths" _name = "project-config-source-paths"
_event = "ConfigSourcePathDeprecation" _event = "ConfigSourcePathDeprecation"
@@ -61,6 +63,26 @@ class ConfigDataPathDeprecation(DBTDeprecation):
_event = "ConfigDataPathDeprecation" _event = "ConfigDataPathDeprecation"
class ConfigLogPathDeprecation(DBTDeprecation):
_name = "project-config-log-path"
_event = "ConfigLogPathDeprecation"
class ConfigTargetPathDeprecation(DBTDeprecation):
_name = "project-config-target-path"
_event = "ConfigTargetPathDeprecation"
def renamed_method(old_name: str, new_name: str):
class AdapterDeprecationWarning(DBTDeprecation):
_name = "adapter:{}".format(old_name)
_event = "AdapterDeprecationWarning"
dep = AdapterDeprecationWarning()
deprecations_list.append(dep)
deprecations[dep.name] = dep
class MetricAttributesRenamed(DBTDeprecation): class MetricAttributesRenamed(DBTDeprecation):
_name = "metric-attr-renamed" _name = "metric-attr-renamed"
_event = "MetricAttributesRenamed" _event = "MetricAttributesRenamed"
@@ -71,14 +93,14 @@ class ExposureNameDeprecation(DBTDeprecation):
_event = "ExposureNameDeprecation" _event = "ExposureNameDeprecation"
class ConfigLogPathDeprecation(DBTDeprecation): class CollectFreshnessReturnSignature(DBTDeprecation):
_name = "project-config-log-path" _name = "collect-freshness-return-signature"
_event = "ConfigLogPathDeprecation" _event = "CollectFreshnessReturnSignature"
class ConfigTargetPathDeprecation(DBTDeprecation): class TestsConfigDeprecation(DBTDeprecation):
_name = "project-config-target-path" _name = "project-test-config"
_event = "ConfigTargetPathDeprecation" _event = "TestsConfigDeprecation"
def renamed_env_var(old_name: str, new_name: str): def renamed_env_var(old_name: str, new_name: str):
@@ -114,10 +136,11 @@ deprecations_list: List[DBTDeprecation] = [
PackageInstallPathDeprecation(), PackageInstallPathDeprecation(),
ConfigSourcePathDeprecation(), ConfigSourcePathDeprecation(),
ConfigDataPathDeprecation(), ConfigDataPathDeprecation(),
MetricAttributesRenamed(),
ExposureNameDeprecation(), ExposureNameDeprecation(),
ConfigLogPathDeprecation(), ConfigLogPathDeprecation(),
ConfigTargetPathDeprecation(), ConfigTargetPathDeprecation(),
TestsConfigDeprecation(),
CollectFreshnessReturnSignature(),
] ]
deprecations: Dict[str, DBTDeprecation] = {d.name: d for d in deprecations_list} deprecations: Dict[str, DBTDeprecation] = {d.name: d for d in deprecations_list}

View File

@@ -126,6 +126,19 @@ message ConfigTargetPathDeprecationMsg {
ConfigTargetPathDeprecation data = 2; ConfigTargetPathDeprecation data = 2;
} }
// D013
message TestsConfigDeprecation {
string deprecated_path = 1;
string exp_path = 2;
}
message TestsConfigDeprecationMsg {
CoreEventInfo info = 1;
TestsConfigDeprecation data = 2;
}
// I065 // I065
message DeprecatedModel { message DeprecatedModel {
string model_name = 1; string model_name = 1;

View File

@@ -2,10 +2,10 @@
# Generated by the protocol buffer compiler. DO NOT EDIT! # Generated by the protocol buffer compiler. DO NOT EDIT!
# source: core_types.proto # source: core_types.proto
"""Generated protocol buffer code.""" """Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports) # @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default() _sym_db = _symbol_database.Default()
@@ -14,66 +14,69 @@ _sym_db = _symbol_database.Default()
from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2 from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10\x63ore_types.proto\x12\x0bproto_types\x1a\x1fgoogle/protobuf/timestamp.proto\"\x99\x02\n\rCoreEventInfo\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04\x63ode\x18\x02 \x01(\t\x12\x0b\n\x03msg\x18\x03 \x01(\t\x12\r\n\x05level\x18\x04 \x01(\t\x12\x15\n\rinvocation_id\x18\x05 \x01(\t\x12\x0b\n\x03pid\x18\x06 \x01(\x05\x12\x0e\n\x06thread\x18\x07 \x01(\t\x12&\n\x02ts\x18\x08 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x34\n\x05\x65xtra\x18\t \x03(\x0b\x32%.proto_types.CoreEventInfo.ExtraEntry\x12\x10\n\x08\x63\x61tegory\x18\n \x01(\t\x1a,\n\nExtraEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"@\n\x1aPackageRedirectDeprecation\x12\x10\n\x08old_name\x18\x01 \x01(\t\x12\x10\n\x08new_name\x18\x02 \x01(\t\"\x80\x01\n\x1dPackageRedirectDeprecationMsg\x12(\n\x04info\x18\x01 \x01(\x0b\x32\x1a.proto_types.CoreEventInfo\x12\x35\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\'.proto_types.PackageRedirectDeprecation\"\x1f\n\x1dPackageInstallPathDeprecation\"\x86\x01\n PackageInstallPathDeprecationMsg\x12(\n\x04info\x18\x01 \x01(\x0b\x32\x1a.proto_types.CoreEventInfo\x12\x38\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32*.proto_types.PackageInstallPathDeprecation\"H\n\x1b\x43onfigSourcePathDeprecation\x12\x17\n\x0f\x64\x65precated_path\x18\x01 \x01(\t\x12\x10\n\x08\x65xp_path\x18\x02 \x01(\t\"\x82\x01\n\x1e\x43onfigSourcePathDeprecationMsg\x12(\n\x04info\x18\x01 \x01(\x0b\x32\x1a.proto_types.CoreEventInfo\x12\x36\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32(.proto_types.ConfigSourcePathDeprecation\"F\n\x19\x43onfigDataPathDeprecation\x12\x17\n\x0f\x64\x65precated_path\x18\x01 \x01(\t\x12\x10\n\x08\x65xp_path\x18\x02 \x01(\t\"~\n\x1c\x43onfigDataPathDeprecationMsg\x12(\n\x04info\x18\x01 \x01(\x0b\x32\x1a.proto_types.CoreEventInfo\x12\x34\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32&.proto_types.ConfigDataPathDeprecation\".\n\x17MetricAttributesRenamed\x12\x13\n\x0bmetric_name\x18\x01 \x01(\t\"z\n\x1aMetricAttributesRenamedMsg\x12(\n\x04info\x18\x01 \x01(\x0b\x32\x1a.proto_types.CoreEventInfo\x12\x32\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32$.proto_types.MetricAttributesRenamed\"+\n\x17\x45xposureNameDeprecation\x12\x10\n\x08\x65xposure\x18\x01 \x01(\t\"z\n\x1a\x45xposureNameDeprecationMsg\x12(\n\x04info\x18\x01 \x01(\x0b\x32\x1a.proto_types.CoreEventInfo\x12\x32\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32$.proto_types.ExposureNameDeprecation\"^\n\x13InternalDeprecation\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0e\n\x06reason\x18\x02 \x01(\t\x12\x18\n\x10suggested_action\x18\x03 \x01(\t\x12\x0f\n\x07version\x18\x04 \x01(\t\"r\n\x16InternalDeprecationMsg\x12(\n\x04info\x18\x01 \x01(\x0b\x32\x1a.proto_types.CoreEventInfo\x12.\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32 .proto_types.InternalDeprecation\"@\n\x1a\x45nvironmentVariableRenamed\x12\x10\n\x08old_name\x18\x01 \x01(\t\x12\x10\n\x08new_name\x18\x02 \x01(\t\"\x80\x01\n\x1d\x45nvironmentVariableRenamedMsg\x12(\n\x04info\x18\x01 \x01(\x0b\x32\x1a.proto_types.CoreEventInfo\x12\x35\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\'.proto_types.EnvironmentVariableRenamed\"3\n\x18\x43onfigLogPathDeprecation\x12\x17\n\x0f\x64\x65precated_path\x18\x01 \x01(\t\"|\n\x1b\x43onfigLogPathDeprecationMsg\x12(\n\x04info\x18\x01 \x01(\x0b\x32\x1a.proto_types.CoreEventInfo\x12\x33\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32%.proto_types.ConfigLogPathDeprecation\"6\n\x1b\x43onfigTargetPathDeprecation\x12\x17\n\x0f\x64\x65precated_path\x18\x01 \x01(\t\"\x82\x01\n\x1e\x43onfigTargetPathDeprecationMsg\x12(\n\x04info\x18\x01 \x01(\x0b\x32\x1a.proto_types.CoreEventInfo\x12\x36\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32(.proto_types.ConfigTargetPathDeprecation\"V\n\x0f\x44\x65precatedModel\x12\x12\n\nmodel_name\x18\x01 \x01(\t\x12\x15\n\rmodel_version\x18\x02 \x01(\t\x12\x18\n\x10\x64\x65precation_date\x18\x03 \x01(\t\"j\n\x12\x44\x65precatedModelMsg\x12(\n\x04info\x18\x01 \x01(\x0b\x32\x1a.proto_types.CoreEventInfo\x12*\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x1c.proto_types.DeprecatedModel\"/\n\x17\x44\x65psScrubbedPackageName\x12\x14\n\x0cpackage_name\x18\x01 \x01(\t\"z\n\x1a\x44\x65psScrubbedPackageNameMsg\x12(\n\x04info\x18\x01 \x01(\x0b\x32\x1a.proto_types.CoreEventInfo\x12\x32\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32$.proto_types.DepsScrubbedPackageNameb\x06proto3') DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10\x63ore_types.proto\x12\x0bproto_types\x1a\x1fgoogle/protobuf/timestamp.proto\"\x99\x02\n\rCoreEventInfo\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04\x63ode\x18\x02 \x01(\t\x12\x0b\n\x03msg\x18\x03 \x01(\t\x12\r\n\x05level\x18\x04 \x01(\t\x12\x15\n\rinvocation_id\x18\x05 \x01(\t\x12\x0b\n\x03pid\x18\x06 \x01(\x05\x12\x0e\n\x06thread\x18\x07 \x01(\t\x12&\n\x02ts\x18\x08 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x34\n\x05\x65xtra\x18\t \x03(\x0b\x32%.proto_types.CoreEventInfo.ExtraEntry\x12\x10\n\x08\x63\x61tegory\x18\n \x01(\t\x1a,\n\nExtraEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"@\n\x1aPackageRedirectDeprecation\x12\x10\n\x08old_name\x18\x01 \x01(\t\x12\x10\n\x08new_name\x18\x02 \x01(\t\"\x80\x01\n\x1dPackageRedirectDeprecationMsg\x12(\n\x04info\x18\x01 \x01(\x0b\x32\x1a.proto_types.CoreEventInfo\x12\x35\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\'.proto_types.PackageRedirectDeprecation\"\x1f\n\x1dPackageInstallPathDeprecation\"\x86\x01\n PackageInstallPathDeprecationMsg\x12(\n\x04info\x18\x01 \x01(\x0b\x32\x1a.proto_types.CoreEventInfo\x12\x38\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32*.proto_types.PackageInstallPathDeprecation\"H\n\x1b\x43onfigSourcePathDeprecation\x12\x17\n\x0f\x64\x65precated_path\x18\x01 \x01(\t\x12\x10\n\x08\x65xp_path\x18\x02 \x01(\t\"\x82\x01\n\x1e\x43onfigSourcePathDeprecationMsg\x12(\n\x04info\x18\x01 \x01(\x0b\x32\x1a.proto_types.CoreEventInfo\x12\x36\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32(.proto_types.ConfigSourcePathDeprecation\"F\n\x19\x43onfigDataPathDeprecation\x12\x17\n\x0f\x64\x65precated_path\x18\x01 \x01(\t\x12\x10\n\x08\x65xp_path\x18\x02 \x01(\t\"~\n\x1c\x43onfigDataPathDeprecationMsg\x12(\n\x04info\x18\x01 \x01(\x0b\x32\x1a.proto_types.CoreEventInfo\x12\x34\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32&.proto_types.ConfigDataPathDeprecation\".\n\x17MetricAttributesRenamed\x12\x13\n\x0bmetric_name\x18\x01 \x01(\t\"z\n\x1aMetricAttributesRenamedMsg\x12(\n\x04info\x18\x01 \x01(\x0b\x32\x1a.proto_types.CoreEventInfo\x12\x32\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32$.proto_types.MetricAttributesRenamed\"+\n\x17\x45xposureNameDeprecation\x12\x10\n\x08\x65xposure\x18\x01 \x01(\t\"z\n\x1a\x45xposureNameDeprecationMsg\x12(\n\x04info\x18\x01 \x01(\x0b\x32\x1a.proto_types.CoreEventInfo\x12\x32\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32$.proto_types.ExposureNameDeprecation\"^\n\x13InternalDeprecation\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0e\n\x06reason\x18\x02 \x01(\t\x12\x18\n\x10suggested_action\x18\x03 \x01(\t\x12\x0f\n\x07version\x18\x04 \x01(\t\"r\n\x16InternalDeprecationMsg\x12(\n\x04info\x18\x01 \x01(\x0b\x32\x1a.proto_types.CoreEventInfo\x12.\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32 .proto_types.InternalDeprecation\"@\n\x1a\x45nvironmentVariableRenamed\x12\x10\n\x08old_name\x18\x01 \x01(\t\x12\x10\n\x08new_name\x18\x02 \x01(\t\"\x80\x01\n\x1d\x45nvironmentVariableRenamedMsg\x12(\n\x04info\x18\x01 \x01(\x0b\x32\x1a.proto_types.CoreEventInfo\x12\x35\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\'.proto_types.EnvironmentVariableRenamed\"3\n\x18\x43onfigLogPathDeprecation\x12\x17\n\x0f\x64\x65precated_path\x18\x01 \x01(\t\"|\n\x1b\x43onfigLogPathDeprecationMsg\x12(\n\x04info\x18\x01 \x01(\x0b\x32\x1a.proto_types.CoreEventInfo\x12\x33\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32%.proto_types.ConfigLogPathDeprecation\"6\n\x1b\x43onfigTargetPathDeprecation\x12\x17\n\x0f\x64\x65precated_path\x18\x01 \x01(\t\"\x82\x01\n\x1e\x43onfigTargetPathDeprecationMsg\x12(\n\x04info\x18\x01 \x01(\x0b\x32\x1a.proto_types.CoreEventInfo\x12\x36\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32(.proto_types.ConfigTargetPathDeprecation\"C\n\x16TestsConfigDeprecation\x12\x17\n\x0f\x64\x65precated_path\x18\x01 \x01(\t\x12\x10\n\x08\x65xp_path\x18\x02 \x01(\t\"x\n\x19TestsConfigDeprecationMsg\x12(\n\x04info\x18\x01 \x01(\x0b\x32\x1a.proto_types.CoreEventInfo\x12\x31\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32#.proto_types.TestsConfigDeprecation\"V\n\x0f\x44\x65precatedModel\x12\x12\n\nmodel_name\x18\x01 \x01(\t\x12\x15\n\rmodel_version\x18\x02 \x01(\t\x12\x18\n\x10\x64\x65precation_date\x18\x03 \x01(\t\"j\n\x12\x44\x65precatedModelMsg\x12(\n\x04info\x18\x01 \x01(\x0b\x32\x1a.proto_types.CoreEventInfo\x12*\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x1c.proto_types.DeprecatedModel\"/\n\x17\x44\x65psScrubbedPackageName\x12\x14\n\x0cpackage_name\x18\x01 \x01(\t\"z\n\x1a\x44\x65psScrubbedPackageNameMsg\x12(\n\x04info\x18\x01 \x01(\x0b\x32\x1a.proto_types.CoreEventInfo\x12\x32\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32$.proto_types.DepsScrubbedPackageNameb\x06proto3')
_globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'core_types_pb2', globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'core_types_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False: if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None DESCRIPTOR._options = None
_COREEVENTINFO_EXTRAENTRY._options = None _COREEVENTINFO_EXTRAENTRY._options = None
_COREEVENTINFO_EXTRAENTRY._serialized_options = b'8\001' _COREEVENTINFO_EXTRAENTRY._serialized_options = b'8\001'
_globals['_COREEVENTINFO']._serialized_start=67 _COREEVENTINFO._serialized_start=67
_globals['_COREEVENTINFO']._serialized_end=348 _COREEVENTINFO._serialized_end=348
_globals['_COREEVENTINFO_EXTRAENTRY']._serialized_start=304 _COREEVENTINFO_EXTRAENTRY._serialized_start=304
_globals['_COREEVENTINFO_EXTRAENTRY']._serialized_end=348 _COREEVENTINFO_EXTRAENTRY._serialized_end=348
_globals['_PACKAGEREDIRECTDEPRECATION']._serialized_start=350 _PACKAGEREDIRECTDEPRECATION._serialized_start=350
_globals['_PACKAGEREDIRECTDEPRECATION']._serialized_end=414 _PACKAGEREDIRECTDEPRECATION._serialized_end=414
_globals['_PACKAGEREDIRECTDEPRECATIONMSG']._serialized_start=417 _PACKAGEREDIRECTDEPRECATIONMSG._serialized_start=417
_globals['_PACKAGEREDIRECTDEPRECATIONMSG']._serialized_end=545 _PACKAGEREDIRECTDEPRECATIONMSG._serialized_end=545
_globals['_PACKAGEINSTALLPATHDEPRECATION']._serialized_start=547 _PACKAGEINSTALLPATHDEPRECATION._serialized_start=547
_globals['_PACKAGEINSTALLPATHDEPRECATION']._serialized_end=578 _PACKAGEINSTALLPATHDEPRECATION._serialized_end=578
_globals['_PACKAGEINSTALLPATHDEPRECATIONMSG']._serialized_start=581 _PACKAGEINSTALLPATHDEPRECATIONMSG._serialized_start=581
_globals['_PACKAGEINSTALLPATHDEPRECATIONMSG']._serialized_end=715 _PACKAGEINSTALLPATHDEPRECATIONMSG._serialized_end=715
_globals['_CONFIGSOURCEPATHDEPRECATION']._serialized_start=717 _CONFIGSOURCEPATHDEPRECATION._serialized_start=717
_globals['_CONFIGSOURCEPATHDEPRECATION']._serialized_end=789 _CONFIGSOURCEPATHDEPRECATION._serialized_end=789
_globals['_CONFIGSOURCEPATHDEPRECATIONMSG']._serialized_start=792 _CONFIGSOURCEPATHDEPRECATIONMSG._serialized_start=792
_globals['_CONFIGSOURCEPATHDEPRECATIONMSG']._serialized_end=922 _CONFIGSOURCEPATHDEPRECATIONMSG._serialized_end=922
_globals['_CONFIGDATAPATHDEPRECATION']._serialized_start=924 _CONFIGDATAPATHDEPRECATION._serialized_start=924
_globals['_CONFIGDATAPATHDEPRECATION']._serialized_end=994 _CONFIGDATAPATHDEPRECATION._serialized_end=994
_globals['_CONFIGDATAPATHDEPRECATIONMSG']._serialized_start=996 _CONFIGDATAPATHDEPRECATIONMSG._serialized_start=996
_globals['_CONFIGDATAPATHDEPRECATIONMSG']._serialized_end=1122 _CONFIGDATAPATHDEPRECATIONMSG._serialized_end=1122
_globals['_METRICATTRIBUTESRENAMED']._serialized_start=1124 _METRICATTRIBUTESRENAMED._serialized_start=1124
_globals['_METRICATTRIBUTESRENAMED']._serialized_end=1170 _METRICATTRIBUTESRENAMED._serialized_end=1170
_globals['_METRICATTRIBUTESRENAMEDMSG']._serialized_start=1172 _METRICATTRIBUTESRENAMEDMSG._serialized_start=1172
_globals['_METRICATTRIBUTESRENAMEDMSG']._serialized_end=1294 _METRICATTRIBUTESRENAMEDMSG._serialized_end=1294
_globals['_EXPOSURENAMEDEPRECATION']._serialized_start=1296 _EXPOSURENAMEDEPRECATION._serialized_start=1296
_globals['_EXPOSURENAMEDEPRECATION']._serialized_end=1339 _EXPOSURENAMEDEPRECATION._serialized_end=1339
_globals['_EXPOSURENAMEDEPRECATIONMSG']._serialized_start=1341 _EXPOSURENAMEDEPRECATIONMSG._serialized_start=1341
_globals['_EXPOSURENAMEDEPRECATIONMSG']._serialized_end=1463 _EXPOSURENAMEDEPRECATIONMSG._serialized_end=1463
_globals['_INTERNALDEPRECATION']._serialized_start=1465 _INTERNALDEPRECATION._serialized_start=1465
_globals['_INTERNALDEPRECATION']._serialized_end=1559 _INTERNALDEPRECATION._serialized_end=1559
_globals['_INTERNALDEPRECATIONMSG']._serialized_start=1561 _INTERNALDEPRECATIONMSG._serialized_start=1561
_globals['_INTERNALDEPRECATIONMSG']._serialized_end=1675 _INTERNALDEPRECATIONMSG._serialized_end=1675
_globals['_ENVIRONMENTVARIABLERENAMED']._serialized_start=1677 _ENVIRONMENTVARIABLERENAMED._serialized_start=1677
_globals['_ENVIRONMENTVARIABLERENAMED']._serialized_end=1741 _ENVIRONMENTVARIABLERENAMED._serialized_end=1741
_globals['_ENVIRONMENTVARIABLERENAMEDMSG']._serialized_start=1744 _ENVIRONMENTVARIABLERENAMEDMSG._serialized_start=1744
_globals['_ENVIRONMENTVARIABLERENAMEDMSG']._serialized_end=1872 _ENVIRONMENTVARIABLERENAMEDMSG._serialized_end=1872
_globals['_CONFIGLOGPATHDEPRECATION']._serialized_start=1874 _CONFIGLOGPATHDEPRECATION._serialized_start=1874
_globals['_CONFIGLOGPATHDEPRECATION']._serialized_end=1925 _CONFIGLOGPATHDEPRECATION._serialized_end=1925
_globals['_CONFIGLOGPATHDEPRECATIONMSG']._serialized_start=1927 _CONFIGLOGPATHDEPRECATIONMSG._serialized_start=1927
_globals['_CONFIGLOGPATHDEPRECATIONMSG']._serialized_end=2051 _CONFIGLOGPATHDEPRECATIONMSG._serialized_end=2051
_globals['_CONFIGTARGETPATHDEPRECATION']._serialized_start=2053 _CONFIGTARGETPATHDEPRECATION._serialized_start=2053
_globals['_CONFIGTARGETPATHDEPRECATION']._serialized_end=2107 _CONFIGTARGETPATHDEPRECATION._serialized_end=2107
_globals['_CONFIGTARGETPATHDEPRECATIONMSG']._serialized_start=2110 _CONFIGTARGETPATHDEPRECATIONMSG._serialized_start=2110
_globals['_CONFIGTARGETPATHDEPRECATIONMSG']._serialized_end=2240 _CONFIGTARGETPATHDEPRECATIONMSG._serialized_end=2240
_globals['_DEPRECATEDMODEL']._serialized_start=2242 _TESTSCONFIGDEPRECATION._serialized_start=2242
_globals['_DEPRECATEDMODEL']._serialized_end=2328 _TESTSCONFIGDEPRECATION._serialized_end=2309
_globals['_DEPRECATEDMODELMSG']._serialized_start=2330 _TESTSCONFIGDEPRECATIONMSG._serialized_start=2311
_globals['_DEPRECATEDMODELMSG']._serialized_end=2436 _TESTSCONFIGDEPRECATIONMSG._serialized_end=2431
_globals['_DEPSSCRUBBEDPACKAGENAME']._serialized_start=2438 _DEPRECATEDMODEL._serialized_start=2433
_globals['_DEPSSCRUBBEDPACKAGENAME']._serialized_end=2485 _DEPRECATEDMODEL._serialized_end=2519
_globals['_DEPSSCRUBBEDPACKAGENAMEMSG']._serialized_start=2487 _DEPRECATEDMODELMSG._serialized_start=2521
_globals['_DEPSSCRUBBEDPACKAGENAMEMSG']._serialized_end=2609 _DEPRECATEDMODELMSG._serialized_end=2627
_DEPSSCRUBBEDPACKAGENAME._serialized_start=2629
_DEPSSCRUBBEDPACKAGENAME._serialized_end=2676
_DEPSSCRUBBEDPACKAGENAMEMSG._serialized_start=2678
_DEPSSCRUBBEDPACKAGENAMEMSG._serialized_end=2800
# @@protoc_insertion_point(module_scope) # @@protoc_insertion_point(module_scope)

View File

@@ -164,6 +164,18 @@ class ConfigTargetPathDeprecation(WarnLevel):
return line_wrap_message(warning_tag(f"Deprecated functionality\n\n{description}")) return line_wrap_message(warning_tag(f"Deprecated functionality\n\n{description}"))
class TestsConfigDeprecation(WarnLevel):
def code(self) -> str:
return "D013"
def message(self) -> str:
description = (
f"The `{self.deprecated_path}` config has been renamed to `{self.exp_path}`. "
"Please update your `dbt_project.yml` configuration to reflect this change."
)
return line_wrap_message(warning_tag(f"Deprecated functionality\n\n{description}"))
# ======================================================= # =======================================================
# M - Deps generation # M - Deps generation
# ======================================================= # =======================================================

View File

@@ -852,6 +852,12 @@ class InvalidAccessTypeError(ParsingError):
super().__init__(msg=msg) super().__init__(msg=msg)
class InvalidUnitTestGivenInput(ParsingError):
def __init__(self, input: str) -> None:
msg = f"Unit test given inputs must be either a 'ref', 'source' or 'this' call. Got: '{input}'."
super().__init__(msg=msg)
class SameKeyNestedError(CompilationError): class SameKeyNestedError(CompilationError):
def __init__(self) -> None: def __init__(self) -> None:
msg = "Test cannot have the same key at the top-level and in config" msg = "Test cannot have the same key at the top-level and in config"
@@ -1021,8 +1027,9 @@ class TargetNotFoundError(CompilationError):
return msg return msg
class DuplicateSourcePatchNameError(CompilationError): class DuplicatePatchNameError(CompilationError):
def __init__(self, patch_1, patch_2): def __init__(self, node_type, patch_1, patch_2):
self.node_type = node_type
self.patch_1 = patch_1 self.patch_1 = patch_1
self.patch_2 = patch_2 self.patch_2 = patch_2
super().__init__(msg=self.get_message()) super().__init__(msg=self.get_message())
@@ -1033,11 +1040,11 @@ class DuplicateSourcePatchNameError(CompilationError):
self.patch_1.path, self.patch_1.path,
self.patch_2.path, self.patch_2.path,
name, name,
"sources", self.node_type.pluralize(),
) )
msg = ( msg = (
f"dbt found two schema.yml entries for the same source named " f"dbt found two schema.yml entries for the same {self.node_type} named "
f"{self.patch_1.name} in package {self.patch_1.overrides}. Sources may only be " f"{self.patch_1.name} in package {self.patch_1.overrides}. {self.node_type.pluralize()} may only be "
f"overridden a single time. To fix this, {fix}" f"overridden a single time. To fix this, {fix}"
) )
return msg return msg

View File

@@ -14,7 +14,7 @@ class Graph:
""" """
def __init__(self, graph) -> None: def __init__(self, graph) -> None:
self.graph = graph self.graph: nx.DiGraph = graph
def nodes(self) -> Set[UniqueId]: def nodes(self) -> Set[UniqueId]:
return set(self.graph.nodes()) return set(self.graph.nodes())
@@ -83,10 +83,10 @@ class Graph:
removed nodes are preserved as explicit new edges. removed nodes are preserved as explicit new edges.
""" """
new_graph = self.graph.copy() new_graph: nx.DiGraph = self.graph.copy()
include_nodes = set(selected) include_nodes: Set[UniqueId] = set(selected)
still_removing = True still_removing: bool = True
while still_removing: while still_removing:
nodes_to_remove = list( nodes_to_remove = list(
node node
@@ -129,6 +129,8 @@ class Graph:
return Graph(new_graph) return Graph(new_graph)
def subgraph(self, nodes: Iterable[UniqueId]) -> "Graph": def subgraph(self, nodes: Iterable[UniqueId]) -> "Graph":
# Take the original networkx graph and return a subgraph containing only
# the selected unique_id nodes.
return Graph(self.graph.subgraph(nodes)) return Graph(self.graph.subgraph(nodes))
def get_dependent_nodes(self, node: UniqueId): def get_dependent_nodes(self, node: UniqueId):

View File

@@ -31,6 +31,8 @@ def can_select_indirectly(node):
""" """
if node.resource_type == NodeType.Test: if node.resource_type == NodeType.Test:
return True return True
elif node.resource_type == NodeType.Unit:
return True
else: else:
return False return False
@@ -46,8 +48,8 @@ class NodeSelector(MethodManager):
include_empty_nodes: bool = False, include_empty_nodes: bool = False,
) -> None: ) -> None:
super().__init__(manifest, previous_state) super().__init__(manifest, previous_state)
self.full_graph = graph self.full_graph: Graph = graph
self.include_empty_nodes = include_empty_nodes self.include_empty_nodes: bool = include_empty_nodes
# build a subgraph containing only non-empty, enabled nodes and enabled # build a subgraph containing only non-empty, enabled nodes and enabled
# sources. # sources.
@@ -171,9 +173,12 @@ class NodeSelector(MethodManager):
elif unique_id in self.manifest.semantic_models: elif unique_id in self.manifest.semantic_models:
semantic_model = self.manifest.semantic_models[unique_id] semantic_model = self.manifest.semantic_models[unique_id]
return semantic_model.config.enabled return semantic_model.config.enabled
elif unique_id in self.manifest.unit_tests:
return True
elif unique_id in self.manifest.saved_queries: elif unique_id in self.manifest.saved_queries:
saved_query = self.manifest.saved_queries[unique_id] saved_query = self.manifest.saved_queries[unique_id]
return saved_query.config.enabled return saved_query.config.enabled
node = self.manifest.nodes[unique_id] node = self.manifest.nodes[unique_id]
if self.include_empty_nodes: if self.include_empty_nodes:
@@ -199,6 +204,8 @@ class NodeSelector(MethodManager):
node = self.manifest.metrics[unique_id] node = self.manifest.metrics[unique_id]
elif unique_id in self.manifest.semantic_models: elif unique_id in self.manifest.semantic_models:
node = self.manifest.semantic_models[unique_id] node = self.manifest.semantic_models[unique_id]
elif unique_id in self.manifest.unit_tests:
node = self.manifest.unit_tests[unique_id]
elif unique_id in self.manifest.saved_queries: elif unique_id in self.manifest.saved_queries:
node = self.manifest.saved_queries[unique_id] node = self.manifest.saved_queries[unique_id]
else: else:
@@ -246,8 +253,13 @@ class NodeSelector(MethodManager):
) )
for unique_id in self.graph.select_successors(selected): for unique_id in self.graph.select_successors(selected):
if unique_id in self.manifest.nodes: if unique_id in self.manifest.nodes or unique_id in self.manifest.unit_tests:
node = self.manifest.nodes[unique_id] if unique_id in self.manifest.nodes:
node = self.manifest.nodes[unique_id]
elif unique_id in self.manifest.unit_tests:
node = self.manifest.unit_tests[unique_id] # type: ignore
# Test nodes that are not selected themselves, but whose parents are selected.
# (Does not include unit tests because they can only have one parent.)
if can_select_indirectly(node): if can_select_indirectly(node):
# should we add it in directly? # should we add it in directly?
if indirect_selection == IndirectSelection.Eager or set( if indirect_selection == IndirectSelection.Eager or set(
@@ -315,8 +327,11 @@ class NodeSelector(MethodManager):
"""Returns a queue over nodes in the graph that tracks progress of """Returns a queue over nodes in the graph that tracks progress of
dependecies. dependecies.
""" """
# Filtering hapens in get_selected
selected_nodes = self.get_selected(spec) selected_nodes = self.get_selected(spec)
# Save to global variable
selected_resources.set_selected_resources(selected_nodes) selected_resources.set_selected_resources(selected_nodes)
# Construct a new graph using the selected_nodes
new_graph = self.full_graph.get_subset_graph(selected_nodes) new_graph = self.full_graph.get_subset_graph(selected_nodes)
# should we give a way here for consumers to mutate the graph? # should we give a way here for consumers to mutate the graph?
return GraphQueue(new_graph.graph, self.manifest, selected_nodes) return GraphQueue(new_graph.graph, self.manifest, selected_nodes)

View File

@@ -18,6 +18,7 @@ from dbt.contracts.graph.nodes import (
ResultNode, ResultNode,
ManifestNode, ManifestNode,
ModelNode, ModelNode,
UnitTestDefinition,
SavedQuery, SavedQuery,
SemanticModel, SemanticModel,
) )
@@ -101,7 +102,9 @@ def is_selected_node(fqn: List[str], node_selector: str, is_versioned: bool) ->
return True return True
SelectorTarget = Union[SourceDefinition, ManifestNode, Exposure, Metric] SelectorTarget = Union[
SourceDefinition, ManifestNode, Exposure, Metric, SemanticModel, UnitTestDefinition
]
class SelectorMethod(metaclass=abc.ABCMeta): class SelectorMethod(metaclass=abc.ABCMeta):
@@ -148,6 +151,21 @@ class SelectorMethod(metaclass=abc.ABCMeta):
continue continue
yield unique_id, metric yield unique_id, metric
def unit_tests(
self, included_nodes: Set[UniqueId]
) -> Iterator[Tuple[UniqueId, UnitTestDefinition]]:
for unique_id, unit_test in self.manifest.unit_tests.items():
unique_id = UniqueId(unique_id)
if unique_id not in included_nodes:
continue
yield unique_id, unit_test
def parsed_and_unit_nodes(self, included_nodes: Set[UniqueId]):
yield from chain(
self.parsed_nodes(included_nodes),
self.unit_tests(included_nodes),
)
def semantic_model_nodes( def semantic_model_nodes(
self, included_nodes: Set[UniqueId] self, included_nodes: Set[UniqueId]
) -> Iterator[Tuple[UniqueId, SemanticModel]]: ) -> Iterator[Tuple[UniqueId, SemanticModel]]:
@@ -176,6 +194,7 @@ class SelectorMethod(metaclass=abc.ABCMeta):
self.source_nodes(included_nodes), self.source_nodes(included_nodes),
self.exposure_nodes(included_nodes), self.exposure_nodes(included_nodes),
self.metric_nodes(included_nodes), self.metric_nodes(included_nodes),
self.unit_tests(included_nodes),
self.semantic_model_nodes(included_nodes), self.semantic_model_nodes(included_nodes),
) )
@@ -192,6 +211,7 @@ class SelectorMethod(metaclass=abc.ABCMeta):
self.parsed_nodes(included_nodes), self.parsed_nodes(included_nodes),
self.exposure_nodes(included_nodes), self.exposure_nodes(included_nodes),
self.metric_nodes(included_nodes), self.metric_nodes(included_nodes),
self.unit_tests(included_nodes),
self.semantic_model_nodes(included_nodes), self.semantic_model_nodes(included_nodes),
self.saved_query_nodes(included_nodes), self.saved_query_nodes(included_nodes),
) )
@@ -519,30 +539,37 @@ class TestNameSelectorMethod(SelectorMethod):
__test__ = False __test__ = False
def search(self, included_nodes: Set[UniqueId], selector: str) -> Iterator[UniqueId]: def search(self, included_nodes: Set[UniqueId], selector: str) -> Iterator[UniqueId]:
for node, real_node in self.parsed_nodes(included_nodes): for unique_id, node in self.parsed_and_unit_nodes(included_nodes):
if real_node.resource_type == NodeType.Test and hasattr(real_node, "test_metadata"): if node.resource_type == NodeType.Test and hasattr(node, "test_metadata"):
if fnmatch(real_node.test_metadata.name, selector): # type: ignore[union-attr] if fnmatch(node.test_metadata.name, selector): # type: ignore[union-attr]
yield node yield unique_id
elif node.resource_type == NodeType.Unit:
if fnmatch(node.name, selector):
yield unique_id
class TestTypeSelectorMethod(SelectorMethod): class TestTypeSelectorMethod(SelectorMethod):
__test__ = False __test__ = False
def search(self, included_nodes: Set[UniqueId], selector: str) -> Iterator[UniqueId]: def search(self, included_nodes: Set[UniqueId], selector: str) -> Iterator[UniqueId]:
search_type: Type search_types: List[Any]
# continue supporting 'schema' + 'data' for backwards compatibility # continue supporting 'schema' + 'data' for backwards compatibility
if selector in ("generic", "schema"): if selector in ("generic", "schema"):
search_type = GenericTestNode search_types = [GenericTestNode]
elif selector in ("singular", "data"): elif selector in ("data"):
search_type = SingularTestNode search_types = [GenericTestNode, SingularTestNode]
elif selector in ("singular"):
search_types = [SingularTestNode]
elif selector in ("unit"):
search_types = [UnitTestDefinition]
else: else:
raise DbtRuntimeError( raise DbtRuntimeError(
f'Invalid test type selector {selector}: expected "generic" or ' '"singular"' f'Invalid test type selector {selector}: expected "generic", "singular", "unit", or "data"'
) )
for node, real_node in self.parsed_nodes(included_nodes): for unique_id, node in self.parsed_and_unit_nodes(included_nodes):
if isinstance(real_node, search_type): if isinstance(node, tuple(search_types)):
yield node yield unique_id
class StateSelectorMethod(SelectorMethod): class StateSelectorMethod(SelectorMethod):
@@ -618,7 +645,9 @@ class StateSelectorMethod(SelectorMethod):
def check_modified_content( def check_modified_content(
self, old: Optional[SelectorTarget], new: SelectorTarget, adapter_type: str self, old: Optional[SelectorTarget], new: SelectorTarget, adapter_type: str
) -> bool: ) -> bool:
if isinstance(new, (SourceDefinition, Exposure, Metric, SemanticModel)): if isinstance(
new, (SourceDefinition, Exposure, Metric, SemanticModel, UnitTestDefinition)
):
# these all overwrite `same_contents` # these all overwrite `same_contents`
different_contents = not new.same_contents(old) # type: ignore different_contents = not new.same_contents(old) # type: ignore
else: else:
@@ -698,17 +727,21 @@ class StateSelectorMethod(SelectorMethod):
manifest: WritableManifest = self.previous_state.manifest manifest: WritableManifest = self.previous_state.manifest
for node, real_node in self.all_nodes(included_nodes): for unique_id, node in self.all_nodes(included_nodes):
previous_node: Optional[SelectorTarget] = None previous_node: Optional[SelectorTarget] = None
if node in manifest.nodes: if unique_id in manifest.nodes:
previous_node = manifest.nodes[node] previous_node = manifest.nodes[unique_id]
elif node in manifest.sources: elif unique_id in manifest.sources:
previous_node = manifest.sources[node] previous_node = manifest.sources[unique_id]
elif node in manifest.exposures: elif unique_id in manifest.exposures:
previous_node = manifest.exposures[node] previous_node = manifest.exposures[unique_id]
elif node in manifest.metrics: elif unique_id in manifest.metrics:
previous_node = manifest.metrics[node] previous_node = manifest.metrics[unique_id]
elif unique_id in manifest.semantic_models:
previous_node = manifest.semantic_models[unique_id]
elif unique_id in manifest.unit_tests:
previous_node = manifest.unit_tests[unique_id]
keyword_args = {} keyword_args = {}
if checker.__name__ in [ if checker.__name__ in [
@@ -718,8 +751,8 @@ class StateSelectorMethod(SelectorMethod):
]: ]:
keyword_args["adapter_type"] = adapter_type # type: ignore keyword_args["adapter_type"] = adapter_type # type: ignore
if checker(previous_node, real_node, **keyword_args): # type: ignore if checker(previous_node, node, **keyword_args): # type: ignore
yield node yield unique_id
class ResultSelectorMethod(SelectorMethod): class ResultSelectorMethod(SelectorMethod):

View File

@@ -100,6 +100,7 @@ class SelectionCriteria:
except ValueError as exc: except ValueError as exc:
raise InvalidSelectorError(f"'{method_parts[0]}' is not a valid method name") from exc raise InvalidSelectorError(f"'{method_parts[0]}' is not a valid method name") from exc
# Following is for cases like config.severity and config.materialized
method_arguments: List[str] = method_parts[1:] method_arguments: List[str] = method_parts[1:]
return method_name, method_arguments return method_name, method_arguments

View File

@@ -7,7 +7,7 @@ models:
columns: columns:
- name: id - name: id
description: "The primary key for this table" description: "The primary key for this table"
tests: data_tests:
- unique - unique
- not_null - not_null
@@ -16,6 +16,6 @@ models:
columns: columns:
- name: id - name: id
description: "The primary key for this table" description: "The primary key for this table"
tests: data_tests:
- unique - unique
- not_null - not_null

View File

@@ -35,6 +35,8 @@ class NodeType(StrEnum):
Group = "group" Group = "group"
SavedQuery = "saved_query" SavedQuery = "saved_query"
SemanticModel = "semantic_model" SemanticModel = "semantic_model"
Unit = "unit_test"
Fixture = "fixture"
@classmethod @classmethod
def executable(cls) -> List["NodeType"]: def executable(cls) -> List["NodeType"]:

View File

@@ -74,6 +74,10 @@ class TargetBlock(YamlBlock, Generic[Target]):
def columns(self): def columns(self):
return [] return []
@property
def data_tests(self) -> List[TestDef]:
return []
@property @property
def tests(self) -> List[TestDef]: def tests(self) -> List[TestDef]:
return [] return []
@@ -100,11 +104,11 @@ class TargetColumnsBlock(TargetBlock[ColumnTarget], Generic[ColumnTarget]):
@dataclass @dataclass
class TestBlock(TargetColumnsBlock[Testable], Generic[Testable]): class TestBlock(TargetColumnsBlock[Testable], Generic[Testable]):
@property @property
def tests(self) -> List[TestDef]: def data_tests(self) -> List[TestDef]:
if self.target.tests is None: if self.target.data_tests is None:
return [] return []
else: else:
return self.target.tests return self.target.data_tests
@property @property
def quote_columns(self) -> Optional[bool]: def quote_columns(self) -> Optional[bool]:
@@ -129,11 +133,11 @@ class VersionedTestBlock(TestBlock, Generic[Versioned]):
raise DbtInternalError(".columns for VersionedTestBlock with versions") raise DbtInternalError(".columns for VersionedTestBlock with versions")
@property @property
def tests(self) -> List[TestDef]: def data_tests(self) -> List[TestDef]:
if not self.target.versions: if not self.target.versions:
return super().tests return super().data_tests
else: else:
raise DbtInternalError(".tests for VersionedTestBlock with versions") raise DbtInternalError(".data_tests for VersionedTestBlock with versions")
@classmethod @classmethod
def from_yaml_block(cls, src: YamlBlock, target: Versioned) -> "VersionedTestBlock[Versioned]": def from_yaml_block(cls, src: YamlBlock, target: Versioned) -> "VersionedTestBlock[Versioned]":
@@ -146,7 +150,7 @@ class VersionedTestBlock(TestBlock, Generic[Versioned]):
@dataclass @dataclass
class GenericTestBlock(TestBlock[Testable], Generic[Testable]): class GenericTestBlock(TestBlock[Testable], Generic[Testable]):
test: Dict[str, Any] data_test: Dict[str, Any]
column_name: Optional[str] column_name: Optional[str]
tags: List[str] tags: List[str]
version: Optional[NodeVersion] version: Optional[NodeVersion]
@@ -155,7 +159,7 @@ class GenericTestBlock(TestBlock[Testable], Generic[Testable]):
def from_test_block( def from_test_block(
cls, cls,
src: TestBlock, src: TestBlock,
test: Dict[str, Any], data_test: Dict[str, Any],
column_name: Optional[str], column_name: Optional[str],
tags: List[str], tags: List[str],
version: Optional[NodeVersion], version: Optional[NodeVersion],
@@ -164,7 +168,7 @@ class GenericTestBlock(TestBlock[Testable], Generic[Testable]):
file=src.file, file=src.file,
data=src.data, data=src.data,
target=src.target, target=src.target,
test=test, data_test=data_test,
column_name=column_name, column_name=column_name,
tags=tags, tags=tags,
version=version, version=version,

View File

@@ -0,0 +1,46 @@
from typing import Optional, Dict, List, Any
from io import StringIO
import csv
from dbt.contracts.files import FixtureSourceFile
from dbt.contracts.graph.nodes import UnitTestFileFixture
from dbt.node_types import NodeType
from dbt.parser.base import Parser
from dbt.parser.search import FileBlock
class FixtureParser(Parser[UnitTestFileFixture]):
@property
def resource_type(self) -> NodeType:
return NodeType.Fixture
@classmethod
def get_compiled_path(cls, block: FileBlock):
# Is this necessary?
return block.path.relative_path
def generate_unique_id(self, resource_name: str, _: Optional[str] = None) -> str:
return f"fixture.{self.project.project_name}.{resource_name}"
def parse_file(self, file_block: FileBlock):
assert isinstance(file_block.file, FixtureSourceFile)
unique_id = self.generate_unique_id(file_block.name)
fixture = UnitTestFileFixture(
name=file_block.name,
path=file_block.file.path.relative_path,
original_file_path=file_block.path.original_file_path,
package_name=self.project.project_name,
unique_id=unique_id,
resource_type=NodeType.Fixture,
rows=self.get_rows(file_block.file.contents),
)
self.manifest.add_fixture(file_block.file, fixture)
def get_rows(self, contents) -> List[Dict[str, Any]]:
rows = []
dummy_file = StringIO(contents)
reader = csv.DictReader(dummy_file)
for row in reader:
rows.append(row)
return rows

View File

@@ -43,7 +43,7 @@ class GenericTestParser(BaseParser[GenericTestNode]):
t t
for t in jinja.extract_toplevel_blocks( for t in jinja.extract_toplevel_blocks(
base_node.raw_code, base_node.raw_code,
allowed_blocks={"test"}, allowed_blocks={"test", "data_test"},
collect_raw_data=False, collect_raw_data=False,
) )
if isinstance(t, jinja.BlockTag) if isinstance(t, jinja.BlockTag)

View File

@@ -110,14 +110,14 @@ class TestBuilder(Generic[Testable]):
def __init__( def __init__(
self, self,
test: Dict[str, Any], data_test: Dict[str, Any],
target: Testable, target: Testable,
package_name: str, package_name: str,
render_ctx: Dict[str, Any], render_ctx: Dict[str, Any],
column_name: Optional[str] = None, column_name: Optional[str] = None,
version: Optional[NodeVersion] = None, version: Optional[NodeVersion] = None,
) -> None: ) -> None:
test_name, test_args = self.extract_test_args(test, column_name) test_name, test_args = self.extract_test_args(data_test, column_name)
self.args: Dict[str, Any] = test_args self.args: Dict[str, Any] = test_args
if "model" in self.args: if "model" in self.args:
raise TestArgIncludesModelError() raise TestArgIncludesModelError()
@@ -154,6 +154,7 @@ class TestBuilder(Generic[Testable]):
try: try:
value = get_rendered(value, render_ctx, native=True) value = get_rendered(value, render_ctx, native=True)
except UndefinedMacroError as e: except UndefinedMacroError as e:
raise CustomMacroPopulatingConfigValueError( raise CustomMacroPopulatingConfigValueError(
target_name=self.target.name, target_name=self.target.name,
column_name=column_name, column_name=column_name,
@@ -195,24 +196,24 @@ class TestBuilder(Generic[Testable]):
return TypeError('invalid target type "{}"'.format(type(self.target))) return TypeError('invalid target type "{}"'.format(type(self.target)))
@staticmethod @staticmethod
def extract_test_args(test, name=None) -> Tuple[str, Dict[str, Any]]: def extract_test_args(data_test, name=None) -> Tuple[str, Dict[str, Any]]:
if not isinstance(test, dict): if not isinstance(data_test, dict):
raise TestTypeError(test) raise TestTypeError(data_test)
# If the test is a dictionary with top-level keys, the test name is "test_name" # If the test is a dictionary with top-level keys, the test name is "test_name"
# and the rest are arguments # and the rest are arguments
# {'name': 'my_favorite_test', 'test_name': 'unique', 'config': {'where': '1=1'}} # {'name': 'my_favorite_test', 'test_name': 'unique', 'config': {'where': '1=1'}}
if "test_name" in test.keys(): if "test_name" in data_test.keys():
test_name = test.pop("test_name") test_name = data_test.pop("test_name")
test_args = test test_args = data_test
# If the test is a nested dictionary with one top-level key, the test name # If the test is a nested dictionary with one top-level key, the test name
# is the dict name, and nested keys are arguments # is the dict name, and nested keys are arguments
# {'unique': {'name': 'my_favorite_test', 'config': {'where': '1=1'}}} # {'unique': {'name': 'my_favorite_test', 'config': {'where': '1=1'}}}
else: else:
test = list(test.items()) data_test = list(data_test.items())
if len(test) != 1: if len(data_test) != 1:
raise TestDefinitionDictLengthError(test) raise TestDefinitionDictLengthError(data_test)
test_name, test_args = test[0] test_name, test_args = data_test[0]
if not isinstance(test_args, dict): if not isinstance(test_args, dict):
raise TestArgsNotDictError(test_args) raise TestArgsNotDictError(test_args)

View File

@@ -49,7 +49,7 @@ class MacroParser(BaseParser[Macro]):
t t
for t in jinja.extract_toplevel_blocks( for t in jinja.extract_toplevel_blocks(
base_node.raw_code, base_node.raw_code,
allowed_blocks={"macro", "materialization", "test"}, allowed_blocks={"macro", "materialization", "test", "data_test"},
collect_raw_data=False, collect_raw_data=False,
) )
if isinstance(t, jinja.BlockTag) if isinstance(t, jinja.BlockTag)

View File

@@ -104,6 +104,7 @@ from dbt.contracts.graph.nodes import (
ResultNode, ResultNode,
ModelNode, ModelNode,
NodeRelation, NodeRelation,
UnitTestDefinition,
) )
from dbt.contracts.graph.unparsed import NodeVersion from dbt.contracts.graph.unparsed import NodeVersion
from dbt.contracts.util import Writable from dbt.contracts.util import Writable
@@ -117,6 +118,7 @@ from dbt.parser.analysis import AnalysisParser
from dbt.parser.generic_test import GenericTestParser from dbt.parser.generic_test import GenericTestParser
from dbt.parser.singular_test import SingularTestParser from dbt.parser.singular_test import SingularTestParser
from dbt.parser.docs import DocumentationParser from dbt.parser.docs import DocumentationParser
from dbt.parser.fixtures import FixtureParser
from dbt.parser.hooks import HookParser from dbt.parser.hooks import HookParser
from dbt.parser.macros import MacroParser from dbt.parser.macros import MacroParser
from dbt.parser.models import ModelParser from dbt.parser.models import ModelParser
@@ -125,6 +127,7 @@ from dbt.parser.search import FileBlock
from dbt.parser.seeds import SeedParser from dbt.parser.seeds import SeedParser
from dbt.parser.snapshots import SnapshotParser from dbt.parser.snapshots import SnapshotParser
from dbt.parser.sources import SourcePatcher from dbt.parser.sources import SourcePatcher
from dbt.parser.unit_tests import UnitTestPatcher
from dbt.version import __version__ from dbt.version import __version__
from dbt.common.dataclass_schema import StrEnum, dbtClassMixin from dbt.common.dataclass_schema import StrEnum, dbtClassMixin
@@ -219,6 +222,7 @@ class ManifestLoaderInfo(dbtClassMixin, Writable):
load_macros_elapsed: Optional[float] = None load_macros_elapsed: Optional[float] = None
parse_project_elapsed: Optional[float] = None parse_project_elapsed: Optional[float] = None
patch_sources_elapsed: Optional[float] = None patch_sources_elapsed: Optional[float] = None
patch_unit_tests_elapsed: Optional[float] = None
process_manifest_elapsed: Optional[float] = None process_manifest_elapsed: Optional[float] = None
load_all_elapsed: Optional[float] = None load_all_elapsed: Optional[float] = None
projects: List[ProjectLoaderInfo] = field(default_factory=list) projects: List[ProjectLoaderInfo] = field(default_factory=list)
@@ -474,6 +478,7 @@ class ManifestLoader:
SeedParser, SeedParser,
DocumentationParser, DocumentationParser,
HookParser, HookParser,
FixtureParser,
] ]
for project in self.all_projects.values(): for project in self.all_projects.values():
if project.project_name not in project_parser_files: if project.project_name not in project_parser_files:
@@ -512,6 +517,16 @@ class ManifestLoader:
self.manifest.sources = patcher.sources self.manifest.sources = patcher.sources
self._perf_info.patch_sources_elapsed = time.perf_counter() - start_patch self._perf_info.patch_sources_elapsed = time.perf_counter() - start_patch
# patch_unit_tests converts the UnparsedUnitTestDefinitions in the
# manifest.unit_tests to UnitTestDefinitions via 'patch_unit_test'
# in UnitTestPatcher
# TODO: is this needed
start_patch = time.perf_counter()
unit_test_patcher = UnitTestPatcher(self.root_project, self.manifest)
unit_test_patcher.construct_unit_tests()
self.manifest.unit_tests = unit_test_patcher.unit_tests
self._perf_info.patch_unit_tests_elapsed = time.perf_counter() - start_patch
# We need to rebuild disabled in order to include disabled sources # We need to rebuild disabled in order to include disabled sources
self.manifest.rebuild_disabled_lookup() self.manifest.rebuild_disabled_lookup()
@@ -529,6 +544,8 @@ class ManifestLoader:
# determine whether they need processing. # determine whether they need processing.
start_process = time.perf_counter() start_process = time.perf_counter()
self.process_sources(self.root_project.project_name) self.process_sources(self.root_project.project_name)
# TODO: does this need to be done?.... I think it's done when we loop through versions
self.process_unit_tests(self.root_project.project_name)
self.process_refs(self.root_project.project_name, self.root_project.dependencies) self.process_refs(self.root_project.project_name, self.root_project.dependencies)
self.process_docs(self.root_project) self.process_docs(self.root_project)
self.process_metrics(self.root_project) self.process_metrics(self.root_project)
@@ -665,7 +682,7 @@ class ManifestLoader:
for file_id in parser_files[parser_name]: for file_id in parser_files[parser_name]:
block = FileBlock(self.manifest.files[file_id]) block = FileBlock(self.manifest.files[file_id])
if isinstance(parser, SchemaParser): if isinstance(parser, SchemaParser):
assert isinstance(block.file, SchemaSourceFile) assert isinstance(block.file, (SchemaSourceFile))
if self.partially_parsing: if self.partially_parsing:
dct = block.file.pp_dict dct = block.file.pp_dict
else: else:
@@ -1064,6 +1081,7 @@ class ManifestLoader:
"load_macros_elapsed": self._perf_info.load_macros_elapsed, "load_macros_elapsed": self._perf_info.load_macros_elapsed,
"parse_project_elapsed": self._perf_info.parse_project_elapsed, "parse_project_elapsed": self._perf_info.parse_project_elapsed,
"patch_sources_elapsed": self._perf_info.patch_sources_elapsed, "patch_sources_elapsed": self._perf_info.patch_sources_elapsed,
"patch_unit_tests_elapsed": self._perf_info.patch_unit_tests_elapsed,
"process_manifest_elapsed": (self._perf_info.process_manifest_elapsed), "process_manifest_elapsed": (self._perf_info.process_manifest_elapsed),
"load_all_elapsed": self._perf_info.load_all_elapsed, "load_all_elapsed": self._perf_info.load_all_elapsed,
"is_partial_parse_enabled": (self._perf_info.is_partial_parse_enabled), "is_partial_parse_enabled": (self._perf_info.is_partial_parse_enabled),
@@ -1224,6 +1242,18 @@ class ManifestLoader:
continue continue
_process_sources_for_exposure(self.manifest, current_project, exposure) _process_sources_for_exposure(self.manifest, current_project, exposure)
# Loops through all nodes and exposures, for each element in
# 'sources' array finds the source node and updates the
# 'depends_on.nodes' array with the unique id
def process_unit_tests(self, current_project: str):
for node in self.manifest.nodes.values():
if node.resource_type == NodeType.Unit:
continue
assert not isinstance(node, UnitTestDefinition)
if node.created_at < self.started_at:
continue
_process_unit_tests_for_node(self.manifest, current_project, node)
def cleanup_disabled(self): def cleanup_disabled(self):
# make sure the nodes are in the manifest.nodes or the disabled dict, # make sure the nodes are in the manifest.nodes or the disabled dict,
# correctly now that the schema files are also parsed # correctly now that the schema files are also parsed
@@ -1752,7 +1782,7 @@ def _process_sources_for_node(manifest: Manifest, current_project: str, node: Ma
) )
if target_source is None or isinstance(target_source, Disabled): if target_source is None or isinstance(target_source, Disabled):
# this folows the same pattern as refs # this follows the same pattern as refs
node.config.enabled = False node.config.enabled = False
invalid_target_fail_unless_test( invalid_target_fail_unless_test(
node=node, node=node,
@@ -1765,6 +1795,30 @@ def _process_sources_for_node(manifest: Manifest, current_project: str, node: Ma
node.depends_on.add_node(target_source_id) node.depends_on.add_node(target_source_id)
def _process_unit_tests_for_node(manifest: Manifest, current_project: str, node: ManifestNode):
if not isinstance(node, ModelNode):
return
target_unit_test: Optional[UnitTestDefinition] = None
for unit_test_name in node.unit_tests:
# TODO: loop through tests and build all the versioned nodes...
target_unit_test = manifest.resolve_unit_tests(
unit_test_name,
current_project,
node.package_name,
)
if target_unit_test is None:
# this folows the same pattern as refs
node.config.enabled = False
continue
# TODO: below will changed based on if versions are involved or not.
# target_unit_test_id = target_unit_test.unique_id
node.depends_on.add_node(target_unit_test.unique_id)
# This is called in task.rpc.sql_commands when a "dynamic" node is # This is called in task.rpc.sql_commands when a "dynamic" node is
# created in the manifest, in 'add_refs' # created in the manifest, in 'add_refs'
def process_macro(config: RuntimeConfig, manifest: Manifest, macro: Macro) -> None: def process_macro(config: RuntimeConfig, manifest: Manifest, macro: Macro) -> None:
@@ -1793,8 +1847,9 @@ def write_semantic_manifest(manifest: Manifest, target_path: str) -> None:
semantic_manifest.write_json_to_file(path) semantic_manifest.write_json_to_file(path)
def write_manifest(manifest: Manifest, target_path: str): def write_manifest(manifest: Manifest, target_path: str, which: Optional[str] = None):
path = os.path.join(target_path, MANIFEST_FILE_NAME) file_name = MANIFEST_FILE_NAME
path = os.path.join(target_path, file_name)
manifest.write(path) manifest.write(path)
write_semantic_manifest(manifest=manifest, target_path=target_path) write_semantic_manifest(manifest=manifest, target_path=target_path)

View File

@@ -280,6 +280,10 @@ class PartialParsing:
if saved_source_file.parse_file_type == ParseFileType.Documentation: if saved_source_file.parse_file_type == ParseFileType.Documentation:
self.delete_doc_node(saved_source_file) self.delete_doc_node(saved_source_file)
# fixtures
if saved_source_file.parse_file_type == ParseFileType.Fixture:
self.delete_fixture_node(saved_source_file)
fire_event(PartialParsingFile(operation="deleted", file_id=file_id)) fire_event(PartialParsingFile(operation="deleted", file_id=file_id))
# Updates for non-schema files # Updates for non-schema files
@@ -293,6 +297,8 @@ class PartialParsing:
self.update_macro_in_saved(new_source_file, old_source_file) self.update_macro_in_saved(new_source_file, old_source_file)
elif new_source_file.parse_file_type == ParseFileType.Documentation: elif new_source_file.parse_file_type == ParseFileType.Documentation:
self.update_doc_in_saved(new_source_file, old_source_file) self.update_doc_in_saved(new_source_file, old_source_file)
elif new_source_file.parse_file_type == ParseFileType.Fixture:
self.update_fixture_in_saved(new_source_file, old_source_file)
else: else:
raise Exception(f"Invalid parse_file_type in source_file {file_id}") raise Exception(f"Invalid parse_file_type in source_file {file_id}")
fire_event(PartialParsingFile(operation="updated", file_id=file_id)) fire_event(PartialParsingFile(operation="updated", file_id=file_id))
@@ -377,6 +383,13 @@ class PartialParsing:
self.saved_files[new_source_file.file_id] = deepcopy(new_source_file) self.saved_files[new_source_file.file_id] = deepcopy(new_source_file)
self.add_to_pp_files(new_source_file) self.add_to_pp_files(new_source_file)
def update_fixture_in_saved(self, new_source_file, old_source_file):
if self.already_scheduled_for_parsing(old_source_file):
return
self.delete_fixture_node(old_source_file)
self.saved_files[new_source_file.file_id] = deepcopy(new_source_file)
self.add_to_pp_files(new_source_file)
def remove_mssat_file(self, source_file): def remove_mssat_file(self, source_file):
# nodes [unique_ids] -- SQL files # nodes [unique_ids] -- SQL files
# There should always be a node for a SQL file # There should always be a node for a SQL file
@@ -579,6 +592,20 @@ class PartialParsing:
# Remove the file object # Remove the file object
self.saved_manifest.files.pop(source_file.file_id) self.saved_manifest.files.pop(source_file.file_id)
def delete_fixture_node(self, source_file):
# remove fixtures from the "fixtures" dictionary
fixture_unique_id = source_file.fixture
self.saved_manifest.fixtures.pop(fixture_unique_id)
unit_tests = source_file.unit_tests.copy()
for unique_id in unit_tests:
unit_test = self.saved_manifest.unit_tests.pop(unique_id)
# schedule unit_test for parsing
self._schedule_for_parsing(
"unit_tests", unit_test, unit_test.name, self.delete_schema_unit_test
)
source_file.unit_tests.remove(unique_id)
self.saved_manifest.files.pop(source_file.file_id)
# Schema files ----------------------- # Schema files -----------------------
# Changed schema files # Changed schema files
def change_schema_file(self, file_id): def change_schema_file(self, file_id):
@@ -608,7 +635,7 @@ class PartialParsing:
self.saved_manifest.files.pop(file_id) self.saved_manifest.files.pop(file_id)
# For each key in a schema file dictionary, process the changed, deleted, and added # For each key in a schema file dictionary, process the changed, deleted, and added
# elemnts for the key lists # elements for the key lists
def handle_schema_file_changes(self, schema_file, saved_yaml_dict, new_yaml_dict): def handle_schema_file_changes(self, schema_file, saved_yaml_dict, new_yaml_dict):
# loop through comparing previous dict_from_yaml with current dict_from_yaml # loop through comparing previous dict_from_yaml with current dict_from_yaml
# Need to do the deleted/added/changed thing, just like the files lists # Need to do the deleted/added/changed thing, just like the files lists
@@ -681,6 +708,7 @@ class PartialParsing:
handle_change("metrics", self.delete_schema_metric) handle_change("metrics", self.delete_schema_metric)
handle_change("groups", self.delete_schema_group) handle_change("groups", self.delete_schema_group)
handle_change("semantic_models", self.delete_schema_semantic_model) handle_change("semantic_models", self.delete_schema_semantic_model)
handle_change("unit_tests", self.delete_schema_unit_test)
handle_change("saved_queries", self.delete_schema_saved_query) handle_change("saved_queries", self.delete_schema_saved_query)
def _handle_element_change( def _handle_element_change(
@@ -938,6 +966,17 @@ class PartialParsing:
elif unique_id in self.saved_manifest.disabled: elif unique_id in self.saved_manifest.disabled:
self.delete_disabled(unique_id, schema_file.file_id) self.delete_disabled(unique_id, schema_file.file_id)
def delete_schema_unit_test(self, schema_file, unit_test_dict):
unit_test_name = unit_test_dict["name"]
unit_tests = schema_file.unit_tests.copy()
for unique_id in unit_tests:
if unique_id in self.saved_manifest.unit_tests:
unit_test = self.saved_manifest.unit_tests[unique_id]
if unit_test.name == unit_test_name:
self.saved_manifest.unit_tests.pop(unique_id)
schema_file.unit_tests.remove(unique_id)
# No disabled unit tests yet
def get_schema_element(self, elem_list, elem_name): def get_schema_element(self, elem_list, elem_name):
for element in elem_list: for element in elem_list:
if "name" in element and element["name"] == elem_name: if "name" in element and element["name"] == elem_name:
@@ -1009,6 +1048,8 @@ class PartialParsing:
# Create a list of file_ids for source_files that need to be reparsed, and # Create a list of file_ids for source_files that need to be reparsed, and
# a dictionary of file_ids to yaml_keys to names. # a dictionary of file_ids to yaml_keys to names.
for source_file in self.saved_files.values(): for source_file in self.saved_files.values():
if source_file.parse_file_type == ParseFileType.Fixture:
continue
file_id = source_file.file_id file_id = source_file.file_id
if not source_file.env_vars: if not source_file.env_vars:
continue continue

View File

@@ -10,6 +10,7 @@ from dbt.contracts.files import (
FileHash, FileHash,
AnySourceFile, AnySourceFile,
SchemaSourceFile, SchemaSourceFile,
FixtureSourceFile,
) )
from dbt.config import Project from dbt.config import Project
from dbt.common.dataclass_schema import dbtClassMixin from dbt.common.dataclass_schema import dbtClassMixin
@@ -46,7 +47,13 @@ def load_source_file(
saved_files, saved_files,
) -> Optional[AnySourceFile]: ) -> Optional[AnySourceFile]:
sf_cls = SchemaSourceFile if parse_file_type == ParseFileType.Schema else SourceFile if parse_file_type == ParseFileType.Schema:
sf_cls = SchemaSourceFile
elif parse_file_type == ParseFileType.Fixture:
sf_cls = FixtureSourceFile # type:ignore[assignment]
else:
sf_cls = SourceFile # type:ignore[assignment]
source_file = sf_cls( source_file = sf_cls(
path=path, path=path,
checksum=FileHash.empty(), checksum=FileHash.empty(),
@@ -422,5 +429,10 @@ def get_file_types_for_project(project):
"extensions": [".yml", ".yaml"], "extensions": [".yml", ".yaml"],
"parser": "SchemaParser", "parser": "SchemaParser",
}, },
ParseFileType.Fixture: {
"paths": project.fixture_paths,
"extensions": [".csv"],
"parser": "FixtureParser",
},
} }
return file_types return file_types

View File

@@ -72,11 +72,11 @@ class SchemaGenericTestParser(SimpleParser):
def parse_column_tests( def parse_column_tests(
self, block: TestBlock, column: UnparsedColumn, version: Optional[NodeVersion] self, block: TestBlock, column: UnparsedColumn, version: Optional[NodeVersion]
) -> None: ) -> None:
if not column.tests: if not column.data_tests:
return return
for test in column.tests: for data_test in column.data_tests:
self.parse_test(block, test, column, version) self.parse_test(block, data_test, column, version)
def create_test_node( def create_test_node(
self, self,
@@ -148,7 +148,7 @@ class SchemaGenericTestParser(SimpleParser):
def parse_generic_test( def parse_generic_test(
self, self,
target: Testable, target: Testable,
test: Dict[str, Any], data_test: Dict[str, Any],
tags: List[str], tags: List[str],
column_name: Optional[str], column_name: Optional[str],
schema_file_id: str, schema_file_id: str,
@@ -156,7 +156,7 @@ class SchemaGenericTestParser(SimpleParser):
) -> GenericTestNode: ) -> GenericTestNode:
try: try:
builder = TestBuilder( builder = TestBuilder(
test=test, data_test=data_test,
target=target, target=target,
column_name=column_name, column_name=column_name,
version=version, version=version,
@@ -321,7 +321,7 @@ class SchemaGenericTestParser(SimpleParser):
""" """
node = self.parse_generic_test( node = self.parse_generic_test(
target=block.target, target=block.target,
test=block.test, data_test=block.data_test,
tags=block.tags, tags=block.tags,
column_name=block.column_name, column_name=block.column_name,
schema_file_id=block.file.file_id, schema_file_id=block.file.file_id,
@@ -357,12 +357,12 @@ class SchemaGenericTestParser(SimpleParser):
def parse_test( def parse_test(
self, self,
target_block: TestBlock, target_block: TestBlock,
test: TestDef, data_test: TestDef,
column: Optional[UnparsedColumn], column: Optional[UnparsedColumn],
version: Optional[NodeVersion], version: Optional[NodeVersion],
) -> None: ) -> None:
if isinstance(test, str): if isinstance(data_test, str):
test = {test: {}} data_test = {data_test: {}}
if column is None: if column is None:
column_name: Optional[str] = None column_name: Optional[str] = None
@@ -376,7 +376,7 @@ class SchemaGenericTestParser(SimpleParser):
block = GenericTestBlock.from_test_block( block = GenericTestBlock.from_test_block(
src=target_block, src=target_block,
test=test, data_test=data_test,
column_name=column_name, column_name=column_name,
tags=column_tags, tags=column_tags,
version=version, version=version,
@@ -387,8 +387,8 @@ class SchemaGenericTestParser(SimpleParser):
for column in block.columns: for column in block.columns:
self.parse_column_tests(block, column, None) self.parse_column_tests(block, column, None)
for test in block.tests: for data_test in block.data_tests:
self.parse_test(block, test, None, None) self.parse_test(block, data_test, None, None)
def parse_versioned_tests(self, block: VersionedTestBlock) -> None: def parse_versioned_tests(self, block: VersionedTestBlock) -> None:
if not block.target.versions: if not block.target.versions:

View File

@@ -25,21 +25,22 @@ class SchemaYamlRenderer(BaseRenderer):
models: models:
- name: blah - name: blah
description: blah description: blah
tests: ... data_tests: ...
columns: columns:
- name: - name:
description: blah description: blah
tests: ... data_tests: ...
Return True if it's tests or description - those aren't rendered now Return True if it's tests, data_tests or description - those aren't rendered now
because they're rendered later in parse_generic_tests or process_docs. because they're rendered later in parse_generic_tests or process_docs.
"tests" and "data_tests" are both currently supported but "tests" has been deprecated
""" """
# top level descriptions and tests # top level descriptions and data_tests
if len(keypath) >= 1 and keypath[0] in ("tests", "description"): if len(keypath) >= 1 and keypath[0] in ("tests", "data_tests", "description"):
return True return True
# columns descriptions and tests # columns descriptions and data_tests
if len(keypath) == 2 and keypath[1] in ("tests", "description"): if len(keypath) == 2 and keypath[1] in ("tests", "data_tests", "description"):
return True return True
# versions # versions
@@ -49,7 +50,7 @@ class SchemaYamlRenderer(BaseRenderer):
if ( if (
len(keypath) >= 3 len(keypath) >= 3
and keypath[0] in ("columns", "dimensions", "measures", "entities") and keypath[0] in ("columns", "dimensions", "measures", "entities")
and keypath[2] in ("tests", "description") and keypath[2] in ("tests", "data_tests", "description")
): ):
return True return True

View File

@@ -5,6 +5,7 @@ from abc import ABCMeta, abstractmethod
from typing import Any, Callable, Dict, Generic, Iterable, List, Optional, Type, TypeVar from typing import Any, Callable, Dict, Generic, Iterable, List, Optional, Type, TypeVar
from dataclasses import dataclass, field from dataclasses import dataclass, field
from dbt import deprecations
from dbt.common.contracts.constraints import ConstraintType, ModelLevelConstraint from dbt.common.contracts.constraints import ConstraintType, ModelLevelConstraint
from dbt.common.dataclass_schema import ValidationError, dbtClassMixin from dbt.common.dataclass_schema import ValidationError, dbtClassMixin
@@ -33,7 +34,7 @@ from dbt.contracts.graph.unparsed import (
from dbt.exceptions import ( from dbt.exceptions import (
DuplicateMacroPatchNameError, DuplicateMacroPatchNameError,
DuplicatePatchPathError, DuplicatePatchPathError,
DuplicateSourcePatchNameError, DuplicatePatchNameError,
JSONValidationError, JSONValidationError,
DbtInternalError, DbtInternalError,
ParsingError, ParsingError,
@@ -139,6 +140,11 @@ class SchemaParser(SimpleParser[YamlBlock, ModelNode]):
self.root_project, self.project.project_name, self.schema_yaml_vars self.root_project, self.project.project_name, self.schema_yaml_vars
) )
# This is unnecessary, but mypy was requiring it. Clean up parser code so
# we don't have to do this.
def parse_from_dict(self, dct):
pass
@classmethod @classmethod
def get_compiled_path(cls, block: FileBlock) -> str: def get_compiled_path(cls, block: FileBlock) -> str:
# should this raise an error? # should this raise an error?
@@ -226,6 +232,12 @@ class SchemaParser(SimpleParser[YamlBlock, ModelNode]):
semantic_model_parser = SemanticModelParser(self, yaml_block) semantic_model_parser = SemanticModelParser(self, yaml_block)
semantic_model_parser.parse() semantic_model_parser.parse()
if "unit_tests" in dct:
from dbt.parser.unit_tests import UnitTestParser
unit_test_parser = UnitTestParser(self, yaml_block)
unit_test_parser.parse()
if "saved_queries" in dct: if "saved_queries" in dct:
from dbt.parser.schema_yaml_readers import SavedQueryParser from dbt.parser.schema_yaml_readers import SavedQueryParser
@@ -251,12 +263,13 @@ class ParseResult:
# abstract base class (ABCMeta) # abstract base class (ABCMeta)
# Four subclasses: MetricParser, ExposureParser, GroupParser, SourceParser, PatchParser # Many subclasses: MetricParser, ExposureParser, GroupParser, SourceParser,
# PatchParser, SemanticModelParser, SavedQueryParser, UnitTestParser
class YamlReader(metaclass=ABCMeta): class YamlReader(metaclass=ABCMeta):
def __init__(self, schema_parser: SchemaParser, yaml: YamlBlock, key: str) -> None: def __init__(self, schema_parser: SchemaParser, yaml: YamlBlock, key: str) -> None:
self.schema_parser = schema_parser self.schema_parser = schema_parser
# key: models, seeds, snapshots, sources, macros, # key: models, seeds, snapshots, sources, macros,
# analyses, exposures # analyses, exposures, unit_tests
self.key = key self.key = key
self.yaml = yaml self.yaml = yaml
self.schema_yaml_vars = SchemaYamlVars() self.schema_yaml_vars = SchemaYamlVars()
@@ -304,10 +317,10 @@ class YamlReader(metaclass=ABCMeta):
if coerce_dict_str(entry) is None: if coerce_dict_str(entry) is None:
raise YamlParseListError(path, self.key, data, "expected a dict with string keys") raise YamlParseListError(path, self.key, data, "expected a dict with string keys")
if "name" not in entry: if "name" not in entry and "model" not in entry:
raise ParsingError("Entry did not contain a name") raise ParsingError("Entry did not contain a name")
# Render the data (except for tests and descriptions). # Render the data (except for tests, data_tests and descriptions).
# See the SchemaYamlRenderer # See the SchemaYamlRenderer
entry = self.render_entry(entry) entry = self.render_entry(entry)
if self.schema_yaml_vars.env_vars: if self.schema_yaml_vars.env_vars:
@@ -367,7 +380,9 @@ class SourceParser(YamlReader):
# source patches must be unique # source patches must be unique
key = (patch.overrides, patch.name) key = (patch.overrides, patch.name)
if key in self.manifest.source_patches: if key in self.manifest.source_patches:
raise DuplicateSourcePatchNameError(patch, self.manifest.source_patches[key]) raise DuplicatePatchNameError(
NodeType.Source, patch, self.manifest.source_patches[key]
)
self.manifest.source_patches[key] = patch self.manifest.source_patches[key] = patch
source_file.source_patches.append(key) source_file.source_patches.append(key)
else: else:
@@ -477,6 +492,8 @@ class PatchParser(YamlReader, Generic[NonSourceTarget, Parsed]):
self.normalize_group_attribute(data, path) self.normalize_group_attribute(data, path)
self.normalize_contract_attribute(data, path) self.normalize_contract_attribute(data, path)
self.normalize_access_attribute(data, path) self.normalize_access_attribute(data, path)
# `tests` has been deprecated, convert to `data_tests` here if present
self.validate_data_tests(data)
node = self._target_type().from_dict(data) node = self._target_type().from_dict(data)
except (ValidationError, JSONValidationError) as exc: except (ValidationError, JSONValidationError) as exc:
raise YamlParseDictError(path, self.key, data, exc) raise YamlParseDictError(path, self.key, data, exc)
@@ -514,6 +531,21 @@ class PatchParser(YamlReader, Generic[NonSourceTarget, Parsed]):
def normalize_access_attribute(self, data, path): def normalize_access_attribute(self, data, path):
return self.normalize_attribute(data, path, "access") return self.normalize_attribute(data, path, "access")
def validate_data_tests(self, data):
if data.get("columns"):
for column in data["columns"]:
if "tests" in column and "data_tests" in column:
raise ValidationError(
"Invalid test config: cannot have both 'tests' and 'data_tests' defined"
)
if "tests" in column:
deprecations.warn(
"project-test-config",
deprecated_path="tests",
exp_path="data_tests",
)
column["data_tests"] = column.pop("tests")
def patch_node_config(self, node, patch): def patch_node_config(self, node, patch):
# Get the ContextConfig that's used in calculating the config # Get the ContextConfig that's used in calculating the config
# This must match the model resource_type that's being patched # This must match the model resource_type that's being patched

View File

@@ -221,10 +221,10 @@ class SourcePatcher:
return generic_test_parser return generic_test_parser
def get_source_tests(self, target: UnpatchedSourceDefinition) -> Iterable[GenericTestNode]: def get_source_tests(self, target: UnpatchedSourceDefinition) -> Iterable[GenericTestNode]:
for test, column in target.get_tests(): for data_test, column in target.get_tests():
yield self.parse_source_test( yield self.parse_source_test(
target=target, target=target,
test=test, data_test=data_test,
column=column, column=column,
) )
@@ -249,7 +249,7 @@ class SourcePatcher:
def parse_source_test( def parse_source_test(
self, self,
target: UnpatchedSourceDefinition, target: UnpatchedSourceDefinition,
test: Dict[str, Any], data_test: Dict[str, Any],
column: Optional[UnparsedColumn], column: Optional[UnparsedColumn],
) -> GenericTestNode: ) -> GenericTestNode:
column_name: Optional[str] column_name: Optional[str]
@@ -269,7 +269,7 @@ class SourcePatcher:
generic_test_parser = self.get_generic_test_parser_for(target.package_name) generic_test_parser = self.get_generic_test_parser_for(target.package_name)
node = generic_test_parser.parse_generic_test( node = generic_test_parser.parse_generic_test(
target=target, target=target,
test=test, data_test=data_test,
tags=tags, tags=tags,
column_name=column_name, column_name=column_name,
schema_file_id=target.file_id, schema_file_id=target.file_id,

View File

@@ -0,0 +1,637 @@
from csv import DictReader
from pathlib import Path
from typing import List, Set, Dict, Any, Optional, Type, TypeVar
import os
from io import StringIO
import csv
from dbt_extractor import py_extract_from_source, ExtractionError # type: ignore
from dbt import utils
from dbt.config import RuntimeConfig
from dbt.context.context_config import (
ContextConfig,
BaseContextConfigGenerator,
ContextConfigGenerator,
UnrenderedConfigGenerator,
)
from dbt.context.providers import generate_parse_exposure, get_rendered
from dbt.contracts.files import FileHash, SchemaSourceFile
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.model_config import UnitTestNodeConfig, ModelConfig, UnitTestConfig
from dbt.contracts.graph.nodes import (
ModelNode,
UnitTestNode,
UnitTestDefinition,
DependsOn,
UnitTestSourceDefinition,
UnpatchedUnitTestDefinition,
)
from dbt.contracts.graph.unparsed import (
UnparsedUnitTest,
UnitTestFormat,
UnitTestNodeVersion,
UnitTestPatch,
NodeVersion,
)
from dbt.common.dataclass_schema import dbtClassMixin
from dbt.exceptions import (
ParsingError,
InvalidUnitTestGivenInput,
DbtInternalError,
DuplicatePatchNameError,
)
from dbt.graph import UniqueId
from dbt.node_types import NodeType
from dbt.parser.schemas import (
SchemaParser,
YamlBlock,
ValidationError,
JSONValidationError,
YamlParseDictError,
YamlReader,
ParseResult,
)
from dbt.utils import get_pseudo_test_path
class UnitTestManifestLoader:
def __init__(self, manifest, root_project, selected) -> None:
self.manifest: Manifest = manifest
self.root_project: RuntimeConfig = root_project
# selected comes from the initial selection against a "regular" manifest
self.selected: Set[UniqueId] = selected
self.unit_test_manifest = Manifest(macros=manifest.macros)
def load(self) -> Manifest:
for unique_id in self.selected:
if unique_id in self.manifest.unit_tests:
unit_test_case: UnitTestDefinition = self.manifest.unit_tests[unique_id]
self.parse_unit_test_case(unit_test_case)
return self.unit_test_manifest
def parse_unit_test_case(self, test_case: UnitTestDefinition, version: Optional[str] = None):
# Create unit test node based on the node being tested
unique_id = self.manifest.ref_lookup.get_unique_id(
key=test_case.model, package=test_case.package_name, version=version
)
tested_node = self.manifest.ref_lookup.perform_lookup(unique_id, self.manifest)
assert isinstance(tested_node, ModelNode)
# Create UnitTestNode based on model being tested. Since selection has
# already been done, we don't have to care about fields that are necessary
# for selection.
# Note: no depends_on, that's added later using input nodes
name = f"{test_case.model}__{test_case.name}"
unit_test_node = UnitTestNode(
name=name,
resource_type=NodeType.Unit,
package_name=test_case.package_name,
path=get_pseudo_test_path(name, test_case.original_file_path),
original_file_path=test_case.original_file_path,
unique_id=test_case.unique_id,
config=UnitTestNodeConfig(
materialized="unit",
expected_rows=test_case.expect.rows, # type:ignore
),
raw_code=tested_node.raw_code,
database=tested_node.database,
schema=tested_node.schema,
alias=name,
fqn=test_case.unique_id.split("."),
checksum=FileHash.empty(),
tested_node_unique_id=tested_node.unique_id,
overrides=test_case.overrides,
)
ctx = generate_parse_exposure(
unit_test_node, # type: ignore
self.root_project,
self.manifest,
test_case.package_name,
)
get_rendered(unit_test_node.raw_code, ctx, unit_test_node, capture_macros=True)
# unit_test_node now has a populated refs/sources
self.unit_test_manifest.nodes[unit_test_node.unique_id] = unit_test_node
# Now create input_nodes for the test inputs
"""
given:
- input: ref('my_model_a')
rows: []
- input: ref('my_model_b')
rows:
- {id: 1, b: 2}
- {id: 2, b: 2}
"""
# Add the model "input" nodes, consisting of all referenced models in the unit test.
# This creates an ephemeral model for every input in every test, so there may be multiple
# input models substituting for the same input ref'd model. Note that since these are
# always "ephemeral" they just wrap the tested_node SQL in additional CTEs. No actual table
# or view is created.
for given in test_case.given:
# extract the original_input_node from the ref in the "input" key of the given list
original_input_node = self._get_original_input_node(given.input, tested_node)
common_fields = {
"resource_type": NodeType.Model,
"package_name": test_case.package_name,
"original_file_path": original_input_node.original_file_path,
"config": ModelConfig(materialized="ephemeral"),
"database": original_input_node.database,
"alias": original_input_node.identifier,
"schema": original_input_node.schema,
"fqn": original_input_node.fqn,
"checksum": FileHash.empty(),
"raw_code": self._build_fixture_raw_code(given.rows, None),
}
if original_input_node.resource_type in (
NodeType.Model,
NodeType.Seed,
NodeType.Snapshot,
):
input_name = f"{unit_test_node.name}__{original_input_node.name}"
input_node = ModelNode(
**common_fields,
unique_id=f"model.{test_case.package_name}.{input_name}",
name=input_name,
path=original_input_node.path,
)
elif original_input_node.resource_type == NodeType.Source:
# We are reusing the database/schema/identifier from the original source,
# but that shouldn't matter since this acts as an ephemeral model which just
# wraps a CTE around the unit test node.
input_name = f"{unit_test_node.name}__{original_input_node.search_name}__{original_input_node.name}"
input_node = UnitTestSourceDefinition(
**common_fields,
unique_id=f"model.{test_case.package_name}.{input_name}",
name=original_input_node.name, # must be the same name for source lookup to work
path=input_name + ".sql", # for writing out compiled_code
source_name=original_input_node.source_name, # needed for source lookup
)
# Sources need to go in the sources dictionary in order to create the right lookup
# TODO: i think this should be model_name.version
self.unit_test_manifest.sources[input_node.unique_id] = input_node # type: ignore
# Both ModelNode and UnitTestSourceDefinition need to go in nodes dictionary
self.unit_test_manifest.nodes[input_node.unique_id] = input_node
# Populate this_input_node_unique_id if input fixture represents node being tested
if original_input_node == tested_node:
unit_test_node.this_input_node_unique_id = input_node.unique_id
# Add unique ids of input_nodes to depends_on
unit_test_node.depends_on.nodes.append(input_node.unique_id)
def _build_fixture_raw_code(self, rows, column_name_to_data_types) -> str:
# We're not currently using column_name_to_data_types, but leaving here for
# possible future use.
return ("{{{{ get_fixture_sql({rows}, {column_name_to_data_types}) }}}}").format(
rows=rows, column_name_to_data_types=column_name_to_data_types
)
def _get_original_input_node(self, input: str, tested_node: ModelNode):
"""
Returns the original input node as defined in the project given an input reference
and the node being tested.
input: str representing how input node is referenced in tested model sql
* examples:
- "ref('my_model_a')"
- "source('my_source_schema', 'my_source_name')"
- "this"
tested_node: ModelNode of representing node being tested
"""
if input.strip() == "this":
original_input_node = tested_node
else:
try:
statically_parsed = py_extract_from_source(f"{{{{ {input} }}}}")
except ExtractionError:
raise InvalidUnitTestGivenInput(input=input)
if statically_parsed["refs"]:
ref = list(statically_parsed["refs"])[0]
name = ref.get("name")
package = ref.get("package")
version = ref.get("version")
# TODO: disabled lookup, versioned lookup, public models
original_input_node = self.manifest.ref_lookup.find(
name, package, version, self.manifest
)
elif statically_parsed["sources"]:
source = list(statically_parsed["sources"])[0]
input_source_name, input_name = source
original_input_node = self.manifest.source_lookup.find(
f"{input_source_name}.{input_name}",
None,
self.manifest,
)
else:
raise InvalidUnitTestGivenInput(input=input)
return original_input_node
T = TypeVar("T", bound=dbtClassMixin)
class UnitTestParser(YamlReader):
def __init__(self, schema_parser: SchemaParser, yaml: YamlBlock) -> None:
super().__init__(schema_parser, yaml, "unit_tests")
self.schema_parser = schema_parser
self.yaml = yaml
def _target_from_dict(self, cls: Type[T], data: Dict[str, Any]) -> T:
path = self.yaml.path.original_file_path
try:
cls.validate(data)
return cls.from_dict(data)
except (ValidationError, JSONValidationError) as exc:
raise YamlParseDictError(path, self.key, data, exc)
# This should create the UnparseUnitTest object. Then it should be turned into and UnpatchedUnitTest
def parse(self) -> ParseResult:
for data in self.get_key_dicts():
is_override = "overrides" in data
if is_override:
data["path"] = self.yaml.path.original_file_path
patch = self._target_from_dict(UnitTestPatch, data)
assert isinstance(self.yaml.file, SchemaSourceFile)
source_file = self.yaml.file
# unit test patches must be unique
key = (patch.overrides, patch.name)
if key in self.manifest.unit_test_patches:
raise DuplicatePatchNameError(
NodeType.Unit, patch, self.manifest.unit_test_patches[key]
)
self.manifest.unit_test_patches[key] = patch
source_file.unit_test_patches.append(key)
else:
unit_test = self._target_from_dict(UnparsedUnitTest, data)
self.add_unit_test_definition(unit_test)
return ParseResult()
def add_unit_test_definition(self, unit_test: UnparsedUnitTest) -> None:
unit_test_case_unique_id = (
f"{NodeType.Unit}.{self.project.project_name}.{unit_test.model}.{unit_test.name}"
)
unit_test_fqn = self._build_fqn(
self.project.project_name,
self.yaml.path.original_file_path,
unit_test.model,
unit_test.name,
)
# unit_test_config = self._build_unit_test_config(unit_test_fqn, unit_test.config)
unit_test_definition = UnpatchedUnitTestDefinition(
name=unit_test.name,
package_name=self.project.project_name,
path=self.yaml.path.relative_path,
original_file_path=self.yaml.path.original_file_path,
unique_id=unit_test_case_unique_id,
resource_type=NodeType.Unit,
fqn=unit_test_fqn,
model=unit_test.model,
given=unit_test.given,
expect=unit_test.expect,
versions=unit_test.versions,
description=unit_test.description,
overrides=unit_test.overrides,
config=unit_test.config,
)
# Check that format and type of rows matches for each given input,
# convert rows to a list of dictionaries, and add the unique_id of
# the unit_test_definition to the fixture source_file for partial parsing.
self._validate_and_normalize_given(unit_test_definition)
self._validate_and_normalize_expect(unit_test_definition)
# # for calculating state:modified
# unit_test_definition.build_unit_test_checksum()
self.manifest.add_unit_test(self.yaml.file, unit_test_definition)
def _build_unit_test_config(
self, unit_test_fqn: List[str], config_dict: Dict[str, Any]
) -> UnitTestConfig:
config = ContextConfig(
self.schema_parser.root_project,
unit_test_fqn,
NodeType.Unit,
self.schema_parser.project.project_name,
)
unit_test_config_dict = config.build_config_dict(patch_config_dict=config_dict)
unit_test_config_dict = self.render_entry(unit_test_config_dict)
return UnitTestConfig.from_dict(unit_test_config_dict)
def _build_fqn(self, package_name, original_file_path, model_name, test_name):
# This code comes from "get_fqn" and "get_fqn_prefix" in the base parser.
# We need to get the directories underneath the model-path.
path = Path(original_file_path)
relative_path = str(path.relative_to(*path.parts[:1]))
no_ext = os.path.splitext(relative_path)[0]
fqn = [package_name]
fqn.extend(utils.split_path(no_ext)[:-1])
fqn.append(model_name)
fqn.append(test_name)
return fqn
def _get_fixture(self, fixture_name: str, project_name: str):
fixture_unique_id = f"{NodeType.Fixture}.{project_name}.{fixture_name}"
if fixture_unique_id in self.manifest.fixtures:
fixture = self.manifest.fixtures[fixture_unique_id]
return fixture
else:
raise ParsingError(
f"File not found for fixture '{fixture_name}' in unit tests in {self.yaml.path.original_file_path}"
)
def _validate_and_normalize_given(self, unit_test_definition):
for ut_input in unit_test_definition.given:
self._validate_and_normalize_rows(ut_input, unit_test_definition, "input")
def _validate_and_normalize_expect(self, unit_test_definition):
self._validate_and_normalize_rows(
unit_test_definition.expect, unit_test_definition, "expected"
)
def _validate_and_normalize_rows(self, ut_fixture, unit_test_definition, fixture_type) -> None:
if ut_fixture.format == UnitTestFormat.Dict:
if ut_fixture.rows is None and ut_fixture.fixture is None: # This is a seed
ut_fixture.rows = self._load_rows_from_seed(ut_fixture.input)
if not isinstance(ut_fixture.rows, list):
raise ParsingError(
f"Unit test {unit_test_definition.name} has {fixture_type} rows "
f"which do not match format {ut_fixture.format}"
)
elif ut_fixture.format == UnitTestFormat.CSV:
if not (isinstance(ut_fixture.rows, str) or isinstance(ut_fixture.fixture, str)):
raise ParsingError(
f"Unit test {unit_test_definition.name} has {fixture_type} rows or fixtures "
f"which do not match format {ut_fixture.format}. Expected string."
)
if ut_fixture.fixture:
# find fixture file object and store unit_test_definition unique_id
fixture = self._get_fixture(ut_fixture.fixture, self.project.project_name)
fixture_source_file = self.manifest.files[fixture.file_id]
fixture_source_file.unit_tests.append(unit_test_definition.unique_id)
ut_fixture.rows = fixture.rows
else:
ut_fixture.rows = self._convert_csv_to_list_of_dicts(ut_fixture.rows)
def _convert_csv_to_list_of_dicts(self, csv_string: str) -> List[Dict[str, Any]]:
dummy_file = StringIO(csv_string)
reader = csv.DictReader(dummy_file)
rows = []
for row in reader:
rows.append(row)
return rows
def _load_rows_from_seed(self, ref_str: str) -> List[Dict[str, Any]]:
"""Read rows from seed file on disk if not specified in YAML config. If seed file doesn't exist, return empty list."""
ref = py_extract_from_source("{{ " + ref_str + " }}")["refs"][0]
rows: List[Dict[str, Any]] = []
seed_name = ref["name"]
package_name = ref.get("package", self.project.project_name)
seed_node = self.manifest.ref_lookup.find(seed_name, package_name, None, self.manifest)
if not seed_node or seed_node.resource_type != NodeType.Seed:
# Seed not found in custom package specified
if package_name != self.project.project_name:
raise ParsingError(
f"Unable to find seed '{package_name}.{seed_name}' for unit tests in '{package_name}' package"
)
else:
raise ParsingError(
f"Unable to find seed '{package_name}.{seed_name}' for unit tests in directories: {self.project.seed_paths}"
)
seed_path = Path(seed_node.root_path) / seed_node.original_file_path
with open(seed_path, "r") as f:
for row in DictReader(f):
rows.append(row)
return rows
# TODO: add more context for why we patch unit tests
class UnitTestPatcher:
def __init__(
self,
root_project: RuntimeConfig,
manifest: Manifest,
) -> None:
self.root_project = root_project
self.manifest = manifest
self.patches_used: Dict[str, Set[str]] = {}
self.unit_tests: Dict[str, UnitTestDefinition] = {}
# This method calls the 'parse_unit_test' method which takes
# the UnpatchedUnitTestDefinitions in the manifest and combines them
# with what we know about versioned models to generate appropriate
# unit tests
def construct_unit_tests(self) -> None:
for unique_id, unpatched in self.manifest.unit_tests.items():
# schema_file = self.manifest.files[unpatched.file_id]
if isinstance(unpatched, UnitTestDefinition):
# In partial parsing, there will be UnitTestDefinition
# which must be retained.
self.unit_tests[unpatched.unique_id] = unpatched
continue
# returns None if there is no patch
patch = self.get_patch_for(unpatched)
# returns unpatched if there is no patch
patched = self.patch_unit_test(unpatched, patch)
# Convert UnpatchedUnitTestDefinition to a list of UnitTestDefinition base don model versions
parsed_unit_tests = self.parse_unit_test(patched)
for unit_test in parsed_unit_tests:
self.unit_tests[unit_test.unique_id] = unit_test
def patch_unit_test(
self,
unpatched: UnpatchedUnitTestDefinition,
patch: Optional[UnitTestPatch],
) -> UnpatchedUnitTestDefinition:
# This skips patching if no patch exists because of the
# performance overhead of converting to and from dicts
if patch is None:
return unpatched
unit_test_dct = unpatched.to_dict(omit_none=True)
patch_path: Optional[Path] = None
if patch is not None:
unit_test_dct.update(patch.to_patch_dict())
patch_path = patch.path
unit_test = UnparsedUnitTest.from_dict(unit_test_dct)
return unpatched.replace(unit_test=unit_test, patch_path=patch_path)
# This converts an UnpatchedUnitTestDefinition to a UnitTestDefinition
# It returns a list of UnitTestDefinitions because a single UnpatchedUnitTestDefinition may be
# multiple unit tests if the model is versioned.
def parse_unit_test(self, unit_test: UnpatchedUnitTestDefinition) -> List[UnitTestDefinition]:
version_list = self.get_unit_test_versions(
model_name=unit_test.model, versions=unit_test.versions
)
if not version_list:
return [self.build_unit_test_definition(unit_test=unit_test, version=None)]
return [
self.build_unit_test_definition(unit_test=unit_test, version=v) for v in version_list
]
def _find_tested_model_node(
self, unit_test: UnpatchedUnitTestDefinition, model_version: Optional[NodeVersion]
) -> ModelNode:
package_name = unit_test.package_name
# TODO: does this work when `define_id` is used in the yaml?
model_name_split = unit_test.model.split()
model_name = model_name_split[0]
tested_node = self.manifest.ref_lookup.find(
model_name, package_name, model_version, self.manifest
)
if not tested_node:
raise ParsingError(
f"Unable to find model '{package_name}.{unit_test.model}' for unit tests in {unit_test.original_file_path}"
)
return tested_node
def build_unit_test_definition(
self, unit_test: UnpatchedUnitTestDefinition, version: Optional[NodeVersion]
) -> UnitTestDefinition:
config = self._generate_unit_test_config(
target=unit_test,
rendered=True,
)
unit_test_config = config.finalize_and_validate()
if not isinstance(config, UnitTestConfig):
raise DbtInternalError(
f"Calculated a {type(config)} for a unit test, but expected a UnitTestConfig"
)
tested_model_node = self._find_tested_model_node(unit_test, model_version=version)
unit_test_name = f"{unit_test.name}.v{version}" if version else unit_test.name
unit_test_case_unique_id = (
f"{NodeType.Unit}.{unit_test.package_name}.{unit_test.model}.{unit_test_name}"
)
unit_test_model_name = f"{unit_test.model}.v{version}" if version else unit_test.model
unit_test_fqn = self._build_fqn(
unit_test.package_name,
unit_test.original_file_path,
unit_test_model_name,
unit_test_name,
)
parsed_unit_test = UnitTestDefinition(
name=unit_test_name,
model=unit_test_model_name,
resource_type=NodeType.Unit,
package_name=unit_test.package_name,
path=unit_test.path,
original_file_path=unit_test.original_file_path,
unique_id=unit_test_case_unique_id,
version=version,
given=unit_test.given,
expect=unit_test.expect,
description=unit_test.description,
overrides=unit_test.overrides,
depends_on=DependsOn(nodes=[tested_model_node.unique_id]),
fqn=unit_test_fqn,
config=unit_test_config,
schema=tested_model_node.schema,
)
# relation name is added after instantiation because the adapter does
# not provide the relation name for a UnpatchedSourceDefinition object
return parsed_unit_test
def _build_fqn(self, package_name, original_file_path, model_name, test_name):
# This code comes from "get_fqn" and "get_fqn_prefix" in the base parser.
# We need to get the directories underneath the model-path.
path = Path(original_file_path)
relative_path = str(path.relative_to(*path.parts[:1]))
no_ext = os.path.splitext(relative_path)[0]
fqn = [package_name]
fqn.extend(utils.split_path(no_ext)[:-1])
fqn.append(model_name)
fqn.append(test_name)
return fqn
def get_unit_test_versions(
self, model_name: str, versions: Optional[UnitTestNodeVersion]
) -> List[Optional[NodeVersion]]:
version_list = []
if versions is None:
for node in self.manifest.nodes.values():
# only modelnodes have unit tests
if isinstance(node, ModelNode) and node.is_versioned:
if node.name == model_name:
version_list.append(node.version)
elif versions.exclude is not None:
for node in self.manifest.nodes.values():
# only modelnodes have unit tests
if isinstance(node, ModelNode) and node.is_versioned:
if node.name == model_name:
# no version has been specified and this version is not explicitly excluded
if node.version not in versions.exclude:
version_list.append(node.version)
# versions were explicitly included
elif versions.include is not None:
for i in versions.include:
# todo: does this actually need reformatting?
version_list.append(i)
return version_list
def get_patch_for(
self,
unpatched: UnpatchedUnitTestDefinition,
) -> Optional[UnitTestPatch]:
if isinstance(unpatched, UnitTestDefinition):
return None
key = unpatched.name
patch: Optional[UnitTestPatch] = self.manifest.unit_test_patches.get(key)
if patch is None:
return None
if key not in self.patches_used:
# mark the key as used
self.patches_used[key] = set()
return patch
def _generate_unit_test_config(self, target: UnpatchedUnitTestDefinition, rendered: bool):
generator: BaseContextConfigGenerator
if rendered:
generator = ContextConfigGenerator(self.root_project)
else:
generator = UnrenderedConfigGenerator(self.root_project)
# configs with precendence set
precedence_configs = dict()
precedence_configs.update(target.config)
return generator.calculate_node_config(
config_call_dict={},
fqn=target.fqn,
resource_type=NodeType.Unit,
project_name=target.package_name,
base=False,
patch_config_dict=precedence_configs,
)

View File

@@ -310,7 +310,7 @@ class BaseRunner(metaclass=ABCMeta):
with collect_timing_info("compile", ctx.timing.append): with collect_timing_info("compile", ctx.timing.append):
# if we fail here, we still have a compiled node to return # if we fail here, we still have a compiled node to return
# this has the benefit of showing a build path for the errant # this has the benefit of showing a build path for the errant
# model # model. This calls the 'compile' method in CompileTask
ctx.node = self.compile(manifest) ctx.node = self.compile(manifest)
# for ephemeral nodes, we only want to compile, not run # for ephemeral nodes, we only want to compile, not run

View File

@@ -1,4 +1,5 @@
import threading import threading
from typing import Dict, List, Set
from .run import RunTask, ModelRunner as run_model_runner from .run import RunTask, ModelRunner as run_model_runner
from .snapshot import SnapshotRunner as snapshot_model_runner from .snapshot import SnapshotRunner as snapshot_model_runner
@@ -7,7 +8,7 @@ from .test import TestRunner as test_runner
from dbt.contracts.results import NodeStatus from dbt.contracts.results import NodeStatus
from dbt.common.exceptions import DbtInternalError from dbt.common.exceptions import DbtInternalError
from dbt.graph import ResourceTypeSelector from dbt.graph import ResourceTypeSelector, GraphQueue, Graph
from dbt.node_types import NodeType from dbt.node_types import NodeType
from dbt.task.test import TestSelector from dbt.task.test import TestSelector
from dbt.task.base import BaseRunner from dbt.task.base import BaseRunner
@@ -76,38 +77,144 @@ class BuildTask(RunTask):
I.E. a resource of type Model is handled by the ModelRunner which is I.E. a resource of type Model is handled by the ModelRunner which is
imported as run_model_runner.""" imported as run_model_runner."""
MARK_DEPENDENT_ERRORS_STATUSES = [NodeStatus.Error, NodeStatus.Fail] MARK_DEPENDENT_ERRORS_STATUSES = [NodeStatus.Error, NodeStatus.Fail, NodeStatus.Skipped]
RUNNER_MAP = { RUNNER_MAP = {
NodeType.Model: run_model_runner, NodeType.Model: run_model_runner,
NodeType.Snapshot: snapshot_model_runner, NodeType.Snapshot: snapshot_model_runner,
NodeType.Seed: seed_runner, NodeType.Seed: seed_runner,
NodeType.Test: test_runner, NodeType.Test: test_runner,
NodeType.Unit: test_runner,
} }
ALL_RESOURCE_VALUES = frozenset({x for x in RUNNER_MAP.keys()}) ALL_RESOURCE_VALUES = frozenset({x for x in RUNNER_MAP.keys()})
@property def __init__(self, args, config, manifest) -> None:
def resource_types(self): super().__init__(args, config, manifest)
self.selected_unit_tests: Set = set()
self.model_to_unit_test_map: Dict[str, List] = {}
def resource_types(self, no_unit_tests=False):
if self.args.include_saved_query: if self.args.include_saved_query:
self.RUNNER_MAP[NodeType.SavedQuery] = SavedQueryRunner self.RUNNER_MAP[NodeType.SavedQuery] = SavedQueryRunner
self.ALL_RESOURCE_VALUES = self.ALL_RESOURCE_VALUES.union({NodeType.SavedQuery}) self.ALL_RESOURCE_VALUES = self.ALL_RESOURCE_VALUES.union({NodeType.SavedQuery})
if not self.args.resource_types: if not self.args.resource_types:
return list(self.ALL_RESOURCE_VALUES) resource_types = list(self.ALL_RESOURCE_VALUES)
else:
resource_types = set(self.args.resource_types)
values = set(self.args.resource_types) if "all" in resource_types:
resource_types.remove("all")
resource_types.update(self.ALL_RESOURCE_VALUES)
if "all" in values: # First we get selected_nodes including unit tests, then without,
values.remove("all") # and do a set difference.
values.update(self.ALL_RESOURCE_VALUES) if no_unit_tests is True and NodeType.Unit in resource_types:
resource_types.remove(NodeType.Unit)
return list(resource_types)
return list(values) # overrides get_graph_queue in runnable.py
def get_graph_queue(self) -> GraphQueue:
# Following uses self.selection_arg and self.exclusion_arg
spec = self.get_selection_spec()
def get_node_selector(self) -> ResourceTypeSelector: # selector including unit tests
full_selector = self.get_node_selector(no_unit_tests=False)
# selected node unique_ids with unit_tests
full_selected_nodes = full_selector.get_selected(spec)
# This selector removes the unit_tests from the selector
selector_wo_unit_tests = self.get_node_selector(no_unit_tests=True)
# selected node unique_ids without unit_tests
selected_nodes_wo_unit_tests = selector_wo_unit_tests.get_selected(spec)
# Get the difference in the sets of nodes with and without unit tests and
# save it
selected_unit_tests = full_selected_nodes - selected_nodes_wo_unit_tests
self.selected_unit_tests = selected_unit_tests
self.build_model_to_unit_test_map(selected_unit_tests)
# get_graph_queue in the selector will remove NodeTypes not specified
# in the node_selector (filter_selection).
return selector_wo_unit_tests.get_graph_queue(spec)
# overrides handle_job_queue in runnable.py
def handle_job_queue(self, pool, callback):
if self.run_count == 0:
self.num_nodes = self.num_nodes + len(self.selected_unit_tests)
node = self.job_queue.get()
if (
node.resource_type == NodeType.Model
and self.model_to_unit_test_map
and node.unique_id in self.model_to_unit_test_map
):
self.handle_model_with_unit_tests_node(node, pool, callback)
else:
self.handle_job_queue_node(node, pool, callback)
def handle_model_with_unit_tests_node(self, node, pool, callback):
self._raise_set_error()
args = [node]
if self.config.args.single_threaded:
callback(self.call_model_and_unit_tests_runner(*args))
else:
pool.apply_async(self.call_model_and_unit_tests_runner, args=args, callback=callback)
def call_model_and_unit_tests_runner(self, node) -> RunResult:
assert self.manifest
for unit_test_unique_id in self.model_to_unit_test_map[node.unique_id]:
unit_test_node = self.manifest.unit_tests[unit_test_unique_id]
unit_test_runner = self.get_runner(unit_test_node)
# If the model is marked skip, also skip the unit tests
if node.unique_id in self._skipped_children:
# cause is only for ephemeral nodes
unit_test_runner.do_skip(cause=None)
result = self.call_runner(unit_test_runner)
self._handle_result(result)
if result.status in self.MARK_DEPENDENT_ERRORS_STATUSES:
# The _skipped_children dictionary can contain a run_result for ephemeral nodes,
# but that should never be the case here.
self._skipped_children[node.unique_id] = None
runner = self.get_runner(node)
if runner.node.unique_id in self._skipped_children:
cause = self._skipped_children.pop(runner.node.unique_id)
runner.do_skip(cause=cause)
return self.call_runner(runner)
# handle non-model-plus-unit-tests nodes
def handle_job_queue_node(self, node, pool, callback):
self._raise_set_error()
runner = self.get_runner(node)
# we finally know what we're running! Make sure we haven't decided
# to skip it due to upstream failures
if runner.node.unique_id in self._skipped_children:
cause = self._skipped_children.pop(runner.node.unique_id)
runner.do_skip(cause=cause)
args = [runner]
if self.config.args.single_threaded:
callback(self.call_runner(*args))
else:
pool.apply_async(self.call_runner, args=args, callback=callback)
# Make a map of model unique_ids to selected unit test unique_ids,
# for processing before the model.
def build_model_to_unit_test_map(self, selected_unit_tests):
dct = {}
for unit_test_unique_id in selected_unit_tests:
unit_test = self.manifest.unit_tests[unit_test_unique_id]
model_unique_id = unit_test.depends_on.nodes[0]
if model_unique_id not in dct:
dct[model_unique_id] = []
dct[model_unique_id].append(unit_test.unique_id)
self.model_to_unit_test_map = dct
# We return two different kinds of selectors, one with unit tests and one without
def get_node_selector(self, no_unit_tests=False) -> ResourceTypeSelector:
if self.manifest is None or self.graph is None: if self.manifest is None or self.graph is None:
raise DbtInternalError("manifest and graph must be set to get node selection") raise DbtInternalError("manifest and graph must be set to get node selection")
resource_types = self.resource_types resource_types = self.resource_types(no_unit_tests)
if resource_types == [NodeType.Test]: if resource_types == [NodeType.Test]:
return TestSelector( return TestSelector(
@@ -125,7 +232,8 @@ class BuildTask(RunTask):
def get_runner_type(self, node): def get_runner_type(self, node):
return self.RUNNER_MAP.get(node.resource_type) return self.RUNNER_MAP.get(node.resource_type)
def compile_manifest(self): # Special build compile_manifest method to pass add_test_edges to the compiler
def compile_manifest(self) -> None:
if self.manifest is None: if self.manifest is None:
raise DbtInternalError("compile_manifest called before manifest was loaded") raise DbtInternalError("compile_manifest called before manifest was loaded")
self.graph = self.compiler.compile(self.manifest, add_test_edges=True) self.graph: Graph = self.compiler.compile(self.manifest, add_test_edges=True)

View File

View File

@@ -54,6 +54,7 @@ from dbt.logger import (
ModelMetadata, ModelMetadata,
NodeCount, NodeCount,
) )
from dbt.node_types import NodeType
from dbt.parser.manifest import write_manifest from dbt.parser.manifest import write_manifest
from dbt.task.base import ConfiguredTask, BaseRunner from dbt.task.base import ConfiguredTask, BaseRunner
from .printer import ( from .printer import (
@@ -123,6 +124,7 @@ class GraphRunnableTask(ConfiguredTask):
fire_event(DefaultSelector(name=default_selector_name)) fire_event(DefaultSelector(name=default_selector_name))
spec = self.config.get_selector(default_selector_name) spec = self.config.get_selector(default_selector_name)
else: else:
# This is what's used with no default selector and no selection
# use --select and --exclude args # use --select and --exclude args
spec = parse_difference(self.selection_arg, self.exclusion_arg, indirect_selection) spec = parse_difference(self.selection_arg, self.exclusion_arg, indirect_selection)
return spec return spec
@@ -137,6 +139,7 @@ class GraphRunnableTask(ConfiguredTask):
def get_graph_queue(self) -> GraphQueue: def get_graph_queue(self) -> GraphQueue:
selector = self.get_node_selector() selector = self.get_node_selector()
# Following uses self.selection_arg and self.exclusion_arg
spec = self.get_selection_spec() spec = self.get_selection_spec()
return selector.get_graph_queue(spec) return selector.get_graph_queue(spec)
@@ -156,9 +159,11 @@ class GraphRunnableTask(ConfiguredTask):
self._flattened_nodes.append(self.manifest.sources[uid]) self._flattened_nodes.append(self.manifest.sources[uid])
elif uid in self.manifest.saved_queries: elif uid in self.manifest.saved_queries:
self._flattened_nodes.append(self.manifest.saved_queries[uid]) self._flattened_nodes.append(self.manifest.saved_queries[uid])
elif uid in self.manifest.unit_tests:
self._flattened_nodes.append(self.manifest.unit_tests[uid])
else: else:
raise DbtInternalError( raise DbtInternalError(
f"Node selection returned {uid}, expected a node or a source" f"Node selection returned {uid}, expected a node, a source, or a unit test"
) )
self.num_nodes = len([n for n in self._flattened_nodes if not n.is_ephemeral_model]) self.num_nodes = len([n for n in self._flattened_nodes if not n.is_ephemeral_model])
@@ -207,6 +212,8 @@ class GraphRunnableTask(ConfiguredTask):
status: Dict[str, str] = {} status: Dict[str, str] = {}
try: try:
result = runner.run_with_hooks(self.manifest) result = runner.run_with_hooks(self.manifest)
except Exception as exc:
raise DbtInternalError(f"Unable to execute node: {exc}")
finally: finally:
finishctx = TimestampNamed("finished_at") finishctx = TimestampNamed("finished_at")
with finishctx, DbtModelState(status): with finishctx, DbtModelState(status):
@@ -217,8 +224,9 @@ class GraphRunnableTask(ConfiguredTask):
) )
) )
# `_event_status` dict is only used for logging. Make sure # `_event_status` dict is only used for logging. Make sure
# it gets deleted when we're done with it # it gets deleted when we're done with it, except for unit tests
runner.node.clear_event_status() if not runner.node.resource_type == NodeType.Unit:
runner.node.clear_event_status()
fail_fast = get_flags().FAIL_FAST fail_fast = get_flags().FAIL_FAST
@@ -270,16 +278,7 @@ class GraphRunnableTask(ConfiguredTask):
self.job_queue.mark_done(result.node.unique_id) self.job_queue.mark_done(result.node.unique_id)
while not self.job_queue.empty(): while not self.job_queue.empty():
node = self.job_queue.get() self.handle_job_queue(pool, callback)
self._raise_set_error()
runner = self.get_runner(node)
# we finally know what we're running! Make sure we haven't decided
# to skip it due to upstream failures
if runner.node.unique_id in self._skipped_children:
cause = self._skipped_children.pop(runner.node.unique_id)
runner.do_skip(cause=cause)
args = (runner,)
self._submit(pool, args, callback)
# block on completion # block on completion
if get_flags().FAIL_FAST: if get_flags().FAIL_FAST:
@@ -296,6 +295,19 @@ class GraphRunnableTask(ConfiguredTask):
return return
# The build command overrides this
def handle_job_queue(self, pool, callback):
node = self.job_queue.get()
self._raise_set_error()
runner = self.get_runner(node)
# we finally know what we're running! Make sure we haven't decided
# to skip it due to upstream failures
if runner.node.unique_id in self._skipped_children:
cause = self._skipped_children.pop(runner.node.unique_id)
runner.do_skip(cause=cause)
args = [runner]
self._submit(pool, args, callback)
def _handle_result(self, result: RunResult): def _handle_result(self, result: RunResult):
"""Mark the result as completed, insert the `CompileResultNode` into """Mark the result as completed, insert the `CompileResultNode` into
the manifest, and mark any descendants (potentially with a 'cause' if the manifest, and mark any descendants (potentially with a 'cause' if
@@ -310,6 +322,7 @@ class GraphRunnableTask(ConfiguredTask):
if self.manifest is None: if self.manifest is None:
raise DbtInternalError("manifest was None in _handle_result") raise DbtInternalError("manifest was None in _handle_result")
# If result.status == NodeStatus.Error, plus Fail for build command
if result.status in self.MARK_DEPENDENT_ERRORS_STATUSES: if result.status in self.MARK_DEPENDENT_ERRORS_STATUSES:
if is_ephemeral: if is_ephemeral:
cause = result cause = result

View File

@@ -1,22 +1,24 @@
from distutils.util import strtobool from distutils.util import strtobool
import agate
import daff
import re
from dataclasses import dataclass from dataclasses import dataclass
from dbt.utils import _coerce_decimal from dbt.utils import _coerce_decimal
from dbt.common.events.format import pluralize from dbt.common.events.format import pluralize
from dbt.common.dataclass_schema import dbtClassMixin from dbt.common.dataclass_schema import dbtClassMixin
import threading import threading
from typing import Dict, Any from typing import Dict, Any, Optional, Union, List
from .compile import CompileRunner from .compile import CompileRunner
from .run import RunTask from .run import RunTask
from dbt.contracts.graph.nodes import ( from dbt.contracts.graph.nodes import TestNode, UnitTestDefinition, UnitTestNode
TestNode,
)
from dbt.contracts.graph.manifest import Manifest from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.results import TestStatus, PrimitiveDict, RunResult from dbt.contracts.results import TestStatus, PrimitiveDict, RunResult
from dbt.context.providers import generate_runtime_model_context from dbt.context.providers import generate_runtime_model_context
from dbt.clients.jinja import MacroGenerator from dbt.clients.jinja import MacroGenerator
from dbt.common.clients.agate_helper import list_rows_from_table, json_rows_from_table
from dbt.common.events.functions import fire_event from dbt.common.events.functions import fire_event
from dbt.common.events.types import ( from dbt.common.events.types import (
LogTestResult, LogTestResult,
@@ -31,7 +33,16 @@ from dbt.graph import (
ResourceTypeSelector, ResourceTypeSelector,
) )
from dbt.node_types import NodeType from dbt.node_types import NodeType
from dbt.parser.unit_tests import UnitTestManifestLoader
from dbt.flags import get_flags from dbt.flags import get_flags
from dbt.common.ui import green, red
@dataclass
class UnitTestDiff(dbtClassMixin):
actual: List[Dict[str, Any]]
expected: List[Dict[str, Any]]
rendered: str
@dataclass @dataclass
@@ -59,10 +70,18 @@ class TestResultData(dbtClassMixin):
return bool(field) return bool(field)
@dataclass
class UnitTestResultData(dbtClassMixin):
should_error: bool
adapter_response: Dict[str, Any]
diff: Optional[UnitTestDiff] = None
class TestRunner(CompileRunner): class TestRunner(CompileRunner):
_ANSI_ESCAPE = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")
def describe_node(self): def describe_node(self):
node_name = self.node.name return f"{self.node.resource_type} {self.node.name}"
return "test {}".format(node_name)
def print_result_line(self, result): def print_result_line(self, result):
model = result.node model = result.node
@@ -93,16 +112,16 @@ class TestRunner(CompileRunner):
def before_execute(self): def before_execute(self):
self.print_start_line() self.print_start_line()
def execute_test(self, test: TestNode, manifest: Manifest) -> TestResultData: def execute_data_test(self, data_test: TestNode, manifest: Manifest) -> TestResultData:
context = generate_runtime_model_context(test, self.config, manifest) context = generate_runtime_model_context(data_test, self.config, manifest)
materialization_macro = manifest.find_materialization_macro_by_name( materialization_macro = manifest.find_materialization_macro_by_name(
self.config.project_name, test.get_materialization(), self.adapter.type() self.config.project_name, data_test.get_materialization(), self.adapter.type()
) )
if materialization_macro is None: if materialization_macro is None:
raise MissingMaterializationError( raise MissingMaterializationError(
materialization=test.get_materialization(), adapter_type=self.adapter.type() materialization=data_test.get_materialization(), adapter_type=self.adapter.type()
) )
if "config" not in context: if "config" not in context:
@@ -121,14 +140,14 @@ class TestRunner(CompileRunner):
num_rows = len(table.rows) num_rows = len(table.rows)
if num_rows != 1: if num_rows != 1:
raise DbtInternalError( raise DbtInternalError(
f"dbt internally failed to execute {test.unique_id}: " f"dbt internally failed to execute {data_test.unique_id}: "
f"Returned {num_rows} rows, but expected " f"Returned {num_rows} rows, but expected "
f"1 row" f"1 row"
) )
num_cols = len(table.columns) num_cols = len(table.columns)
if num_cols != 3: if num_cols != 3:
raise DbtInternalError( raise DbtInternalError(
f"dbt internally failed to execute {test.unique_id}: " f"dbt internally failed to execute {data_test.unique_id}: "
f"Returned {num_cols} columns, but expected " f"Returned {num_cols} columns, but expected "
f"3 columns" f"3 columns"
) )
@@ -143,9 +162,87 @@ class TestRunner(CompileRunner):
TestResultData.validate(test_result_dct) TestResultData.validate(test_result_dct)
return TestResultData.from_dict(test_result_dct) return TestResultData.from_dict(test_result_dct)
def execute(self, test: TestNode, manifest: Manifest): def build_unit_test_manifest_from_test(
result = self.execute_test(test, manifest) self, unit_test_def: UnitTestDefinition, manifest: Manifest
) -> Manifest:
# build a unit test manifest with only the test from this UnitTestDefinition
loader = UnitTestManifestLoader(manifest, self.config, {unit_test_def.unique_id})
return loader.load()
def execute_unit_test(
self, unit_test_def: UnitTestDefinition, manifest: Manifest
) -> UnitTestResultData:
unit_test_manifest = self.build_unit_test_manifest_from_test(unit_test_def, manifest)
# The unit test node and definition have the same unique_id
unit_test_node = unit_test_manifest.nodes[unit_test_def.unique_id]
assert isinstance(unit_test_node, UnitTestNode)
# Compile the node
unit_test_node = self.compiler.compile_node(unit_test_node, unit_test_manifest, {})
# generate_runtime_unit_test_context not strictly needed - this is to run the 'unit'
# materialization, not compile the node.compiled_code
context = generate_runtime_model_context(unit_test_node, self.config, unit_test_manifest)
materialization_macro = unit_test_manifest.find_materialization_macro_by_name(
self.config.project_name, unit_test_node.get_materialization(), self.adapter.type()
)
if materialization_macro is None:
raise MissingMaterializationError(
materialization=unit_test_node.get_materialization(),
adapter_type=self.adapter.type(),
)
if "config" not in context:
raise DbtInternalError(
"Invalid materialization context generated, missing config: {}".format(context)
)
# generate materialization macro
macro_func = MacroGenerator(materialization_macro, context)
# execute materialization macro
macro_func()
# load results from context
# could eventually be returned directly by materialization
result = context["load_result"]("main")
adapter_response = result["response"].to_dict(omit_none=True)
table = result["table"]
actual = self._get_unit_test_agate_table(table, "actual")
expected = self._get_unit_test_agate_table(table, "expected")
# generate diff, if exists
should_error, diff = False, None
daff_diff = self._get_daff_diff(expected, actual)
if daff_diff.hasDifference():
should_error = True
rendered = self._render_daff_diff(daff_diff)
rendered = f"\n\n{red('expected')} differs from {green('actual')}:\n\n{rendered}\n"
diff = UnitTestDiff(
actual=json_rows_from_table(actual),
expected=json_rows_from_table(expected),
rendered=rendered,
)
return UnitTestResultData(
diff=diff,
should_error=should_error,
adapter_response=adapter_response,
)
def execute(self, test: Union[TestNode, UnitTestDefinition], manifest: Manifest):
if isinstance(test, UnitTestDefinition):
unit_test_result = self.execute_unit_test(test, manifest)
return self.build_unit_test_run_result(test, unit_test_result)
else:
# Note: manifest here is a normal manifest
test_result = self.execute_data_test(test, manifest)
return self.build_test_run_result(test, test_result)
def build_test_run_result(self, test: TestNode, result: TestResultData) -> RunResult:
severity = test.config.severity.upper() severity = test.config.severity.upper()
thread_id = threading.current_thread().name thread_id = threading.current_thread().name
num_errors = pluralize(result.failures, "result") num_errors = pluralize(result.failures, "result")
@@ -167,6 +264,31 @@ class TestRunner(CompileRunner):
else: else:
status = TestStatus.Pass status = TestStatus.Pass
run_result = RunResult(
node=test,
status=status,
timing=[],
thread_id=thread_id,
execution_time=0,
message=message,
adapter_response=result.adapter_response,
failures=failures,
)
return run_result
def build_unit_test_run_result(
self, test: UnitTestDefinition, result: UnitTestResultData
) -> RunResult:
thread_id = threading.current_thread().name
status = TestStatus.Pass
message = None
failures = 0
if result.should_error:
status = TestStatus.Fail
message = result.diff.rendered if result.diff else None
failures = 1
return RunResult( return RunResult(
node=test, node=test,
status=status, status=status,
@@ -181,6 +303,41 @@ class TestRunner(CompileRunner):
def after_execute(self, result): def after_execute(self, result):
self.print_result_line(result) self.print_result_line(result)
def _get_unit_test_agate_table(self, result_table, actual_or_expected: str):
unit_test_table = result_table.where(
lambda row: row["actual_or_expected"] == actual_or_expected
)
columns = list(unit_test_table.columns.keys())
columns.remove("actual_or_expected")
return unit_test_table.select(columns)
def _get_daff_diff(
self, expected: agate.Table, actual: agate.Table, ordered: bool = False
) -> daff.TableDiff:
expected_daff_table = daff.PythonTableView(list_rows_from_table(expected))
actual_daff_table = daff.PythonTableView(list_rows_from_table(actual))
alignment = daff.Coopy.compareTables(expected_daff_table, actual_daff_table).align()
result = daff.PythonTableView([])
flags = daff.CompareFlags()
flags.ordered = ordered
diff = daff.TableDiff(alignment, flags)
diff.hilite(result)
return diff
def _render_daff_diff(self, daff_diff: daff.TableDiff) -> str:
result = daff.PythonTableView([])
daff_diff.hilite(result)
rendered = daff.TerminalDiffRender().render(result)
# strip colors if necessary
if not self.config.args.use_colors:
rendered = self._ANSI_ESCAPE.sub("", rendered)
return rendered
class TestSelector(ResourceTypeSelector): class TestSelector(ResourceTypeSelector):
def __init__(self, graph, manifest, previous_state) -> None: def __init__(self, graph, manifest, previous_state) -> None:
@@ -188,7 +345,7 @@ class TestSelector(ResourceTypeSelector):
graph=graph, graph=graph,
manifest=manifest, manifest=manifest,
previous_state=previous_state, previous_state=previous_state,
resource_types=[NodeType.Test], resource_types=[NodeType.Test, NodeType.Unit],
) )

View File

@@ -81,6 +81,7 @@ setup(
"protobuf>=4.0.0", "protobuf>=4.0.0",
"pytz>=2015.7", "pytz>=2015.7",
"pyyaml>=6.0", "pyyaml>=6.0",
"daff>=1.3.46",
"typing-extensions>=4.4", "typing-extensions>=4.4",
# ---- # ----
# Match snowflake-connector-python, to ensure compatibility in dbt-snowflake # Match snowflake-connector-python, to ensure compatibility in dbt-snowflake

View File

@@ -30,7 +30,3 @@ services:
working_dir: /usr/app working_dir: /usr/app
depends_on: depends_on:
- database - database
networks:
default:
name: dbt-net

View File

@@ -3401,6 +3401,510 @@
"config" "config"
] ]
}, },
"UnitTestNodeConfig": {
"type": "object",
"title": "UnitTestNodeConfig",
"properties": {
"_extra": {
"type": "object",
"propertyNames": {
"type": "string"
}
},
"enabled": {
"type": "boolean",
"default": true
},
"alias": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null
},
"schema": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null
},
"database": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null
},
"tags": {
"anyOf": [
{
"type": "array",
"items": {
"type": "string"
}
},
{
"type": "string"
}
]
},
"meta": {
"type": "object",
"propertyNames": {
"type": "string"
}
},
"group": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null
},
"materialized": {
"type": "string",
"default": "view"
},
"incremental_strategy": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null
},
"persist_docs": {
"type": "object",
"propertyNames": {
"type": "string"
}
},
"post-hook": {
"type": "array",
"items": {
"$ref": "#/$defs/Hook"
}
},
"pre-hook": {
"type": "array",
"items": {
"$ref": "#/$defs/Hook"
}
},
"quoting": {
"type": "object",
"propertyNames": {
"type": "string"
}
},
"column_types": {
"type": "object",
"propertyNames": {
"type": "string"
}
},
"full_refresh": {
"anyOf": [
{
"type": "boolean"
},
{
"type": "null"
}
],
"default": null
},
"unique_key": {
"anyOf": [
{
"type": "string"
},
{
"type": "array",
"items": {
"type": "string"
}
},
{
"type": "null"
}
],
"default": null
},
"on_schema_change": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": "ignore"
},
"on_configuration_change": {
"enum": [
"apply",
"continue",
"fail"
]
},
"grants": {
"type": "object",
"propertyNames": {
"type": "string"
}
},
"packages": {
"type": "array",
"items": {
"type": "string"
}
},
"docs": {
"$ref": "#/$defs/Docs"
},
"contract": {
"$ref": "#/$defs/ContractConfig"
},
"expected_rows": {
"type": "array",
"items": {
"type": "object",
"propertyNames": {
"type": "string"
}
}
}
},
"additionalProperties": true
},
"UnitTestOverrides": {
"type": "object",
"title": "UnitTestOverrides",
"properties": {
"macros": {
"type": "object",
"propertyNames": {
"type": "string"
}
},
"vars": {
"type": "object",
"propertyNames": {
"type": "string"
}
},
"env_vars": {
"type": "object",
"propertyNames": {
"type": "string"
}
}
},
"additionalProperties": false
},
"UnitTestNode": {
"type": "object",
"title": "UnitTestNode",
"properties": {
"database": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
]
},
"schema": {
"type": "string"
},
"name": {
"type": "string"
},
"resource_type": {
"const": "unit_test"
},
"package_name": {
"type": "string"
},
"path": {
"type": "string"
},
"original_file_path": {
"type": "string"
},
"unique_id": {
"type": "string"
},
"fqn": {
"type": "array",
"items": {
"type": "string"
}
},
"alias": {
"type": "string"
},
"checksum": {
"$ref": "#/$defs/FileHash"
},
"config": {
"$ref": "#/$defs/UnitTestNodeConfig"
},
"_event_status": {
"type": "object",
"propertyNames": {
"type": "string"
}
},
"tags": {
"type": "array",
"items": {
"type": "string"
}
},
"description": {
"type": "string",
"default": ""
},
"columns": {
"type": "object",
"additionalProperties": {
"$ref": "#/$defs/ColumnInfo"
},
"propertyNames": {
"type": "string"
}
},
"meta": {
"type": "object",
"propertyNames": {
"type": "string"
}
},
"group": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null
},
"docs": {
"$ref": "#/$defs/Docs"
},
"patch_path": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null
},
"build_path": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null
},
"deferred": {
"type": "boolean",
"default": false
},
"unrendered_config": {
"type": "object",
"propertyNames": {
"type": "string"
}
},
"created_at": {
"type": "number"
},
"config_call_dict": {
"type": "object",
"propertyNames": {
"type": "string"
}
},
"relation_name": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null
},
"raw_code": {
"type": "string",
"default": ""
},
"language": {
"type": "string",
"default": "sql"
},
"refs": {
"type": "array",
"items": {
"$ref": "#/$defs/RefArgs"
}
},
"sources": {
"type": "array",
"items": {
"type": "array",
"items": {
"type": "string"
}
}
},
"metrics": {
"type": "array",
"items": {
"type": "array",
"items": {
"type": "string"
}
}
},
"depends_on": {
"$ref": "#/$defs/DependsOn"
},
"compiled_path": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null
},
"compiled": {
"type": "boolean",
"default": false
},
"compiled_code": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null
},
"extra_ctes_injected": {
"type": "boolean",
"default": false
},
"extra_ctes": {
"type": "array",
"items": {
"$ref": "#/$defs/InjectedCTE"
}
},
"_pre_injected_sql": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null
},
"contract": {
"$ref": "#/$defs/Contract"
},
"tested_node_unique_id": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null
},
"this_input_node_unique_id": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null
},
"overrides": {
"anyOf": [
{
"$ref": "#/$defs/UnitTestOverrides"
},
{
"type": "null"
}
],
"default": null
}
},
"additionalProperties": false,
"required": [
"database",
"schema",
"name",
"resource_type",
"package_name",
"path",
"original_file_path",
"unique_id",
"fqn",
"alias",
"checksum"
]
},
"SeedConfig": { "SeedConfig": {
"type": "object", "type": "object",
"title": "SeedConfig", "title": "SeedConfig",
@@ -5223,6 +5727,31 @@
"propertyNames": { "propertyNames": {
"type": "string" "type": "string"
} }
},
"export_as": {
"anyOf": [
{
"enum": [
"table",
"view"
]
},
{
"type": "null"
}
],
"default": null
},
"schema": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null
} }
}, },
"additionalProperties": true "additionalProperties": true
@@ -5251,7 +5780,8 @@
"metric", "metric",
"group", "group",
"saved_query", "saved_query",
"semantic_model" "semantic_model",
"unit_test"
] ]
}, },
"package_name": { "package_name": {
@@ -5281,6 +5811,12 @@
"$ref": "#/$defs/Export" "$ref": "#/$defs/Export"
} }
}, },
"_event_status": {
"type": "object",
"propertyNames": {
"type": "string"
}
},
"description": { "description": {
"anyOf": [ "anyOf": [
{ {
@@ -5822,7 +6358,8 @@
"metric", "metric",
"group", "group",
"saved_query", "saved_query",
"semantic_model" "semantic_model",
"unit_test"
] ]
}, },
"package_name": { "package_name": {
@@ -5975,6 +6512,256 @@
"node_relation" "node_relation"
] ]
}, },
"UnitTestInputFixture": {
"type": "object",
"title": "UnitTestInputFixture",
"properties": {
"input": {
"type": "string"
},
"rows": {
"anyOf": [
{
"type": "string"
},
{
"type": "array",
"items": {
"type": "object",
"propertyNames": {
"type": "string"
}
}
},
{
"type": "null"
}
],
"default": null
},
"format": {
"enum": [
"csv",
"dict"
],
"default": "dict"
},
"fixture": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null
}
},
"additionalProperties": false,
"required": [
"input"
]
},
"UnitTestOutputFixture": {
"type": "object",
"title": "UnitTestOutputFixture",
"properties": {
"rows": {
"anyOf": [
{
"type": "string"
},
{
"type": "array",
"items": {
"type": "object",
"propertyNames": {
"type": "string"
}
}
},
{
"type": "null"
}
],
"default": null
},
"format": {
"enum": [
"csv",
"dict"
],
"default": "dict"
},
"fixture": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null
}
},
"additionalProperties": false
},
"UnitTestConfig": {
"type": "object",
"title": "UnitTestConfig",
"properties": {
"_extra": {
"type": "object",
"propertyNames": {
"type": "string"
}
},
"tags": {
"anyOf": [
{
"type": "string"
},
{
"type": "array",
"items": {
"type": "string"
}
}
]
},
"meta": {
"type": "object",
"propertyNames": {
"type": "string"
}
}
},
"additionalProperties": true
},
"UnitTestDefinition": {
"type": "object",
"title": "UnitTestDefinition",
"properties": {
"model": {
"type": "string"
},
"given": {
"type": "array",
"items": {
"$ref": "#/$defs/UnitTestInputFixture"
}
},
"expect": {
"$ref": "#/$defs/UnitTestOutputFixture"
},
"name": {
"type": "string"
},
"resource_type": {
"enum": [
"model",
"analysis",
"test",
"snapshot",
"operation",
"seed",
"rpc",
"sql_operation",
"doc",
"source",
"macro",
"exposure",
"metric",
"group",
"saved_query",
"semantic_model",
"unit_test"
]
},
"package_name": {
"type": "string"
},
"path": {
"type": "string"
},
"original_file_path": {
"type": "string"
},
"unique_id": {
"type": "string"
},
"fqn": {
"type": "array",
"items": {
"type": "string"
}
},
"_event_status": {
"type": "object",
"propertyNames": {
"type": "string"
}
},
"description": {
"type": "string",
"default": ""
},
"overrides": {
"anyOf": [
{
"$ref": "#/$defs/UnitTestOverrides"
},
{
"type": "null"
}
],
"default": null
},
"depends_on": {
"$ref": "#/$defs/DependsOn"
},
"config": {
"$ref": "#/$defs/UnitTestConfig"
},
"checksum": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null
},
"schema": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null
}
},
"additionalProperties": false,
"required": [
"model",
"given",
"expect",
"name",
"resource_type",
"package_name",
"path",
"original_file_path",
"unique_id",
"fqn"
]
},
"WritableManifest": { "WritableManifest": {
"type": "object", "type": "object",
"title": "WritableManifest", "title": "WritableManifest",
@@ -6012,6 +6799,9 @@
{ {
"$ref": "#/$defs/SnapshotNode" "$ref": "#/$defs/SnapshotNode"
}, },
{
"$ref": "#/$defs/UnitTestNode"
},
{ {
"$ref": "#/$defs/SeedNode" "$ref": "#/$defs/SeedNode"
} }
@@ -6121,6 +6911,9 @@
{ {
"$ref": "#/$defs/SnapshotNode" "$ref": "#/$defs/SnapshotNode"
}, },
{
"$ref": "#/$defs/UnitTestNode"
},
{ {
"$ref": "#/$defs/SeedNode" "$ref": "#/$defs/SeedNode"
}, },
@@ -6138,6 +6931,9 @@
}, },
{ {
"$ref": "#/$defs/SemanticModel" "$ref": "#/$defs/SemanticModel"
},
{
"$ref": "#/$defs/UnitTestDefinition"
} }
] ]
} }
@@ -6230,6 +7026,16 @@
"propertyNames": { "propertyNames": {
"type": "string" "type": "string"
} }
},
"unit_tests": {
"type": "object",
"description": "The unit tests defined in the project",
"additionalProperties": {
"$ref": "#/$defs/UnitTestDefinition"
},
"propertyNames": {
"type": "string"
}
} }
}, },
"additionalProperties": false, "additionalProperties": false,
@@ -6248,7 +7054,8 @@
"child_map", "child_map",
"group_map", "group_map",
"saved_queries", "saved_queries",
"semantic_models" "semantic_models",
"unit_tests"
] ]
} }
}, },

File diff suppressed because it is too large Load Diff

View File

@@ -31,22 +31,22 @@ MODELS__SCHEMA_YML = """
version: 2 version: 2
models: models:
- name: foo_alias - name: foo_alias
tests: data_tests:
- expect_value: - expect_value:
field: tablename field: tablename
value: foo value: foo
- name: ref_foo_alias - name: ref_foo_alias
tests: data_tests:
- expect_value: - expect_value:
field: tablename field: tablename
value: ref_foo_alias value: ref_foo_alias
- name: alias_in_project - name: alias_in_project
tests: data_tests:
- expect_value: - expect_value:
field: tablename field: tablename
value: project_alias value: project_alias
- name: alias_in_project_with_override - name: alias_in_project_with_override
tests: data_tests:
- expect_value: - expect_value:
field: tablename field: tablename
value: override_alias value: override_alias
@@ -128,12 +128,12 @@ MODELS_DUPE_CUSTOM_DATABASE__SCHEMA_YML = """
version: 2 version: 2
models: models:
- name: model_a - name: model_a
tests: data_tests:
- expect_value: - expect_value:
field: tablename field: tablename
value: duped_alias value: duped_alias
- name: model_b - name: model_b
tests: data_tests:
- expect_value: - expect_value:
field: tablename field: tablename
value: duped_alias value: duped_alias
@@ -161,17 +161,17 @@ MODELS_DUPE_CUSTOM_SCHEMA__SCHEMA_YML = """
version: 2 version: 2
models: models:
- name: model_a - name: model_a
tests: data_tests:
- expect_value: - expect_value:
field: tablename field: tablename
value: duped_alias value: duped_alias
- name: model_b - name: model_b
tests: data_tests:
- expect_value: - expect_value:
field: tablename field: tablename
value: duped_alias value: duped_alias
- name: model_c - name: model_c
tests: data_tests:
- expect_value: - expect_value:
field: tablename field: tablename
value: duped_alias value: duped_alias

View File

@@ -59,7 +59,7 @@ models:
- name: base - name: base
columns: columns:
- name: id - name: id
tests: data_tests:
- not_null - not_null
""" """
@@ -69,7 +69,7 @@ models:
- name: view_model - name: view_model
columns: columns:
- name: id - name: id
tests: data_tests:
- not_null - not_null
""" """
@@ -79,7 +79,7 @@ models:
- name: table_model - name: table_model
columns: columns:
- name: id - name: id
tests: data_tests:
- not_null - not_null
""" """

View File

@@ -22,7 +22,7 @@ models:
columns: columns:
- name: id - name: id
description: The user ID number description: The user ID number
tests: data_tests:
- unique - unique
- not_null - not_null
- name: first_name - name: first_name
@@ -33,7 +33,7 @@ models:
description: The user's IP address description: The user's IP address
- name: updated_at - name: updated_at
description: The last time this user's email was updated description: The last time this user's email was updated
tests: data_tests:
- test.nothing - test.nothing
- name: second_model - name: second_model

View File

@@ -99,7 +99,7 @@ schema_yml = """
version: 2 version: 2
models: models:
- name: model - name: model
tests: data_tests:
- is_type: - is_type:
column_map: column_map:
smallint_col: ['integer', 'number'] smallint_col: ['integer', 'number']

View File

@@ -279,7 +279,7 @@ models:
expression: (id > 0) expression: (id > 0)
- type: check - type: check
expression: id >= 1 expression: id >= 1
tests: data_tests:
- unique - unique
- name: color - name: color
data_type: text data_type: text
@@ -298,7 +298,7 @@ models:
- type: primary_key - type: primary_key
- type: check - type: check
expression: (id > 0) expression: (id > 0)
tests: data_tests:
- unique - unique
- name: color - name: color
data_type: text data_type: text
@@ -317,7 +317,7 @@ models:
- type: primary_key - type: primary_key
- type: check - type: check
expression: (id > 0) expression: (id > 0)
tests: data_tests:
- unique - unique
- name: color - name: color
data_type: text data_type: text
@@ -336,7 +336,7 @@ models:
- type: primary_key - type: primary_key
- type: check - type: check
expression: (id > 0) expression: (id > 0)
tests: data_tests:
- unique - unique
- name: color - name: color
data_type: text data_type: text
@@ -365,7 +365,7 @@ models:
- type: foreign_key - type: foreign_key
expression: {schema}.foreign_key_model (id) expression: {schema}.foreign_key_model (id)
- type: unique - type: unique
tests: data_tests:
- unique - unique
- name: color - name: color
data_type: text data_type: text
@@ -384,7 +384,7 @@ models:
- type: primary_key - type: primary_key
- type: check - type: check
expression: (id > 0) expression: (id > 0)
tests: data_tests:
- unique - unique
- name: color - name: color
data_type: text data_type: text
@@ -403,7 +403,7 @@ models:
- type: primary_key - type: primary_key
- type: check - type: check
expression: (id > 0) expression: (id > 0)
tests: data_tests:
- unique - unique
- name: color - name: color
data_type: text data_type: text
@@ -422,7 +422,7 @@ models:
- type: primary_key - type: primary_key
- type: check - type: check
expression: (id > 0) expression: (id > 0)
tests: data_tests:
- unique - unique
- name: color - name: color
data_type: text data_type: text
@@ -466,7 +466,7 @@ models:
description: hello description: hello
constraints: constraints:
- type: not_null - type: not_null
tests: data_tests:
- unique - unique
- name: color - name: color
data_type: text data_type: text
@@ -517,7 +517,7 @@ models:
description: hello description: hello
constraints: constraints:
- type: not_null - type: not_null
tests: data_tests:
- unique - unique
- name: from # reserved word - name: from # reserved word
quote: true quote: true

View File

@@ -40,7 +40,7 @@ models:
- name: view_model - name: view_model
columns: columns:
- name: id - name: id
tests: data_tests:
- unique: - unique:
severity: error severity: error
- not_null - not_null

View File

@@ -341,7 +341,7 @@ seeds:
- name: example_seed - name: example_seed
columns: columns:
- name: new_col - name: new_col
tests: data_tests:
- not_null - not_null
""" """
@@ -351,7 +351,7 @@ snapshots:
- name: example_snapshot - name: example_snapshot
columns: columns:
- name: new_col - name: new_col
tests: data_tests:
- not_null - not_null
""" """

View File

@@ -4,7 +4,7 @@ models:
- name: disabled - name: disabled
columns: columns:
- name: id - name: id
tests: data_tests:
- unique - unique
""" """

View File

@@ -54,46 +54,46 @@ seeds:
- name: seed_enabled - name: seed_enabled
columns: columns:
- name: birthday - name: birthday
tests: data_tests:
- column_type: - column_type:
type: date type: date
- name: seed_id - name: seed_id
tests: data_tests:
- column_type: - column_type:
type: text type: text
- name: seed_tricky - name: seed_tricky
columns: columns:
- name: seed_id - name: seed_id
tests: data_tests:
- column_type: - column_type:
type: integer type: integer
- name: seed_id_str - name: seed_id_str
tests: data_tests:
- column_type: - column_type:
type: text type: text
- name: a_bool - name: a_bool
tests: data_tests:
- column_type: - column_type:
type: boolean type: boolean
- name: looks_like_a_bool - name: looks_like_a_bool
tests: data_tests:
- column_type: - column_type:
type: text type: text
- name: a_date - name: a_date
tests: data_tests:
- column_type: - column_type:
type: timestamp without time zone type: timestamp without time zone
- name: looks_like_a_date - name: looks_like_a_date
tests: data_tests:
- column_type: - column_type:
type: text type: text
- name: relative - name: relative
tests: data_tests:
- column_type: - column_type:
type: text type: text
- name: weekday - name: weekday
tests: data_tests:
- column_type: - column_type:
type: text type: text
""" """

View File

@@ -132,7 +132,7 @@ models:
- name: chipmunks - name: chipmunks
columns: columns:
- name: name - name: name
tests: data_tests:
- not_null: - not_null:
store_failures_as: view store_failures_as: view
- accepted_values: - accepted_values:
@@ -143,7 +143,7 @@ models:
- simon - simon
- theodore - theodore
- name: shirt - name: shirt
tests: data_tests:
- not_null: - not_null:
store_failures: true store_failures: true
store_failures_as: view store_failures_as: view

View File

@@ -167,7 +167,7 @@ class StoreTestFailuresAsProjectLevelOff(StoreTestFailuresAsBase):
@pytest.fixture(scope="class") @pytest.fixture(scope="class")
def project_config_update(self): def project_config_update(self):
return {"tests": {"store_failures": False}} return {"data_tests": {"store_failures": False}}
def test_tests_run_successfully_and_are_stored_as_expected(self, project): def test_tests_run_successfully_and_are_stored_as_expected(self, project):
expected_results = { expected_results = {
@@ -204,7 +204,7 @@ class StoreTestFailuresAsProjectLevelView(StoreTestFailuresAsBase):
@pytest.fixture(scope="class") @pytest.fixture(scope="class")
def project_config_update(self): def project_config_update(self):
return {"tests": {"store_failures_as": "view"}} return {"data_tests": {"store_failures_as": "view"}}
def test_tests_run_successfully_and_are_stored_as_expected(self, project): def test_tests_run_successfully_and_are_stored_as_expected(self, project):
expected_results = { expected_results = {
@@ -242,7 +242,7 @@ class StoreTestFailuresAsProjectLevelEphemeral(StoreTestFailuresAsBase):
@pytest.fixture(scope="class") @pytest.fixture(scope="class")
def project_config_update(self): def project_config_update(self):
return {"tests": {"store_failures_as": "ephemeral", "store_failures": True}} return {"data_tests": {"store_failures_as": "ephemeral", "store_failures": True}}
def test_tests_run_successfully_and_are_stored_as_expected(self, project): def test_tests_run_successfully_and_are_stored_as_expected(self, project):
expected_results = { expected_results = {

View File

@@ -53,19 +53,19 @@ models:
- name: fine_model - name: fine_model
columns: columns:
- name: id - name: id
tests: data_tests:
- unique - unique
- not_null - not_null
- name: problematic_model - name: problematic_model
columns: columns:
- name: id - name: id
tests: data_tests:
- unique: - unique:
store_failures: true store_failures: true
- not_null - not_null
- name: first_name - name: first_name
tests: data_tests:
# test truncation of really long test name # test truncation of really long test name
- accepted_values: - accepted_values:
values: values:
@@ -83,7 +83,7 @@ models:
- name: fine_model_but_with_a_no_good_very_long_name - name: fine_model_but_with_a_no_good_very_long_name
columns: columns:
- name: quite_long_column_name - name: quite_long_column_name
tests: data_tests:
# test truncation of really long test name with builtin # test truncation of really long test name with builtin
- unique - unique
""" """

View File

@@ -68,7 +68,7 @@ class StoreTestFailuresBase:
"quote_columns": False, "quote_columns": False,
"test": self.column_type_overrides(), "test": self.column_type_overrides(),
}, },
"tests": {"+schema": TEST_AUDIT_SCHEMA_SUFFIX}, "data_tests": {"+schema": TEST_AUDIT_SCHEMA_SUFFIX},
} }
def column_type_overrides(self): def column_type_overrides(self):

View File

@@ -53,7 +53,7 @@ models__test_any_value_yml = """
version: 2 version: 2
models: models:
- name: test_any_value - name: test_any_value
tests: data_tests:
- assert_equal: - assert_equal:
actual: actual actual: actual
expected: expected expected: expected

View File

@@ -55,7 +55,7 @@ models__test_bool_or_yml = """
version: 2 version: 2
models: models:
- name: test_bool_or - name: test_bool_or
tests: data_tests:
- assert_equal: - assert_equal:
actual: actual actual: actual
expected: expected expected: expected

View File

@@ -22,7 +22,7 @@ models__test_cast_bool_to_text_yml = """
version: 2 version: 2
models: models:
- name: test_cast_bool_to_text - name: test_cast_bool_to_text
tests: data_tests:
- assert_equal: - assert_equal:
actual: actual actual: actual
expected: expected expected: expected

View File

@@ -38,7 +38,7 @@ models__test_concat_yml = """
version: 2 version: 2
models: models:
- name: test_concat - name: test_concat
tests: data_tests:
- assert_equal: - assert_equal:
actual: actual actual: actual
expected: expected expected: expected

View File

@@ -85,7 +85,7 @@ models__test_date_spine_yml = """
version: 2 version: 2
models: models:
- name: test_date_spine - name: test_date_spine
tests: data_tests:
- assert_equal: - assert_equal:
actual: date_day actual: date_day
expected: expected expected: expected

View File

@@ -33,7 +33,7 @@ models__test_date_trunc_yml = """
version: 2 version: 2
models: models:
- name: test_date_trunc - name: test_date_trunc
tests: data_tests:
- assert_equal: - assert_equal:
actual: actual actual: actual
expected: expected expected: expected

View File

@@ -33,7 +33,7 @@ models__test_dateadd_yml = """
version: 2 version: 2
models: models:
- name: test_dateadd - name: test_dateadd
tests: data_tests:
- assert_equal: - assert_equal:
actual: actual actual: actual
expected: expected expected: expected

View File

@@ -58,7 +58,7 @@ models__test_datediff_yml = """
version: 2 version: 2
models: models:
- name: test_datediff - name: test_datediff
tests: data_tests:
- assert_equal: - assert_equal:
actual: actual actual: actual
expected: expected expected: expected

View File

@@ -39,7 +39,7 @@ models__test_escape_single_quotes_yml = """
version: 2 version: 2
models: models:
- name: test_escape_single_quotes - name: test_escape_single_quotes
tests: data_tests:
- assert_equal: - assert_equal:
actual: actual actual: actual
expected: expected expected: expected

View File

@@ -38,7 +38,7 @@ models__test_generate_series_yml = """
version: 2 version: 2
models: models:
- name: test_generate_series - name: test_generate_series
tests: data_tests:
- assert_equal: - assert_equal:
actual: generated_number actual: generated_number
expected: expected expected: expected

View File

@@ -13,7 +13,7 @@ models__test_get_intervals_between_yml = """
version: 2 version: 2
models: models:
- name: test_get_intervals_between - name: test_get_intervals_between
tests: data_tests:
- assert_equal: - assert_equal:
actual: intervals actual: intervals
expected: expected expected: expected

View File

@@ -32,7 +32,7 @@ models__test_get_powers_of_two_yml = """
version: 2 version: 2
models: models:
- name: test_powers_of_two - name: test_powers_of_two
tests: data_tests:
- assert_equal: - assert_equal:
actual: actual actual: actual
expected: expected expected: expected

View File

@@ -37,7 +37,7 @@ models__test_hash_yml = """
version: 2 version: 2
models: models:
- name: test_hash - name: test_hash
tests: data_tests:
- assert_equal: - assert_equal:
actual: actual actual: actual
expected: expected expected: expected

View File

@@ -32,7 +32,7 @@ models__test_last_day_yml = """
version: 2 version: 2
models: models:
- name: test_last_day - name: test_last_day
tests: data_tests:
- assert_equal: - assert_equal:
actual: actual actual: actual
expected: expected expected: expected

View File

@@ -28,7 +28,7 @@ models__test_length_yml = """
version: 2 version: 2
models: models:
- name: test_length - name: test_length
tests: data_tests:
- assert_equal: - assert_equal:
actual: actual actual: actual
expected: expected expected: expected

Some files were not shown because too many files have changed in this diff Show More