forked from repo-mirrors/dbt-core
406 lines
11 KiB
Python
406 lines
11 KiB
Python
"""Unit test utility functions.
|
|
|
|
Note that all imports should be inside the functions to avoid import/mocking
|
|
issues.
|
|
"""
|
|
|
|
import os
|
|
import string
|
|
from unittest import TestCase, mock
|
|
|
|
import agate
|
|
import pytest
|
|
|
|
from dbt.config.project import PartialProject
|
|
from dbt.contracts.graph.manifest import Manifest
|
|
from dbt_common.dataclass_schema import ValidationError
|
|
|
|
|
|
def normalize(path):
|
|
"""On windows, neither is enough on its own:
|
|
|
|
>>> normcase('C:\\documents/ALL CAPS/subdir\\..')
|
|
'c:\\documents\\all caps\\subdir\\..'
|
|
>>> normpath('C:\\documents/ALL CAPS/subdir\\..')
|
|
'C:\\documents\\ALL CAPS'
|
|
>>> normpath(normcase('C:\\documents/ALL CAPS/subdir\\..'))
|
|
'c:\\documents\\all caps'
|
|
"""
|
|
return os.path.normcase(os.path.normpath(path))
|
|
|
|
|
|
class Obj:
|
|
which = "blah"
|
|
single_threaded = False
|
|
|
|
|
|
def mock_connection(name, state="open"):
|
|
conn = mock.MagicMock()
|
|
conn.name = name
|
|
conn.state = state
|
|
return conn
|
|
|
|
|
|
def profile_from_dict(profile, profile_name, cli_vars="{}"):
|
|
from dbt.config import Profile
|
|
from dbt.config.renderer import ProfileRenderer
|
|
from dbt.config.utils import parse_cli_vars
|
|
|
|
if not isinstance(cli_vars, dict):
|
|
cli_vars = parse_cli_vars(cli_vars)
|
|
|
|
renderer = ProfileRenderer(cli_vars)
|
|
return Profile.from_raw_profile_info(
|
|
profile,
|
|
profile_name,
|
|
renderer,
|
|
)
|
|
|
|
|
|
def project_from_dict(project, profile, packages=None, selectors=None, cli_vars="{}"):
|
|
from dbt.config.renderer import DbtProjectYamlRenderer
|
|
from dbt.config.utils import parse_cli_vars
|
|
|
|
if not isinstance(cli_vars, dict):
|
|
cli_vars = parse_cli_vars(cli_vars)
|
|
|
|
renderer = DbtProjectYamlRenderer(profile, cli_vars)
|
|
|
|
project_root = project.pop("project-root", os.getcwd())
|
|
|
|
partial = PartialProject.from_dicts(
|
|
project_root=project_root,
|
|
project_dict=project,
|
|
packages_dict=packages,
|
|
selectors_dict=selectors,
|
|
)
|
|
return partial.render(renderer)
|
|
|
|
|
|
def config_from_parts_or_dicts(project, profile, packages=None, selectors=None, cli_vars={}):
|
|
from copy import deepcopy
|
|
|
|
from dbt.config import Profile, Project, RuntimeConfig
|
|
|
|
if isinstance(project, Project):
|
|
profile_name = project.profile_name
|
|
else:
|
|
profile_name = project.get("profile")
|
|
|
|
if not isinstance(profile, Profile):
|
|
profile = profile_from_dict(
|
|
deepcopy(profile),
|
|
profile_name,
|
|
cli_vars,
|
|
)
|
|
|
|
if not isinstance(project, Project):
|
|
project = project_from_dict(
|
|
deepcopy(project),
|
|
profile,
|
|
packages,
|
|
selectors,
|
|
cli_vars,
|
|
)
|
|
|
|
args = Obj()
|
|
args.vars = cli_vars
|
|
args.profile_dir = "/dev/null"
|
|
return RuntimeConfig.from_parts(project=project, profile=profile, args=args)
|
|
|
|
|
|
def inject_plugin(plugin):
|
|
from dbt.adapters.factory import FACTORY
|
|
|
|
key = plugin.adapter.type()
|
|
FACTORY.plugins[key] = plugin
|
|
|
|
|
|
def inject_plugin_for(config):
|
|
# from dbt.adapters.postgres import Plugin, PostgresAdapter
|
|
from dbt.adapters.factory import FACTORY
|
|
|
|
FACTORY.load_plugin(config.credentials.type)
|
|
adapter = FACTORY.get_adapter(config)
|
|
return adapter
|
|
|
|
|
|
def inject_adapter(value, plugin):
|
|
"""Inject the given adapter into the adapter factory, so your hand-crafted
|
|
artisanal adapter will be available from get_adapter() as if dbt loaded it.
|
|
"""
|
|
inject_plugin(plugin)
|
|
from dbt.adapters.factory import FACTORY
|
|
|
|
key = value.type()
|
|
FACTORY.adapters[key] = value
|
|
|
|
|
|
def clear_plugin(plugin):
|
|
from dbt.adapters.factory import FACTORY
|
|
|
|
key = plugin.adapter.type()
|
|
FACTORY.plugins.pop(key, None)
|
|
FACTORY.adapters.pop(key, None)
|
|
|
|
|
|
class ContractTestCase(TestCase):
|
|
ContractType = None
|
|
|
|
def setUp(self):
|
|
self.maxDiff = None
|
|
super().setUp()
|
|
|
|
def assert_to_dict(self, obj, dct):
|
|
self.assertEqual(obj.to_dict(omit_none=True), dct)
|
|
|
|
def assert_from_dict(self, obj, dct, cls=None):
|
|
if cls is None:
|
|
cls = self.ContractType
|
|
cls.validate(dct)
|
|
self.assertEqual(cls.from_dict(dct), obj)
|
|
|
|
def assert_symmetric(self, obj, dct, cls=None):
|
|
self.assert_to_dict(obj, dct)
|
|
self.assert_from_dict(obj, dct, cls)
|
|
|
|
def assert_fails_validation(self, dct, cls=None):
|
|
if cls is None:
|
|
cls = self.ContractType
|
|
|
|
with self.assertRaises(ValidationError):
|
|
cls.validate(dct)
|
|
cls.from_dict(dct)
|
|
|
|
|
|
def compare_dicts(dict1, dict2):
|
|
first_set = set(dict1.keys())
|
|
second_set = set(dict2.keys())
|
|
print(f"--- Difference between first and second keys: {first_set.difference(second_set)}")
|
|
print(f"--- Difference between second and first keys: {second_set.difference(first_set)}")
|
|
common_keys = set(first_set).intersection(set(second_set))
|
|
found_differences = False
|
|
for key in common_keys:
|
|
if dict1[key] != dict2[key]:
|
|
print(f"--- --- first dict: {key}: {str(dict1[key])}")
|
|
print(f"--- --- second dict: {key}: {str(dict2[key])}")
|
|
found_differences = True
|
|
if found_differences:
|
|
print("--- Found differences in dictionaries")
|
|
else:
|
|
print("--- Found no differences in dictionaries")
|
|
|
|
|
|
def assert_from_dict(obj, dct, cls=None):
|
|
if cls is None:
|
|
cls = obj.__class__
|
|
cls.validate(dct)
|
|
|
|
obj_from_dict = cls.from_dict(dct)
|
|
|
|
if hasattr(obj, "created_at"):
|
|
obj_from_dict.created_at = 1
|
|
obj.created_at = 1
|
|
|
|
assert obj_from_dict == obj
|
|
|
|
|
|
def assert_to_dict(obj, dct):
|
|
obj_to_dict = obj.to_dict(omit_none=True)
|
|
if "created_at" in obj_to_dict:
|
|
obj_to_dict["created_at"] = 1
|
|
if "created_at" in dct:
|
|
dct["created_at"] = 1
|
|
if obj_to_dict != dct:
|
|
compare_dicts(obj_to_dict, dct)
|
|
assert obj_to_dict == dct
|
|
|
|
|
|
def assert_symmetric(obj, dct, cls=None):
|
|
assert_to_dict(obj, dct)
|
|
assert_from_dict(obj, dct, cls)
|
|
|
|
|
|
def assert_fails_validation(dct, cls):
|
|
with pytest.raises(ValidationError):
|
|
cls.validate(dct)
|
|
cls.from_dict(dct)
|
|
|
|
|
|
def generate_name_macros(package):
|
|
from dbt.contracts.graph.nodes import Macro
|
|
from dbt.node_types import NodeType
|
|
|
|
name_sql = {}
|
|
for component in ("database", "schema", "alias"):
|
|
if component == "alias":
|
|
source = "node.name"
|
|
else:
|
|
source = f"target.{component}"
|
|
name = f"generate_{component}_name"
|
|
sql = f"{{% macro {name}(value, node) %}} {{% if value %}} {{{{ value }}}} {{% else %}} {{{{ {source} }}}} {{% endif %}} {{% endmacro %}}"
|
|
name_sql[name] = sql
|
|
|
|
for name, sql in name_sql.items():
|
|
pm = Macro(
|
|
name=name,
|
|
resource_type=NodeType.Macro,
|
|
unique_id=f"macro.{package}.{name}",
|
|
package_name=package,
|
|
original_file_path=normalize("macros/macro.sql"),
|
|
path=normalize("macros/macro.sql"),
|
|
macro_sql=sql,
|
|
)
|
|
yield pm
|
|
|
|
|
|
class TestAdapterConversions(TestCase):
|
|
def _get_tester_for(self, column_type):
|
|
from dbt_common.clients import agate_helper
|
|
|
|
if column_type is agate.TimeDelta: # dbt never makes this!
|
|
return agate.TimeDelta()
|
|
|
|
for instance in agate_helper.DEFAULT_TYPE_TESTER._possible_types:
|
|
if isinstance(instance, column_type): # include child types
|
|
return instance
|
|
|
|
raise ValueError(f"no tester for {column_type}")
|
|
|
|
def _make_table_of(self, rows, column_types):
|
|
column_names = list(string.ascii_letters[: len(rows[0])])
|
|
if isinstance(column_types, type):
|
|
column_types = [self._get_tester_for(column_types) for _ in column_names]
|
|
else:
|
|
column_types = [self._get_tester_for(typ) for typ in column_types]
|
|
table = agate.Table(rows, column_names=column_names, column_types=column_types)
|
|
return table
|
|
|
|
|
|
def MockMacro(package, name="my_macro", **kwargs):
|
|
from dbt.contracts.graph.nodes import Macro
|
|
from dbt.node_types import NodeType
|
|
|
|
mock_kwargs = dict(
|
|
resource_type=NodeType.Macro,
|
|
package_name=package,
|
|
unique_id=f"macro.{package}.{name}",
|
|
original_file_path="/dev/null",
|
|
)
|
|
|
|
mock_kwargs.update(kwargs)
|
|
|
|
macro = mock.MagicMock(spec=Macro, **mock_kwargs)
|
|
macro.name = name
|
|
return macro
|
|
|
|
|
|
def MockMaterialization(package, name="my_materialization", adapter_type=None, **kwargs):
|
|
if adapter_type is None:
|
|
adapter_type = "default"
|
|
kwargs["adapter_type"] = adapter_type
|
|
return MockMacro(package, f"materialization_{name}_{adapter_type}", **kwargs)
|
|
|
|
|
|
def MockGenerateMacro(package, component="some_component", **kwargs):
|
|
name = f"generate_{component}_name"
|
|
return MockMacro(package, name=name, **kwargs)
|
|
|
|
|
|
def MockSource(package, source_name, name, **kwargs):
|
|
from dbt.contracts.graph.nodes import SourceDefinition
|
|
from dbt.node_types import NodeType
|
|
|
|
src = mock.MagicMock(
|
|
__class__=SourceDefinition,
|
|
resource_type=NodeType.Source,
|
|
source_name=source_name,
|
|
package_name=package,
|
|
unique_id=f"source.{package}.{source_name}.{name}",
|
|
search_name=f"{source_name}.{name}",
|
|
**kwargs,
|
|
)
|
|
src.name = name
|
|
return src
|
|
|
|
|
|
def MockNode(package, name, resource_type=None, **kwargs):
|
|
from dbt.contracts.graph.nodes import ModelNode, SeedNode
|
|
from dbt.node_types import NodeType
|
|
|
|
if resource_type is None:
|
|
resource_type = NodeType.Model
|
|
if resource_type == NodeType.Model:
|
|
cls = ModelNode
|
|
elif resource_type == NodeType.Seed:
|
|
cls = SeedNode
|
|
else:
|
|
raise ValueError(f"I do not know how to handle {resource_type}")
|
|
|
|
version = kwargs.get("version")
|
|
search_name = name if version is None else f"{name}.v{version}"
|
|
unique_id = f"{str(resource_type)}.{package}.{search_name}"
|
|
node = mock.MagicMock(
|
|
__class__=cls,
|
|
resource_type=resource_type,
|
|
package_name=package,
|
|
unique_id=unique_id,
|
|
search_name=search_name,
|
|
**kwargs,
|
|
)
|
|
node.name = name
|
|
node.is_versioned = resource_type is NodeType.Model and version is not None
|
|
return node
|
|
|
|
|
|
def MockDocumentation(package, name, **kwargs):
|
|
from dbt.contracts.graph.nodes import Documentation
|
|
from dbt.node_types import NodeType
|
|
|
|
doc = mock.MagicMock(
|
|
__class__=Documentation,
|
|
resource_type=NodeType.Documentation,
|
|
package_name=package,
|
|
search_name=name,
|
|
unique_id=f"{package}.{name}",
|
|
**kwargs,
|
|
)
|
|
doc.name = name
|
|
return doc
|
|
|
|
|
|
def load_internal_manifest_macros(config, macro_hook=lambda m: None):
|
|
from dbt.parser.manifest import ManifestLoader
|
|
|
|
return ManifestLoader.load_macros(config, macro_hook)
|
|
|
|
|
|
def dict_replace(dct, **kwargs):
|
|
dct = dct.copy()
|
|
dct.update(kwargs)
|
|
return dct
|
|
|
|
|
|
def replace_config(n, **kwargs):
|
|
from dataclasses import replace
|
|
|
|
return replace(
|
|
n,
|
|
config=n.config.replace(**kwargs),
|
|
unrendered_config=dict_replace(n.unrendered_config, **kwargs),
|
|
)
|
|
|
|
|
|
def make_manifest(nodes=[], sources=[], macros=[], docs=[]) -> Manifest:
|
|
return Manifest(
|
|
nodes={n.unique_id: n for n in nodes},
|
|
macros={m.unique_id: m for m in macros},
|
|
sources={s.unique_id: s for s in sources},
|
|
docs={d.unique_id: d for d in docs},
|
|
disabled={},
|
|
files={},
|
|
exposures={},
|
|
metrics={},
|
|
selectors={},
|
|
)
|