Compare commits

...

1 Commits

Author SHA1 Message Date
MichelleArk
64a5b19c95 first pass - not working yet 2025-11-26 15:34:52 -05:00
4 changed files with 79 additions and 16 deletions

View File

@@ -335,6 +335,7 @@ class ConfiguredParser(
patch_config_dict=None, patch_config_dict=None,
patch_file_id=None, patch_file_id=None,
validate_config_call_dict: bool = False, validate_config_call_dict: bool = False,
patch_file_index: Optional[str] = None,
) -> None: ) -> None:
"""Given the ContextConfig used for parsing and the parsed node, """Given the ContextConfig used for parsing and the parsed node,
generate and set the true values to use, overriding the temporary parse generate and set the true values to use, overriding the temporary parse
@@ -413,8 +414,20 @@ class ConfiguredParser(
if patch_file and isinstance(patch_file, SchemaSourceFile): if patch_file and isinstance(patch_file, SchemaSourceFile):
schema_key = resource_types_to_schema_file_keys.get(parsed_node.resource_type) schema_key = resource_types_to_schema_file_keys.get(parsed_node.resource_type)
if schema_key: if schema_key:
lookup_name = parsed_node.name
lookup_version = getattr(parsed_node, "version", None)
# Test lookup needs to consider attached node and indexing
if (
parsed_node.resource_type == NodeType.Test
and hasattr(parsed_node, "attached_node") and parsed_node.attached_node
):
if attached_node := self.manifest.nodes.get(parsed_node.attached_node):
lookup_name = f"{attached_node.name}_{patch_file_index}"
lookup_version = getattr(attached_node, "version", None)
if unrendered_patch_config := patch_file.get_unrendered_config( if unrendered_patch_config := patch_file.get_unrendered_config(
schema_key, parsed_node.name, getattr(parsed_node, "version", None) schema_key, lookup_name, lookup_version
): ):
patch_config_dict = deep_merge( patch_config_dict = deep_merge(
patch_config_dict, unrendered_patch_config patch_config_dict, unrendered_patch_config

View File

@@ -35,6 +35,7 @@ schema_file_keys_to_resource_types = {
"semantic_models": NodeType.SemanticModel, "semantic_models": NodeType.SemanticModel,
"saved_queries": NodeType.SavedQuery, "saved_queries": NodeType.SavedQuery,
"functions": NodeType.Function, "functions": NodeType.Function,
"data_tests": NodeType.Test,
} }
resource_types_to_schema_file_keys = { resource_types_to_schema_file_keys = {
@@ -183,6 +184,7 @@ class GenericTestBlock(TestBlock[Testable], Generic[Testable]):
column_name: Optional[str] column_name: Optional[str]
tags: List[str] tags: List[str]
version: Optional[NodeVersion] version: Optional[NodeVersion]
test_index: int
@classmethod @classmethod
def from_test_block( def from_test_block(
@@ -192,6 +194,7 @@ class GenericTestBlock(TestBlock[Testable], Generic[Testable]):
column_name: Optional[str], column_name: Optional[str],
tags: List[str], tags: List[str],
version: Optional[NodeVersion], version: Optional[NodeVersion],
test_index: int,
) -> "GenericTestBlock": ) -> "GenericTestBlock":
return cls( return cls(
file=src.file, file=src.file,
@@ -201,6 +204,7 @@ class GenericTestBlock(TestBlock[Testable], Generic[Testable]):
column_name=column_name, column_name=column_name,
tags=tags, tags=tags,
version=version, version=version,
test_index=test_index,
) )

View File

@@ -81,8 +81,8 @@ class SchemaGenericTestParser(SimpleParser):
if not column.data_tests: if not column.data_tests:
return return
for data_test in column.data_tests: for data_test_idx, data_test in enumerate(column.data_tests):
self.parse_test(block, data_test, column, version) self.parse_test(block, data_test, column, version, data_test_idx)
def create_test_node( def create_test_node(
self, self,
@@ -161,6 +161,7 @@ class SchemaGenericTestParser(SimpleParser):
column_name: Optional[str], column_name: Optional[str],
schema_file_id: str, schema_file_id: str,
version: Optional[NodeVersion], version: Optional[NodeVersion],
test_index: Optional[int] = None,
) -> GenericTestNode: ) -> GenericTestNode:
try: try:
builder = TestBuilder( builder = TestBuilder(
@@ -233,7 +234,7 @@ class SchemaGenericTestParser(SimpleParser):
file_key_name=file_key_name, file_key_name=file_key_name,
description=builder.description, description=builder.description,
) )
self.render_test_update(node, config, builder, schema_file_id) self.render_test_update(node, config, builder, schema_file_id, test_index)
return node return node
@@ -278,18 +279,33 @@ class SchemaGenericTestParser(SimpleParser):
# In the future we will look at generalizing this # In the future we will look at generalizing this
# more to handle additional macros or to use static # more to handle additional macros or to use static
# parsing to avoid jinja overhead. # parsing to avoid jinja overhead.
def render_test_update(self, node, config, builder, schema_file_id): def render_test_update(self, node, config, builder, schema_file_id, test_index: int):
macro_unique_id = self.macro_resolver.get_macro_id( macro_unique_id = self.macro_resolver.get_macro_id(
node.package_name, "test_" + builder.name node.package_name, "test_" + builder.name
) )
# Add the depends_on here so we can limit the macros added # Add the depends_on here so we can limit the macros added
# to the context in rendering processing # to the context in rendering processing
node.depends_on.add_macro(macro_unique_id) node.depends_on.add_macro(macro_unique_id)
# Set attached_node for generic test nodes, if available.
# Generic test node inherits attached node's group config value.
attached_node = self._lookup_attached_node(builder.target, builder.version)
if attached_node:
node.attached_node = attached_node.unique_id
node.group, node.group = attached_node.group, attached_node.group
# Index for lookups on patch file, used when setting unrendered_config for tests
patch_file_index = (
f"{node.column_name}_{test_index}" if node.column_name else str(test_index)
)
if macro_unique_id in ["macro.dbt.test_not_null", "macro.dbt.test_unique"]: if macro_unique_id in ["macro.dbt.test_not_null", "macro.dbt.test_unique"]:
config_call_dict = builder.config config_call_dict = builder.config
config._config_call_dict = config_call_dict config._config_call_dict = config_call_dict
# This sets the config from dbt_project # This sets the config from dbt_project
self.update_parsed_node_config(node, config) self.update_parsed_node_config(
node, config, patch_file_id=schema_file_id, patch_file_index=patch_file_index
)
# source node tests are processed at patch_source time # source node tests are processed at patch_source time
if isinstance(builder.target, UnpatchedSourceDefinition): if isinstance(builder.target, UnpatchedSourceDefinition):
sources = [builder.target.fqn[-2], builder.target.fqn[-1]] sources = [builder.target.fqn[-2], builder.target.fqn[-1]]
@@ -312,19 +328,14 @@ class SchemaGenericTestParser(SimpleParser):
add_rendered_test_kwargs(context, node, capture_macros=True) add_rendered_test_kwargs(context, node, capture_macros=True)
# the parsed node is not rendered in the native context. # the parsed node is not rendered in the native context.
get_rendered(node.raw_code, context, node, capture_macros=True) get_rendered(node.raw_code, context, node, capture_macros=True)
self.update_parsed_node_config(node, config) self.update_parsed_node_config(
node, config, patch_file_id=schema_file_id, patch_file_index=patch_file_index
)
# env_vars should have been updated in the context env_var method # env_vars should have been updated in the context env_var method
except ValidationError as exc: except ValidationError as exc:
# we got a ValidationError - probably bad types in config() # we got a ValidationError - probably bad types in config()
raise SchemaConfigError(exc, node=node) from exc raise SchemaConfigError(exc, node=node) from exc
# Set attached_node for generic test nodes, if available.
# Generic test node inherits attached node's group config value.
attached_node = self._lookup_attached_node(builder.target, builder.version)
if attached_node:
node.attached_node = attached_node.unique_id
node.group, node.group = attached_node.group, attached_node.group
def parse_node(self, block: GenericTestBlock) -> GenericTestNode: def parse_node(self, block: GenericTestBlock) -> GenericTestNode:
"""In schema parsing, we rewrite most of the part of parse_node that """In schema parsing, we rewrite most of the part of parse_node that
builds the initial node to be parsed, but rendering is basically the builds the initial node to be parsed, but rendering is basically the
@@ -337,6 +348,7 @@ class SchemaGenericTestParser(SimpleParser):
column_name=block.column_name, column_name=block.column_name,
schema_file_id=block.file.file_id, schema_file_id=block.file.file_id,
version=block.version, version=block.version,
test_index=block.test_index,
) )
self.add_test_node(block, node) self.add_test_node(block, node)
return node return node
@@ -371,6 +383,7 @@ class SchemaGenericTestParser(SimpleParser):
data_test: TestDef, data_test: TestDef,
column: Optional[UnparsedColumn], column: Optional[UnparsedColumn],
version: Optional[NodeVersion], version: Optional[NodeVersion],
test_index: int,
) -> None: ) -> None:
if isinstance(data_test, str): if isinstance(data_test, str):
data_test = {data_test: {}} data_test = {data_test: {}}
@@ -395,15 +408,17 @@ class SchemaGenericTestParser(SimpleParser):
column_name=column_name, column_name=column_name,
tags=column_tags, tags=column_tags,
version=version, version=version,
test_index=test_index,
) )
self.parse_node(block) self.parse_node(block)
def parse_tests(self, block: TestBlock) -> None: def parse_tests(self, block: TestBlock) -> None:
# TODO: plumb indexing here
for column in block.columns: for column in block.columns:
self.parse_column_tests(block, column, None) self.parse_column_tests(block, column, None)
for data_test in block.data_tests: for data_test_idx, data_test in enumerate(block.data_tests):
self.parse_test(block, data_test, None, None) self.parse_test(block, data_test, None, None, data_test_idx)
def parse_versioned_tests(self, block: VersionedTestBlock) -> None: def parse_versioned_tests(self, block: VersionedTestBlock) -> None:
if not block.target.versions: if not block.target.versions:

