Compare commits

...

20 Commits

Author SHA1 Message Date
Kshitij Aranke
423ffe7517 Add runtime prefix 2024-09-11 12:47:48 +01:00
Kshitij Aranke
f2215ef511 rm: run_started_at, selected_resources 2024-09-10 17:05:47 +01:00
Kshitij Aranke
fc6c58c98b rm: run_started_at, selected_resources 2024-09-10 16:57:51 +01:00
Kshitij Aranke
976c538e49 rm: leaf 2024-09-10 16:51:49 +01:00
Kshitij Aranke
b8d1f635a6 rm: INVOCATION_COMMAND 2024-09-10 16:48:39 +01:00
Kshitij Aranke
231b61f51c test: leaf_contexts_set 2024-09-10 15:34:44 +01:00
Kshitij Aranke
792349aeb6 Merge branch 'main' into fix_core_399 2024-09-10 14:18:48 +01:00
Kshitij Aranke
ebb587602e remove typing.NoneType 2024-09-10 14:18:28 +01:00
Kshitij Aranke
ce29b68d92 get model and docs runtime tests working 2024-09-10 14:11:19 +01:00
Kshitij Aranke
d315f01c3e COMMON_FLAGS_INVOCATION_ARGS 2024-09-10 11:55:48 +01:00
Kshitij Aranke
a88f0f6d69 common_builtins 2024-09-10 11:38:07 +01:00
Kshitij Aranke
5aaa10254d get_module_exports 2024-09-10 11:28:13 +01:00
Kshitij Aranke
76edf5569f add_prefix 2024-09-09 18:56:25 +01:00
Kshitij Aranke
e3e7e3d844 expand_builtins3 2024-09-09 18:03:11 +01:00
Kshitij Aranke
eeb3e247b8 expand_builtins2 2024-09-09 17:47:40 +01:00
Kshitij Aranke
1926c28b20 expand_builtins 2024-09-09 17:18:04 +01:00
Kshitij Aranke
2cd6b2959f bring out model flags 2024-09-09 15:06:51 +01:00
Kshitij Aranke
9f5dea302e simplify pytz pt2 2024-09-09 15:04:26 +01:00
Kshitij Aranke
61184aa931 simplify pytz pt1 2024-09-09 14:59:52 +01:00
Kshitij Aranke
179ef6bda3 First pass: expected model context 2024-09-09 14:49:18 +01:00

View File