View File

@@ -466,6 +466,31 @@ class YamlReader(metaclass=ABCMeta):
if "config" in entry: if "config" in entry:
unrendered_config = entry["config"] unrendered_config = entry["config"]
unrendered_data_test_configs = {}
if "data_tests" in entry:
for data_test_idx, data_test in enumerate(entry["data_tests"]):
if isinstance(data_test, dict) and len(data_test):
data_test_definition = list(data_test.values())[0]
if isinstance(data_test_definition, dict) and data_test_definition.get(
"config"
):
unrendered_data_test_configs[f"{entry['name']}_{data_test_idx}"] = (
data_test_definition["config"]
)
if "columns" in entry:
for column in entry["columns"]:
if isinstance(column, dict) and column.get("data_tests"):
for data_test_idx, data_test in enumerate(column["data_tests"]):
if isinstance(data_test, dict) and len(data_test) == 1:
data_test_definition = list(data_test.values())[0]
if isinstance(
data_test_definition, dict
) and data_test_definition.get("config"):
unrendered_data_test_configs[
f"{entry['name']}_{column['name']}_{data_test_idx}"
] = data_test_definition["config"]
unrendered_version_configs = {} unrendered_version_configs = {}
if "versions" in entry: if "versions" in entry:
for version in entry["versions"]: for version in entry["versions"]:
@@ -486,6 +511,12 @@ class YamlReader(metaclass=ABCMeta):
if unrendered_config: if unrendered_config:
schema_file.add_unrendered_config(unrendered_config, self.key, entry["name"]) schema_file.add_unrendered_config(unrendered_config, self.key, entry["name"])
for test_name, unrendered_data_test_config in unrendered_data_test_configs.items():
print(f"ADD: {test_name}")
schema_file.add_unrendered_config(
unrendered_data_test_config, "data_tests", test_name
)
for version, unrendered_version_config in unrendered_version_configs.items(): for version, unrendered_version_config in unrendered_version_configs.items():
schema_file.add_unrendered_config( schema_file.add_unrendered_config(
unrendered_version_config, self.key, entry["name"], version unrendered_version_config, self.key, entry["name"], version