@@ -1,14 +1,20 @@
import importlib
import os
from typing import Any, Dict, Set
import re
from argparse import Namespace
from copy import deepcopy
from typing import Any, Dict, Mapping, Optional, Set
from unittest import mock
import pytest
import pytz
import dbt_common.exceptions
from dbt.adapters import factory, postgres
from dbt.clients.jinja import MacroStack
from dbt.config.project import VarProvider
from dbt.context import base, docs, macros, providers, query_header
from dbt.context.base import Var
from dbt.contracts.files import FileHash
from dbt.contracts.graph.nodes import (
DependsOn,
@@ -20,6 +26,7 @@ from dbt.contracts.graph.nodes import (
)
from dbt.node_types import NodeType
from dbt_common.events.functions import reset_metadata_vars
from dbt_common.helper_types import WarnErrorOptions
from tests.unit.mock_adapter import adapter_factory
from tests.unit.utils import clear_plugin, config_from_parts_or_dicts, inject_adapter
@@ -277,6 +284,352 @@ PROJECT_DATA = {
}
def clean_value(value):
if isinstance(value, set):
return set(value)
elif isinstance(value, Namespace):
return value.__dict__
elif isinstance(value, Var):
return {k: v for k, v in value._merged.items()}
elif isinstance(value, bool):
return value
elif value is None:
return None
elif isinstance(value, int):
return value
else:
value_str = str(value)
value_str = re.sub(r" at 0x[0-9a-fA-F]+>", ">", value_str)
value_str = re.sub(r" id='[0-9]+'>", ">", value_str)
return value_str
def walk_dict(dictionary):
skip_paths = [
["invocation_id"],
["builtins", "invocation_id"],
["dbt_version"],
["builtins", "dbt_version"],
["invocation_args_dict", "invocation_command"],
["run_started_at"],
["builtins", "run_started_at"],
["selected_resources"],
["builtins", "selected_resources"],
]
stack = [(dictionary, [])]
visited = set() # Set to keep track of visited dictionary objects
while stack:
current_dict, path = stack.pop(0)
if id(current_dict) in visited:
continue
visited.add(id(current_dict))
for key, value in current_dict.items():
current_path = path + [key]
if isinstance(value, Mapping):
stack.append((value, current_path))
else:
if current_path not in skip_paths:
cv = clean_value(value)
if current_path == ["flags"]:
del cv["INVOCATION_COMMAND"]
yield (tuple(current_path), cv)
def add_prefix(path_dict, prefix):
return {prefix + k: v for k, v in path_dict.items()}
def get_module_exports(module_name: str, filter_set: Optional[Set[str]] = None):
module = importlib.import_module(module_name)
export_names = filter_set or module.__all__
return {
("modules", module_name, export): clean_value(getattr(module, export))
for export in export_names
}
PYTZ_COUNTRY_TIMEZONES = {
("modules", "pytz", "country_timezones", country_code): str(timezones)
for country_code, timezones in pytz.country_timezones.items()
}
PYTZ_COUNTRY_NAMES = {
("modules", "pytz", "country_names", country_code): country_name
for country_code, country_name in pytz.country_names.items()
}
COMMON_FLAGS_INVOCATION_ARGS = {
"CACHE_SELECTED_ONLY": False,
"LOG_FORMAT": "default",
"LOG_PATH": "logs",
"SEND_ANONYMOUS_USAGE_STATS": True,
"INDIRECT_SELECTION": "eager",
"INTROSPECT": True,
"PARTIAL_PARSE": True,
"PRINTER_WIDTH": 80,
"QUIET": False,
"STATIC_PARSER": True,
"USE_COLORS": True,
"VERSION_CHECK": True,
"WRITE_JSON": True,
}
COMMON_FLAGS = {
**COMMON_FLAGS_INVOCATION_ARGS,
"LOG_CACHE_EVENTS": False,
"FAIL_FAST": False,
"DEBUG": False,
"WARN_ERROR": None,
"WARN_ERROR_OPTIONS": WarnErrorOptions(include=[], exclude=[]),
"USE_EXPERIMENTAL_PARSER": False,
"NO_PRINT": None,
"PROFILES_DIR": None,
"TARGET_PATH": None,
"EMPTY": None,
"FULL_REFRESH": False,
"STORE_FAILURES": False,
"WHICH": "run",
}
COMMON_BUILTINS = {
("diff_of_two_dicts",): "<function BaseContext.diff_of_two_dicts>",
("flags",): COMMON_FLAGS,
("fromjson",): "<function BaseContext.fromjson>",
("fromyaml",): "<function BaseContext.fromyaml>",
("local_md5",): "<function BaseContext.local_md5>",
("log",): "<function BaseContext.log>",
("print",): "<function BaseContext.print>",
("project_name",): "root",
("return",): "<function BaseContext._return>",
("set",): "<function BaseContext._set>",
("set_strict",): "<function BaseContext.set_strict>",
("thread_id",): "MainThread",
("tojson",): "<function BaseContext.tojson>",
("toyaml",): "<function BaseContext.toyaml>",
("var",): {},
("zip",): "<function BaseContext._zip>",
("zip_strict",): "<function BaseContext.zip_strict>",
}
COMMON_RUNTIME_CONTEXT = {
**COMMON_BUILTINS,
**add_prefix(COMMON_BUILTINS, ("builtins",)),
("target", "host"): "localhost",
("target", "port"): 1,
("target", "user"): "test",
("target", "database"): "test",
("target", "schema"): "analytics",
("target", "connect_timeout"): 10,
("target", "role"): None,
("target", "search_path"): None,
("target", "keepalives_idle"): 0,
("target", "sslmode"): None,
("target", "sslcert"): None,
("target", "sslkey"): None,
("target", "sslrootcert"): None,
("target", "application_name"): "dbt",
("target", "retries"): 1,
("target", "dbname"): "test",
("target", "type"): "postgres",
("target", "threads"): 1,
("target", "name"): "test",
("target", "target_name"): "test",
("target", "profile_name"): "test",
**get_module_exports("datetime", {"date", "datetime", "time", "timedelta", "tzinfo"}),
**get_module_exports("re"),
**get_module_exports(
"itertools",
{
"count",
"cycle",
"repeat",
"accumulate",
"chain",
"compress",
"islice",
"starmap",
"tee",
"zip_longest",
"product",
"permutations",
"combinations",
"combinations_with_replacement",
},
),
("modules", "pytz", "timezone"): "<function timezone>",
("modules", "pytz", "utc"): "UTC",
("modules", "pytz", "AmbiguousTimeError"): "<class 'pytz.exceptions.AmbiguousTimeError'>",
("modules", "pytz", "InvalidTimeError"): "<class 'pytz.exceptions.InvalidTimeError'>",
("modules", "pytz", "NonExistentTimeError"): "<class 'pytz.exceptions.NonExistentTimeError'>",
("modules", "pytz", "UnknownTimeZoneError"): "<class 'pytz.exceptions.UnknownTimeZoneError'>",
("modules", "pytz", "all_timezones"): str(pytz.all_timezones),
("modules", "pytz", "all_timezones_set"): set(pytz.all_timezones_set),
("modules", "pytz", "common_timezones"): str(pytz.common_timezones),
("modules", "pytz", "common_timezones_set"): set(),
("modules", "pytz", "BaseTzInfo"): "<class 'pytz.tzinfo.BaseTzInfo'>",
("modules", "pytz", "FixedOffset"): "<function FixedOffset>",
**PYTZ_COUNTRY_TIMEZONES,
**PYTZ_COUNTRY_NAMES,
}
MODEL_BUILTINS = {
("adapter",): "<dbt.context.providers.RuntimeDatabaseWrapper object>",
(
"adapter_macro",
): "<bound method ProviderContext.adapter_macro of <dbt.context.providers.ModelContext object>>",
("column",): "<MagicMock name='get_adapter().Column'>",
("compiled_code",): "<MagicMock name='model_one.compiled_code'>",
("config",): "<dbt.context.providers.RuntimeConfigObject object>",
("context_macro_stack",): "<dbt.clients.jinja.MacroStack object>",
("database",): "dbt",
("defer_relation",): "<MagicMock name='get_adapter().Relation.create_from()'>",
(
"env_var",
): "<bound method ProviderContext.env_var of <dbt.context.providers.ModelContext object>>",
("execute",): True,
("graph",): "<MagicMock name='mock.flat_graph'>",
(
"load_agate_table",
): "<bound method ProviderContext.load_agate_table of <dbt.context.providers.ModelContext object>>",
(
"load_result",
): "<bound method ProviderContext.load_result of <dbt.context.providers.ModelContext object>>",
("metric",): "<dbt.context.providers.RuntimeMetricResolver object>",
("model",): "<MagicMock name='model_one.to_dict()'>",
("post_hooks",): "[]",
("pre_hooks",): "[]",
("ref",): "<dbt.context.providers.RuntimeRefResolver object>",
(
"render",
): "<bound method ProviderContext.render of <dbt.context.providers.ModelContext object>>",
("schema",): "analytics",
("source",): "<dbt.context.providers.RuntimeSourceResolver object>",
("sql",): "<MagicMock name='model_one.compiled_code'>",
("sql_now",): "<MagicMock name='get_adapter().date_function()'>",
(
"store_raw_result",
): "<bound method ProviderContext.store_raw_result of <dbt.context.providers.ModelContext object>>",
(
"store_result",
): "<bound method ProviderContext.store_result of <dbt.context.providers.ModelContext object>>",
(
"submit_python_job",
): "<bound method ProviderContext.submit_python_job of <dbt.context.providers.ModelContext object>>",
("this",): "<MagicMock name='get_adapter().Relation.create_from()'>",
(
"try_or_compiler_error",
): "<bound method ProviderContext.try_or_compiler_error of <dbt.context.providers.ModelContext object>>",
(
"write",
): "<bound method ProviderContext.write of <dbt.context.providers.ModelContext object>>",
}
MODEL_RUNTIME_BUILTINS = {
**MODEL_BUILTINS,
}
MODEL_EXCEPTIONS = {
("exceptions", "warn"): "<function warn>",
("exceptions", "missing_config"): "<function missing_config>",
("exceptions", "missing_materialization"): "<function missing_materialization>",
("exceptions", "missing_relation"): "<function missing_relation>",
("exceptions", "raise_ambiguous_alias"): "<function raise_ambiguous_alias>",
("exceptions", "raise_ambiguous_catalog_match"): "<function raise_ambiguous_catalog_match>",
("exceptions", "raise_cache_inconsistent"): "<function raise_cache_inconsistent>",
("exceptions", "raise_dataclass_not_dict"): "<function raise_dataclass_not_dict>",
("exceptions", "raise_compiler_error"): "<function raise_compiler_error>",
("exceptions", "raise_database_error"): "<function raise_database_error>",
("exceptions", "raise_dep_not_found"): "<function raise_dep_not_found>",
("exceptions", "raise_dependency_error"): "<function raise_dependency_error>",
("exceptions", "raise_duplicate_patch_name"): "<function raise_duplicate_patch_name>",
("exceptions", "raise_duplicate_resource_name"): "<function raise_duplicate_resource_name>",
(
"exceptions",
"raise_invalid_property_yml_version",
): "<function raise_invalid_property_yml_version>",
("exceptions", "raise_not_implemented"): "<function raise_not_implemented>",
("exceptions", "relation_wrong_type"): "<function relation_wrong_type>",
("exceptions", "raise_contract_error"): "<function raise_contract_error>",
("exceptions", "column_type_missing"): "<function column_type_missing>",
("exceptions", "raise_fail_fast_error"): "<function raise_fail_fast_error>",
(
"exceptions",
"warn_snapshot_timestamp_data_types",
): "<function warn_snapshot_timestamp_data_types>",
}
MODEL_MACROS = {
("macro_a",): "<dbt.clients.jinja.MacroGenerator object>",
("macro_b",): "<dbt.clients.jinja.MacroGenerator object>",
}
EXPECTED_MODEL_RUNTIME_CONTEXT = deepcopy(
{
**COMMON_RUNTIME_CONTEXT,
**MODEL_RUNTIME_BUILTINS,
**add_prefix(MODEL_RUNTIME_BUILTINS, ("builtins",)),
**MODEL_MACROS,
**add_prefix(MODEL_MACROS, ("root",)),
**add_prefix(
{(k.lower(),): v for k, v in COMMON_FLAGS_INVOCATION_ARGS.items()},
("invocation_args_dict",),
),
("invocation_args_dict", "profile_dir"): "/dev/null",
("invocation_args_dict", "warn_error_options", "include"): "[]",
("invocation_args_dict", "warn_error_options", "exclude"): "[]",
**MODEL_EXCEPTIONS,
("api", "Column"): "<MagicMock name='get_adapter().Column'>",
("api", "Relation"): "<dbt.context.providers.RelationProxy object>",
("validation", "any"): "<function ProviderContext.validation.<locals>.validate_any>",
}
)
EXPECTED_MODEL_RUNTIME_CONTEXT = deepcopy(
{
**COMMON_RUNTIME_CONTEXT,
**MODEL_RUNTIME_BUILTINS,
**add_prefix(MODEL_RUNTIME_BUILTINS, ("builtins",)),
**MODEL_MACROS,
**add_prefix(MODEL_MACROS, ("root",)),
**add_prefix(
{(k.lower(),): v for k, v in COMMON_FLAGS_INVOCATION_ARGS.items()},
("invocation_args_dict",),
),
("invocation_args_dict", "profile_dir"): "/dev/null",
("invocation_args_dict", "warn_error_options", "include"): "[]",
("invocation_args_dict", "warn_error_options", "exclude"): "[]",
**MODEL_EXCEPTIONS,
("api", "Column"): "<MagicMock name='get_adapter().Column'>",
("api", "Relation"): "<dbt.context.providers.RelationProxy object>",
("validation", "any"): "<function ProviderContext.validation.<locals>.validate_any>",
}
)
DOCS_BUILTINS = {
("doc",): "<bound method DocsRuntimeContext.doc of "
"<dbt.context.docs.DocsRuntimeContext object>>",
("env_var",): "<bound method SchemaYamlContext.env_var of "
"<dbt.context.docs.DocsRuntimeContext object>>",
}
EXPECTED_DOCS_RUNTIME_CONTEXT = deepcopy(
{
**COMMON_RUNTIME_CONTEXT,
**DOCS_BUILTINS,
**add_prefix(DOCS_BUILTINS, ("builtins",)),
}
)
def model():
return ModelNode(
alias="model_one",
@@ -475,7 +828,8 @@ def test_model_parse_context(config_postgres, manifest_fx, get_adapter, get_incl
manifest=manifest_fx,
context_config=mock.MagicMock(),
)
assert_has_keys(REQUIRED_MODEL_KEYS, MAYBE_KEYS, ctx)
actual_model_context = {k: v for (k, v) in walk_dict(ctx)}
assert actual_model_context == EXPECTED_MODEL_RUNTIME_CONTEXT
def test_model_runtime_context(config_postgres, manifest_fx, get_adapter, get_include_paths):
@@ -484,12 +838,14 @@ def test_model_runtime_context(config_postgres, manifest_fx, get_adapter, get_in
config=config_postgres,
manifest=manifest_fx,
)
assert_has_keys(REQUIRED_MODEL_KEYS, MAYBE_KEYS, ctx)
actual_model_context = {k: v for (k, v) in walk_dict(ctx)}
assert actual_model_context == EXPECTED_MODEL_RUNTIME_CONTEXT
def test_docs_runtime_context(config_postgres):
ctx = docs.generate_runtime_docs_context(config_postgres, mock_model(), [], "root")
assert_has_keys(REQUIRED_DOCS_KEYS, MAYBE_KEYS, ctx)
actual_docs_runtime_context = {k: v for (k, v) in walk_dict(ctx)}
assert actual_docs_runtime_context == EXPECTED_DOCS_RUNTIME_CONTEXT
def test_macro_namespace_duplicates(config_postgres, manifest_fx):