Stop ignoring test directory for precommit (#7201)

* reformat test directory to pass formatting checks

* remove test comment
This commit is contained in:
Emily Rockman
2023-03-22 08:04:13 -05:00
committed by GitHub
parent 8225a009b5
commit ca23148908
46 changed files with 7912 additions and 7152 deletions

View File

@@ -1,8 +1,7 @@
# Configuration for pre-commit hooks (see https://pre-commit.com/).
# Eventually the hooks described here will be run as tests before merging each PR.
# TODO: remove global exclusion of tests when testing overhaul is complete
exclude: ^(test/|core/dbt/docs/build/)
exclude: ^(core/dbt/docs/build/)
# Force all unspecified python hooks to run python 3.8
default_language_version:

View File

@@ -5,7 +5,5 @@ mypy_path = "third-party-stubs/"
namespace_packages = true
[tool.black]
# TODO: remove global exclusion of tests when testing overhaul is complete
force-exclude = 'test/'
line-length = 99
target-version = ['py38']

View File

@@ -6,11 +6,11 @@ from contextlib import contextmanager
def adapter_factory():
class MockAdapter(BaseAdapter):
ConnectionManager = mock.MagicMock(TYPE='mock')
ConnectionManager = mock.MagicMock(TYPE="mock")
responder = mock.MagicMock()
# some convenient defaults
responder.quote.side_effect = lambda identifier: '"{}"'.format(identifier)
responder.date_function.side_effect = lambda: 'unitdate()'
responder.date_function.side_effect = lambda: "unitdate()"
responder.is_cancelable.side_effect = lambda: False
@contextmanager

View File

@@ -4,59 +4,57 @@ from unittest import mock
from dbt.adapters.factory import AdapterContainer
from dbt.adapters.base.plugin import AdapterPlugin
from dbt.include.global_project import (
PACKAGE_PATH as GLOBAL_PROJECT_PATH,
PROJECT_NAME as GLOBAL_PROJECT_NAME,
)
class TestGetPackageNames(unittest.TestCase):
def setUp(self):
with mock.patch('dbt.adapters.base.plugin.project_name_from_path') as get_name:
get_name.return_value = 'root'
with mock.patch("dbt.adapters.base.plugin.project_name_from_path") as get_name:
get_name.return_value = "root"
self.root_plugin = AdapterPlugin(
adapter=mock.MagicMock(),
credentials=mock.MagicMock(),
include_path='/path/to/root/plugin',
dependencies=['childa', 'childb'],
include_path="/path/to/root/plugin",
dependencies=["childa", "childb"],
)
get_name.return_value = 'pkg_childa'
get_name.return_value = "pkg_childa"
self.childa = AdapterPlugin(
adapter=mock.MagicMock(),
credentials=mock.MagicMock(),
include_path='/path/to/childa',
include_path="/path/to/childa",
)
get_name.return_value = 'pkg_childb'
get_name.return_value = "pkg_childb"
self.childb = AdapterPlugin(
adapter=mock.MagicMock(),
credentials=mock.MagicMock(),
include_path='/path/to/childb',
dependencies=['childc']
include_path="/path/to/childb",
dependencies=["childc"],
)
get_name.return_value = 'pkg_childc'
get_name.return_value = "pkg_childc"
self.childc = AdapterPlugin(
adapter=mock.MagicMock(),
credentials=mock.MagicMock(),
include_path='/path/to/childc',
include_path="/path/to/childc",
)
self._mock_modules = {
'root': self.root_plugin,
'childa': self.childa,
'childb': self.childb,
'childc': self.childc,
"root": self.root_plugin,
"childa": self.childa,
"childb": self.childb,
"childc": self.childc,
}
self.factory = AdapterContainer()
self.load_patch = mock.patch.object(AdapterContainer, 'load_plugin')
self.load_patch = mock.patch.object(AdapterContainer, "load_plugin")
self.mock_load = self.load_patch.start()
def mock_load_plugin(name: str):
try:
plugin = self._mock_modules[name]
except KeyError:
raise RuntimeError(f'test could not find adapter type {name}!')
raise RuntimeError(f"test could not find adapter type {name}!")
self.factory.plugins[name] = plugin
self.factory.packages[plugin.project_name] = Path(plugin.include_path)
for dep in plugin.dependencies:
@@ -71,13 +69,26 @@ class TestGetPackageNames(unittest.TestCase):
assert self.factory.get_adapter_package_names(None) == [GLOBAL_PROJECT_NAME]
def test_one_package(self):
self.factory.load_plugin('childc')
assert self.factory.get_adapter_package_names('childc') == ['pkg_childc', GLOBAL_PROJECT_NAME]
self.factory.load_plugin("childc")
assert self.factory.get_adapter_package_names("childc") == [
"pkg_childc",
GLOBAL_PROJECT_NAME,
]
def test_simple_child_packages(self):
self.factory.load_plugin('childb')
assert self.factory.get_adapter_package_names('childb') == ['pkg_childb', 'pkg_childc', GLOBAL_PROJECT_NAME]
self.factory.load_plugin("childb")
assert self.factory.get_adapter_package_names("childb") == [
"pkg_childb",
"pkg_childc",
GLOBAL_PROJECT_NAME,
]
def test_layered_child_packages(self):
self.factory.load_plugin('root')
assert self.factory.get_adapter_package_names('root') == ['root', 'pkg_childa', 'pkg_childb', 'pkg_childc', GLOBAL_PROJECT_NAME]
self.factory.load_plugin("root")
assert self.factory.get_adapter_package_names("root") == [
"root",
"pkg_childa",
"pkg_childb",
"pkg_childc",
GLOBAL_PROJECT_NAME,
]

View File

@@ -14,26 +14,34 @@ SAMPLE_CSV_DATA = """a,b,c,d,e,f,g
1,n,test,3.2,20180806T11:33:29.320Z,True,NULL
2,y,asdf,900,20180806T11:35:29.320Z,False,a string"""
SAMPLE_CSV_BOM_DATA = u'\ufeff' + SAMPLE_CSV_DATA
SAMPLE_CSV_BOM_DATA = "\ufeff" + SAMPLE_CSV_DATA
EXPECTED = [
[
1, 'n', 'test', Decimal('3.2'),
1,
"n",
"test",
Decimal("3.2"),
datetime(2018, 8, 6, 11, 33, 29, 320000, tzinfo=tzinfo.Utc()),
True, None,
True,
None,
],
[
2, 'y', 'asdf', 900,
2,
"y",
"asdf",
900,
datetime(2018, 8, 6, 11, 35, 29, 320000, tzinfo=tzinfo.Utc()),
False, 'a string',
False,
"a string",
],
]
EXPECTED_STRINGS = [
['1', 'n', 'test', '3.2', '20180806T11:33:29.320Z', 'True', None],
['2', 'y', 'asdf', '900', '20180806T11:35:29.320Z', 'False', 'a string'],
["1", "n", "test", "3.2", "20180806T11:33:29.320Z", "True", None],
["2", "y", "asdf", "900", "20180806T11:35:29.320Z", "False", "a string"],
]
@@ -45,39 +53,53 @@ class TestAgateHelper(unittest.TestCase):
rmtree(self.tempdir)
def test_from_csv(self):
path = os.path.join(self.tempdir, 'input.csv')
with open(path, 'wb') as fp:
fp.write(SAMPLE_CSV_DATA.encode('utf-8'))
path = os.path.join(self.tempdir, "input.csv")
with open(path, "wb") as fp:
fp.write(SAMPLE_CSV_DATA.encode("utf-8"))
tbl = agate_helper.from_csv(path, ())
self.assertEqual(len(tbl), len(EXPECTED))
for idx, row in enumerate(tbl):
self.assertEqual(list(row), EXPECTED[idx])
def test_bom_from_csv(self):
path = os.path.join(self.tempdir, 'input.csv')
with open(path, 'wb') as fp:
fp.write(SAMPLE_CSV_BOM_DATA.encode('utf-8'))
path = os.path.join(self.tempdir, "input.csv")
with open(path, "wb") as fp:
fp.write(SAMPLE_CSV_BOM_DATA.encode("utf-8"))
tbl = agate_helper.from_csv(path, ())
self.assertEqual(len(tbl), len(EXPECTED))
for idx, row in enumerate(tbl):
self.assertEqual(list(row), EXPECTED[idx])
def test_from_csv_all_reserved(self):
path = os.path.join(self.tempdir, 'input.csv')
with open(path, 'wb') as fp:
fp.write(SAMPLE_CSV_DATA.encode('utf-8'))
tbl = agate_helper.from_csv(path, tuple('abcdefg'))
path = os.path.join(self.tempdir, "input.csv")
with open(path, "wb") as fp:
fp.write(SAMPLE_CSV_DATA.encode("utf-8"))
tbl = agate_helper.from_csv(path, tuple("abcdefg"))
self.assertEqual(len(tbl), len(EXPECTED_STRINGS))
for expected, row in zip(EXPECTED_STRINGS, tbl):
self.assertEqual(list(row), expected)
def test_from_data(self):
column_names = ['a', 'b', 'c', 'd', 'e', 'f', 'g']
column_names = ["a", "b", "c", "d", "e", "f", "g"]
data = [
{'a': '1', 'b': 'n', 'c': 'test', 'd': '3.2',
'e': '20180806T11:33:29.320Z', 'f': 'True', 'g': 'NULL'},
{'a': '2', 'b': 'y', 'c': 'asdf', 'd': '900',
'e': '20180806T11:35:29.320Z', 'f': 'False', 'g': 'a string'}
{
"a": "1",
"b": "n",
"c": "test",
"d": "3.2",
"e": "20180806T11:33:29.320Z",
"f": "True",
"g": "NULL",
},
{
"a": "2",
"b": "y",
"c": "asdf",
"d": "900",
"e": "20180806T11:35:29.320Z",
"f": "False",
"g": "a string",
},
]
tbl = agate_helper.table_from_data(data, column_names)
self.assertEqual(len(tbl), len(EXPECTED))
@@ -85,50 +107,50 @@ class TestAgateHelper(unittest.TestCase):
self.assertEqual(list(row), EXPECTED[idx])
def test_datetime_formats(self):
path = os.path.join(self.tempdir, 'input.csv')
path = os.path.join(self.tempdir, "input.csv")
datetimes = [
'20180806T11:33:29.000Z',
'20180806T11:33:29Z',
'20180806T113329Z',
"20180806T11:33:29.000Z",
"20180806T11:33:29Z",
"20180806T113329Z",
]
expected = datetime(2018, 8, 6, 11, 33, 29, 0, tzinfo=tzinfo.Utc())
for dt in datetimes:
with open(path, 'wb') as fp:
fp.write('a\n{}'.format(dt).encode('utf-8'))
with open(path, "wb") as fp:
fp.write("a\n{}".format(dt).encode("utf-8"))
tbl = agate_helper.from_csv(path, ())
self.assertEqual(tbl[0][0], expected)
def test_merge_allnull(self):
t1 = agate.Table([(1, 'a', None), (2, 'b', None)], ('a', 'b', 'c'))
t2 = agate.Table([(3, 'c', None), (4, 'd', None)], ('a', 'b', 'c'))
t1 = agate.Table([(1, "a", None), (2, "b", None)], ("a", "b", "c"))
t2 = agate.Table([(3, "c", None), (4, "d", None)], ("a", "b", "c"))
result = agate_helper.merge_tables([t1, t2])
self.assertEqual(result.column_names, ('a', 'b', 'c'))
self.assertEqual(result.column_names, ("a", "b", "c"))
assert isinstance(result.column_types[0], agate.data_types.Number)
assert isinstance(result.column_types[1], agate.data_types.Text)
assert isinstance(result.column_types[2], agate.data_types.Number)
self.assertEqual(len(result), 4)
def test_merge_mixed(self):
t1 = agate.Table([(1, 'a', None), (2, 'b', None)], ('a', 'b', 'c'))
t2 = agate.Table([(3, 'c', 'dog'), (4, 'd', 'cat')], ('a', 'b', 'c'))
t3 = agate.Table([(3, 'c', None), (4, 'd', None)], ('a', 'b', 'c'))
t1 = agate.Table([(1, "a", None), (2, "b", None)], ("a", "b", "c"))
t2 = agate.Table([(3, "c", "dog"), (4, "d", "cat")], ("a", "b", "c"))
t3 = agate.Table([(3, "c", None), (4, "d", None)], ("a", "b", "c"))
result = agate_helper.merge_tables([t1, t2])
self.assertEqual(result.column_names, ('a', 'b', 'c'))
self.assertEqual(result.column_names, ("a", "b", "c"))
assert isinstance(result.column_types[0], agate.data_types.Number)
assert isinstance(result.column_types[1], agate.data_types.Text)
assert isinstance(result.column_types[2], agate.data_types.Text)
self.assertEqual(len(result), 4)
result = agate_helper.merge_tables([t2, t3])
self.assertEqual(result.column_names, ('a', 'b', 'c'))
self.assertEqual(result.column_names, ("a", "b", "c"))
assert isinstance(result.column_types[0], agate.data_types.Number)
assert isinstance(result.column_types[1], agate.data_types.Text)
assert isinstance(result.column_types[2], agate.data_types.Text)
self.assertEqual(len(result), 4)
result = agate_helper.merge_tables([t1, t2, t3])
self.assertEqual(result.column_names, ('a', 'b', 'c'))
self.assertEqual(result.column_names, ("a", "b", "c"))
assert isinstance(result.column_types[0], agate.data_types.Number)
assert isinstance(result.column_types[1], agate.data_types.Text)
assert isinstance(result.column_types[2], agate.data_types.Text)
@@ -138,18 +160,18 @@ class TestAgateHelper(unittest.TestCase):
# String fields should not be coerced into a representative type
# See: https://github.com/dbt-labs/dbt-core/issues/2984
column_names = ['a', 'b', 'c', 'd', 'e']
column_names = ["a", "b", "c", "d", "e"]
result_set = [
{'a': '0005', 'b': '01T00000aabbccdd', 'c': 'true', 'd': 10, 'e': False},
{'a': '0006', 'b': '01T00000aabbccde', 'c': 'false', 'd': 11, 'e': True},
{"a": "0005", "b": "01T00000aabbccdd", "c": "true", "d": 10, "e": False},
{"a": "0006", "b": "01T00000aabbccde", "c": "false", "d": 11, "e": True},
]
tbl = agate_helper.table_from_data_flat(data=result_set, column_names=column_names)
self.assertEqual(len(tbl), len(result_set))
expected = [
['0005', '01T00000aabbccdd', 'true', Decimal(10), False],
['0006', '01T00000aabbccde', 'false', Decimal(11), True],
["0005", "01T00000aabbccdd", "true", Decimal(10), False],
["0006", "01T00000aabbccde", "false", Decimal(11), True],
]
for i, row in enumerate(tbl):
@@ -159,10 +181,10 @@ class TestAgateHelper(unittest.TestCase):
# True and False values should not be cast to 1 and 0, and vice versa
# See: https://github.com/dbt-labs/dbt-core/issues/4511
column_names = ['a', 'b']
column_names = ["a", "b"]
result_set = [
{'a': True, 'b': 1},
{'a': False, 'b': 0},
{"a": True, "b": 1},
{"a": False, "b": 0},
]
tbl = agate_helper.table_from_data_flat(data=result_set, column_names=column_names)
@@ -178,4 +200,3 @@ class TestAgateHelper(unittest.TestCase):
for i, row in enumerate(tbl):
self.assertEqual(list(row), expected[i])

View File

@@ -4,34 +4,27 @@ import decimal
from dbt.adapters.base import Column
class TestStringType(unittest.TestCase):
def test__character_type(self):
col = Column(
'fieldname',
'character',
char_size=10
)
col = Column("fieldname", "character", char_size=10)
self.assertEqual(col.data_type, 'character varying(10)')
self.assertEqual(col.data_type, "character varying(10)")
class TestNumericType(unittest.TestCase):
def test__numeric_type(self):
col = Column(
'fieldname',
'numeric',
numeric_precision=decimal.Decimal('12'),
numeric_scale=decimal.Decimal('2'))
"fieldname",
"numeric",
numeric_precision=decimal.Decimal("12"),
numeric_scale=decimal.Decimal("2"),
)
self.assertEqual(col.data_type, 'numeric(12,2)')
self.assertEqual(col.data_type, "numeric(12,2)")
def test__numeric_type_with_no_precision(self):
# PostgreSQL, at least, will allow empty numeric precision
col = Column(
'fieldname',
'numeric',
numeric_precision=None)
col = Column("fieldname", "numeric", numeric_precision=None)
self.assertEqual(col.data_type, 'numeric')
self.assertEqual(col.data_type, "numeric")

View File

@@ -9,12 +9,12 @@ import time
def make_relation(database, schema, identifier):
return BaseRelation.create(database=database, schema=schema,
identifier=identifier)
return BaseRelation.create(database=database, schema=schema, identifier=identifier)
def make_mock_relationship(database, schema, identifier):
return BaseRelation.create(
database=database, schema=schema, identifier=identifier, type='view'
database=database, schema=schema, identifier=identifier, type="view"
)
@@ -26,13 +26,11 @@ class TestCache(TestCase):
relations = self.cache.get_relations(database, schema)
for identifier, expect in identifiers.items():
found = any(
(r.identifier == identifier and \
r.schema == schema and \
r.database == database)
(r.identifier == identifier and r.schema == schema and r.database == database)
for r in relations
)
msg = '{}.{}.{} was{} found in the cache!'.format(
database, schema, identifier, '' if found else ' not'
msg = "{}.{}.{} was{} found in the cache!".format(
database, schema, identifier, "" if found else " not"
)
self.assertEqual(expect, found, msg)
@@ -46,221 +44,229 @@ class TestCache(TestCase):
class TestEmpty(TestCache):
def test_empty(self):
self.assertEqual(len(self.cache.relations), 0)
self.assertEqual(len(self.cache.get_relations('dbt', 'test')), 0)
self.assertEqual(len(self.cache.get_relations("dbt", "test")), 0)
class TestDrop(TestCache):
def setUp(self):
super().setUp()
self.cache.add(make_relation('dbt', 'foo', 'bar'))
self.cache.add(make_relation("dbt", "foo", "bar"))
def test_missing_identifier_ignored(self):
self.cache.drop(make_relation('dbt', 'foo', 'bar1'))
self.assert_relations_exist('dbt', 'foo', 'bar')
self.cache.drop(make_relation("dbt", "foo", "bar1"))
self.assert_relations_exist("dbt", "foo", "bar")
self.assertEqual(len(self.cache.relations), 1)
def test_missing_schema_ignored(self):
self.cache.drop(make_relation('dbt', 'foo1', 'bar'))
self.assert_relations_exist('dbt', 'foo', 'bar')
self.cache.drop(make_relation("dbt", "foo1", "bar"))
self.assert_relations_exist("dbt", "foo", "bar")
self.assertEqual(len(self.cache.relations), 1)
def test_missing_db_ignored(self):
self.cache.drop(make_relation('dbt1', 'foo', 'bar'))
self.assert_relations_exist('dbt', 'foo', 'bar')
self.cache.drop(make_relation("dbt1", "foo", "bar"))
self.assert_relations_exist("dbt", "foo", "bar")
self.assertEqual(len(self.cache.relations), 1)
def test_drop(self):
self.cache.drop(make_relation('dbt', 'foo', 'bar'))
self.assert_relations_do_not_exist('dbt', 'foo', 'bar')
self.cache.drop(make_relation("dbt", "foo", "bar"))
self.assert_relations_do_not_exist("dbt", "foo", "bar")
self.assertEqual(len(self.cache.relations), 0)
class TestAddLink(TestCache):
def setUp(self):
super().setUp()
self.cache.add(make_relation('dbt', 'schema', 'foo'))
self.cache.add(make_relation('dbt_2', 'schema', 'bar'))
self.cache.add(make_relation('dbt', 'schema_2', 'bar'))
self.cache.add(make_relation("dbt", "schema", "foo"))
self.cache.add(make_relation("dbt_2", "schema", "bar"))
self.cache.add(make_relation("dbt", "schema_2", "bar"))
def test_no_src(self):
self.assert_relations_exist('dbt', 'schema', 'foo')
self.assert_relations_do_not_exist('dbt', 'schema', 'bar')
self.assert_relations_exist("dbt", "schema", "foo")
self.assert_relations_do_not_exist("dbt", "schema", "bar")
self.cache.add_link(make_relation('dbt', 'schema', 'bar'),
make_relation('dbt', 'schema', 'foo'))
self.cache.add_link(
make_relation("dbt", "schema", "bar"), make_relation("dbt", "schema", "foo")
)
self.assert_relations_exist('dbt', 'schema', 'foo', 'bar')
self.assert_relations_exist("dbt", "schema", "foo", "bar")
def test_no_dst(self):
self.assert_relations_exist('dbt', 'schema', 'foo')
self.assert_relations_do_not_exist('dbt', 'schema', 'bar')
self.assert_relations_exist("dbt", "schema", "foo")
self.assert_relations_do_not_exist("dbt", "schema", "bar")
self.cache.add_link(make_relation('dbt', 'schema', 'foo'),
make_relation('dbt', 'schema', 'bar'))
self.cache.add_link(
make_relation("dbt", "schema", "foo"), make_relation("dbt", "schema", "bar")
)
self.assert_relations_exist('dbt', 'schema', 'foo', 'bar')
self.assert_relations_exist("dbt", "schema", "foo", "bar")
class TestRename(TestCache):
def setUp(self):
super().setUp()
self.cache.add(make_relation('DBT', 'schema', 'foo'))
self.assert_relations_exist('DBT', 'schema', 'foo')
self.assertEqual(self.cache.schemas, {('dbt', 'schema')})
self.cache.add(make_relation("DBT", "schema", "foo"))
self.assert_relations_exist("DBT", "schema", "foo")
self.assertEqual(self.cache.schemas, {("dbt", "schema")})
def test_no_source_error(self):
# dest should be created anyway (it's probably a temp table)
self.cache.rename(make_relation('DBT', 'schema', 'bar'),
make_relation('DBT', 'schema', 'baz'))
self.cache.rename(
make_relation("DBT", "schema", "bar"), make_relation("DBT", "schema", "baz")
)
self.assertEqual(len(self.cache.relations), 2)
self.assert_relations_exist('DBT', 'schema', 'foo', 'baz')
self.assert_relations_exist("DBT", "schema", "foo", "baz")
def test_dest_exists_error(self):
foo = make_relation('DBT', 'schema', 'foo')
bar = make_relation('DBT', 'schema', 'bar')
foo = make_relation("DBT", "schema", "foo")
bar = make_relation("DBT", "schema", "bar")
self.cache.add(bar)
self.assert_relations_exist('DBT', 'schema', 'foo', 'bar')
self.assert_relations_exist("DBT", "schema", "foo", "bar")
with self.assertRaises(dbt.exceptions.DbtInternalError):
self.cache.rename(foo, bar)
self.assert_relations_exist('DBT', 'schema', 'foo', 'bar')
self.assert_relations_exist("DBT", "schema", "foo", "bar")
def test_dest_different_db(self):
self.cache.rename(make_relation('DBT', 'schema', 'foo'),
make_relation('DBT_2', 'schema', 'foo'))
self.assert_relations_exist('DBT_2', 'schema', 'foo')
self.assert_relations_do_not_exist('DBT', 'schema', 'foo')
self.cache.rename(
make_relation("DBT", "schema", "foo"), make_relation("DBT_2", "schema", "foo")
)
self.assert_relations_exist("DBT_2", "schema", "foo")
self.assert_relations_do_not_exist("DBT", "schema", "foo")
# we know about both schemas: dbt has nothing, dbt_2 has something.
self.assertEqual(self.cache.schemas, {('dbt_2', 'schema'), ('dbt', 'schema')})
self.assertEqual(self.cache.schemas, {("dbt_2", "schema"), ("dbt", "schema")})
self.assertEqual(len(self.cache.relations), 1)
def test_rename_identifier(self):
self.cache.rename(make_relation('DBT', 'schema', 'foo'),
make_relation('DBT', 'schema', 'bar'))
self.cache.rename(
make_relation("DBT", "schema", "foo"), make_relation("DBT", "schema", "bar")
)
self.assert_relations_exist('DBT', 'schema', 'bar')
self.assert_relations_do_not_exist('DBT', 'schema', 'foo')
self.assertEqual(self.cache.schemas, {('dbt', 'schema')})
self.assert_relations_exist("DBT", "schema", "bar")
self.assert_relations_do_not_exist("DBT", "schema", "foo")
self.assertEqual(self.cache.schemas, {("dbt", "schema")})
relation = self.cache.relations[('dbt', 'schema', 'bar')]
self.assertEqual(relation.inner.schema, 'schema')
self.assertEqual(relation.inner.identifier, 'bar')
self.assertEqual(relation.schema, 'schema')
self.assertEqual(relation.identifier, 'bar')
relation = self.cache.relations[("dbt", "schema", "bar")]
self.assertEqual(relation.inner.schema, "schema")
self.assertEqual(relation.inner.identifier, "bar")
self.assertEqual(relation.schema, "schema")
self.assertEqual(relation.identifier, "bar")
def test_rename_db(self):
self.cache.rename(make_relation('DBT', 'schema', 'foo'),
make_relation('DBT_2', 'schema', 'foo'))
self.cache.rename(
make_relation("DBT", "schema", "foo"), make_relation("DBT_2", "schema", "foo")
)
self.assertEqual(len(self.cache.get_relations('DBT', 'schema')), 0)
self.assertEqual(len(self.cache.get_relations('DBT_2', 'schema')), 1)
self.assert_relations_exist('DBT_2', 'schema', 'foo')
self.assert_relations_do_not_exist('DBT', 'schema', 'foo')
self.assertEqual(len(self.cache.get_relations("DBT", "schema")), 0)
self.assertEqual(len(self.cache.get_relations("DBT_2", "schema")), 1)
self.assert_relations_exist("DBT_2", "schema", "foo")
self.assert_relations_do_not_exist("DBT", "schema", "foo")
# we know about both schemas: dbt has nothing, dbt_2 has something.
self.assertEqual(self.cache.schemas, {('dbt_2', 'schema'), ('dbt', 'schema')})
self.assertEqual(self.cache.schemas, {("dbt_2", "schema"), ("dbt", "schema")})
relation = self.cache.relations[('dbt_2', 'schema', 'foo')]
self.assertEqual(relation.inner.database, 'DBT_2')
self.assertEqual(relation.inner.schema, 'schema')
self.assertEqual(relation.inner.identifier, 'foo')
self.assertEqual(relation.database, 'dbt_2')
self.assertEqual(relation.schema, 'schema')
self.assertEqual(relation.identifier, 'foo')
relation = self.cache.relations[("dbt_2", "schema", "foo")]
self.assertEqual(relation.inner.database, "DBT_2")
self.assertEqual(relation.inner.schema, "schema")
self.assertEqual(relation.inner.identifier, "foo")
self.assertEqual(relation.database, "dbt_2")
self.assertEqual(relation.schema, "schema")
self.assertEqual(relation.identifier, "foo")
def test_rename_schema(self):
self.cache.rename(make_relation('DBT', 'schema', 'foo'),
make_relation('DBT', 'schema_2', 'foo'))
self.cache.rename(
make_relation("DBT", "schema", "foo"), make_relation("DBT", "schema_2", "foo")
)
self.assertEqual(len(self.cache.get_relations('DBT', 'schema')), 0)
self.assertEqual(len(self.cache.get_relations('DBT', 'schema_2')), 1)
self.assert_relations_exist('DBT', 'schema_2', 'foo')
self.assert_relations_do_not_exist('DBT', 'schema', 'foo')
self.assertEqual(len(self.cache.get_relations("DBT", "schema")), 0)
self.assertEqual(len(self.cache.get_relations("DBT", "schema_2")), 1)
self.assert_relations_exist("DBT", "schema_2", "foo")
self.assert_relations_do_not_exist("DBT", "schema", "foo")
# we know about both schemas: schema has nothing, schema_2 has something.
self.assertEqual(self.cache.schemas, {('dbt', 'schema_2'), ('dbt', 'schema')})
self.assertEqual(self.cache.schemas, {("dbt", "schema_2"), ("dbt", "schema")})
relation = self.cache.relations[('dbt', 'schema_2', 'foo')]
self.assertEqual(relation.inner.database, 'DBT')
self.assertEqual(relation.inner.schema, 'schema_2')
self.assertEqual(relation.inner.identifier, 'foo')
self.assertEqual(relation.database, 'dbt')
self.assertEqual(relation.schema, 'schema_2')
self.assertEqual(relation.identifier, 'foo')
relation = self.cache.relations[("dbt", "schema_2", "foo")]
self.assertEqual(relation.inner.database, "DBT")
self.assertEqual(relation.inner.schema, "schema_2")
self.assertEqual(relation.inner.identifier, "foo")
self.assertEqual(relation.database, "dbt")
self.assertEqual(relation.schema, "schema_2")
self.assertEqual(relation.identifier, "foo")
class TestGetRelations(TestCache):
def setUp(self):
super().setUp()
self.relation = make_relation('dbt', 'foo', 'bar')
self.relation = make_relation("dbt", "foo", "bar")
self.cache.add(self.relation)
def test_get_by_name(self):
relations = self.cache.get_relations('dbt', 'foo')
relations = self.cache.get_relations("dbt", "foo")
self.assertEqual(len(relations), 1)
self.assertIs(relations[0], self.relation)
def test_get_by_uppercase_schema(self):
relations = self.cache.get_relations('dbt', 'FOO')
relations = self.cache.get_relations("dbt", "FOO")
self.assertEqual(len(relations), 1)
self.assertIs(relations[0], self.relation)
def test_get_by_uppercase_db(self):
relations = self.cache.get_relations('DBT', 'foo')
relations = self.cache.get_relations("DBT", "foo")
self.assertEqual(len(relations), 1)
self.assertIs(relations[0], self.relation)
def test_get_by_uppercase_schema_and_db(self):
relations = self.cache.get_relations('DBT', 'FOO')
relations = self.cache.get_relations("DBT", "FOO")
self.assertEqual(len(relations), 1)
self.assertIs(relations[0], self.relation)
def test_get_by_wrong_db(self):
relations = self.cache.get_relations('dbt_2', 'foo')
relations = self.cache.get_relations("dbt_2", "foo")
self.assertEqual(len(relations), 0)
def test_get_by_wrong_schema(self):
relations = self.cache.get_relations('dbt', 'foo_2')
relations = self.cache.get_relations("dbt", "foo_2")
self.assertEqual(len(relations), 0)
class TestAdd(TestCache):
def setUp(self):
super().setUp()
self.relation = make_relation('dbt', 'foo', 'bar')
self.relation = make_relation("dbt", "foo", "bar")
self.cache.add(self.relation)
def test_add(self):
relations = self.cache.get_relations('dbt', 'foo')
relations = self.cache.get_relations("dbt", "foo")
self.assertEqual(len(relations), 1)
self.assertEqual(len(self.cache.relations), 1)
self.assertIs(relations[0], self.relation)
def test_add_twice(self):
# add a new relation with same name
self.cache.add(make_relation('dbt', 'foo', 'bar'))
self.cache.add(make_relation("dbt", "foo", "bar"))
self.assertEqual(len(self.cache.relations), 1)
self.assertEqual(self.cache.schemas, {('dbt', 'foo')})
self.assert_relations_exist('dbt', 'foo', 'bar')
self.assertEqual(self.cache.schemas, {("dbt", "foo")})
self.assert_relations_exist("dbt", "foo", "bar")
def add_uppercase_schema(self):
self.cache.add(make_relation('dbt', 'FOO', 'baz'))
self.cache.add(make_relation("dbt", "FOO", "baz"))
self.assertEqual(len(self.cache.relations), 2)
relations = self.cache.get_relations('dbt', 'foo')
relations = self.cache.get_relations("dbt", "foo")
self.assertEqual(len(relations), 2)
self.assertEqual(self.cache.schemas, {('dbt', 'foo')})
self.assertIsNot(self.cache.relations[('dbt', 'foo', 'bar')].inner, None)
self.assertIsNot(self.cache.relations[('dbt', 'foo', 'baz')].inner, None)
self.assertEqual(self.cache.schemas, {("dbt", "foo")})
self.assertIsNot(self.cache.relations[("dbt", "foo", "bar")].inner, None)
self.assertIsNot(self.cache.relations[("dbt", "foo", "baz")].inner, None)
def add_different_db(self):
self.cache.add(make_relation('dbt_2', 'foo', 'bar'))
self.cache.add(make_relation("dbt_2", "foo", "bar"))
self.assertEqual(len(self.cache.relations), 2)
self.assertEqual(len(self.cache.get_relations('dbt_2', 'foo')), 1)
self.assertEqual(len(self.cache.get_relations('dbt', 'foo')), 1)
self.assertEqual(self.cache.schemas, {('dbt', 'foo'), ('dbt_2', 'foo')})
self.assertIsNot(self.cache.relations[('dbt', 'foo', 'bar')].inner, None)
self.assertIsNot(self.cache.relations[('dbt_2', 'foo', 'bar')].inner, None)
self.assertEqual(len(self.cache.get_relations("dbt_2", "foo")), 1)
self.assertEqual(len(self.cache.get_relations("dbt", "foo")), 1)
self.assertEqual(self.cache.schemas, {("dbt", "foo"), ("dbt_2", "foo")})
self.assertIsNot(self.cache.relations[("dbt", "foo", "bar")].inner, None)
self.assertIsNot(self.cache.relations[("dbt_2", "foo", "bar")].inner, None)
class TestLikeDbt(TestCase):
@@ -269,57 +275,70 @@ class TestLikeDbt(TestCase):
self._sleep = True
# add a bunch of cache entries
for ident in 'abcdef':
self.cache.add(make_relation('dbt', 'schema', ident))
for ident in "abcdef":
self.cache.add(make_relation("dbt", "schema", ident))
# 'b' references 'a'
self.cache.add_link(make_relation('dbt', 'schema', 'a'),
make_relation('dbt', 'schema', 'b'))
self.cache.add_link(
make_relation("dbt", "schema", "a"), make_relation("dbt", "schema", "b")
)
# and 'c' references 'b'
self.cache.add_link(make_relation('dbt', 'schema', 'b'),
make_relation('dbt', 'schema', 'c'))
self.cache.add_link(
make_relation("dbt", "schema", "b"), make_relation("dbt", "schema", "c")
)
# and 'd' references 'b'
self.cache.add_link(make_relation('dbt', 'schema', 'b'),
make_relation('dbt', 'schema', 'd'))
self.cache.add_link(
make_relation("dbt", "schema", "b"), make_relation("dbt", "schema", "d")
)
# and 'e' references 'a'
self.cache.add_link(make_relation('dbt', 'schema', 'a'),
make_relation('dbt', 'schema', 'e'))
self.cache.add_link(
make_relation("dbt", "schema", "a"), make_relation("dbt", "schema", "e")
)
# and 'f' references 'd'
self.cache.add_link(make_relation('dbt', 'schema', 'd'),
make_relation('dbt', 'schema', 'f'))
self.cache.add_link(
make_relation("dbt", "schema", "d"), make_relation("dbt", "schema", "f")
)
# so drop propagation goes (a -> (b -> (c (d -> f))) e)
def assert_has_relations(self, expected):
current = set(r.identifier for r in self.cache.get_relations('dbt', 'schema'))
current = set(r.identifier for r in self.cache.get_relations("dbt", "schema"))
self.assertEqual(current, expected)
def test_drop_inner(self):
self.assert_has_relations(set('abcdef'))
self.cache.drop(make_relation('dbt', 'schema', 'b'))
self.assert_has_relations({'a', 'e'})
self.assert_has_relations(set("abcdef"))
self.cache.drop(make_relation("dbt", "schema", "b"))
self.assert_has_relations({"a", "e"})
def test_rename_and_drop(self):
self.assert_has_relations(set('abcdef'))
self.assert_has_relations(set("abcdef"))
# drop the backup/tmp
self.cache.drop(make_relation('dbt', 'schema', 'b__backup'))
self.cache.drop(make_relation('dbt', 'schema', 'b__tmp'))
self.assert_has_relations(set('abcdef'))
self.cache.drop(make_relation("dbt", "schema", "b__backup"))
self.cache.drop(make_relation("dbt", "schema", "b__tmp"))
self.assert_has_relations(set("abcdef"))
# create a new b__tmp
self.cache.add(make_relation('dbt', 'schema', 'b__tmp',))
self.assert_has_relations(set('abcdef') | {'b__tmp'})
self.cache.add(
make_relation(
"dbt",
"schema",
"b__tmp",
)
)
self.assert_has_relations(set("abcdef") | {"b__tmp"})
# rename b -> b__backup
self.cache.rename(make_relation('dbt', 'schema', 'b'),
make_relation('dbt', 'schema', 'b__backup'))
self.assert_has_relations(set('acdef') | {'b__tmp', 'b__backup'})
self.cache.rename(
make_relation("dbt", "schema", "b"), make_relation("dbt", "schema", "b__backup")
)
self.assert_has_relations(set("acdef") | {"b__tmp", "b__backup"})
# rename temp to b
self.cache.rename(make_relation('dbt', 'schema', 'b__tmp'),
make_relation('dbt', 'schema', 'b'))
self.assert_has_relations(set('abcdef') | {'b__backup'})
self.cache.rename(
make_relation("dbt", "schema", "b__tmp"), make_relation("dbt", "schema", "b")
)
self.assert_has_relations(set("abcdef") | {"b__backup"})
# drop backup, everything that used to depend on b should be gone, but
# b itself should still exist
self.cache.drop(make_relation('dbt', 'schema', 'b__backup'))
self.assert_has_relations(set('abe'))
relation = self.cache.relations[('dbt', 'schema', 'a')]
self.cache.drop(make_relation("dbt", "schema", "b__backup"))
self.assert_has_relations(set("abe"))
relation = self.cache.relations[("dbt", "schema", "a")]
self.assertEqual(len(relation.referenced_by), 1)
def _rand_sleep(self):
@@ -329,63 +348,74 @@ class TestLikeDbt(TestCase):
def _target(self, ident):
self._rand_sleep()
self.cache.rename(make_relation('dbt', 'schema', ident),
make_relation('dbt', 'schema', ident+'__backup'))
self._rand_sleep()
self.cache.add(make_relation('dbt', 'schema', ident+'__tmp')
self.cache.rename(
make_relation("dbt", "schema", ident),
make_relation("dbt", "schema", ident + "__backup"),
)
self._rand_sleep()
self.cache.rename(make_relation('dbt', 'schema', ident+'__tmp'),
make_relation('dbt', 'schema', ident))
self.cache.add(make_relation("dbt", "schema", ident + "__tmp"))
self._rand_sleep()
self.cache.drop(make_relation('dbt', 'schema', ident+'__backup'))
return ident, self.cache.get_relations('dbt', 'schema')
self.cache.rename(
make_relation("dbt", "schema", ident + "__tmp"), make_relation("dbt", "schema", ident)
)
self._rand_sleep()
self.cache.drop(make_relation("dbt", "schema", ident + "__backup"))
return ident, self.cache.get_relations("dbt", "schema")
def test_threaded(self):
# add three more short subchains for threads to test on
for ident in 'ghijklmno':
obj = make_mock_relationship('test_db', 'schema', ident)
self.cache.add(make_relation('dbt', 'schema', ident))
for ident in "ghijklmno":
make_mock_relationship("test_db", "schema", ident)
self.cache.add(make_relation("dbt", "schema", ident))
self.cache.add_link(make_relation('dbt', 'schema', 'a'),
make_relation('dbt', 'schema', 'g'))
self.cache.add_link(make_relation('dbt', 'schema', 'g'),
make_relation('dbt', 'schema', 'h'))
self.cache.add_link(make_relation('dbt', 'schema', 'h'),
make_relation('dbt', 'schema', 'i'))
self.cache.add_link(
make_relation("dbt", "schema", "a"), make_relation("dbt", "schema", "g")
)
self.cache.add_link(
make_relation("dbt", "schema", "g"), make_relation("dbt", "schema", "h")
)
self.cache.add_link(
make_relation("dbt", "schema", "h"), make_relation("dbt", "schema", "i")
)
self.cache.add_link(make_relation('dbt', 'schema', 'a'),
make_relation('dbt', 'schema', 'j'))
self.cache.add_link(make_relation('dbt', 'schema', 'j'),
make_relation('dbt', 'schema', 'k'))
self.cache.add_link(make_relation('dbt', 'schema', 'k'),
make_relation('dbt', 'schema', 'l'))
self.cache.add_link(
make_relation("dbt", "schema", "a"), make_relation("dbt", "schema", "j")
)
self.cache.add_link(
make_relation("dbt", "schema", "j"), make_relation("dbt", "schema", "k")
)
self.cache.add_link(
make_relation("dbt", "schema", "k"), make_relation("dbt", "schema", "l")
)
self.cache.add_link(make_relation('dbt', 'schema', 'a'),
make_relation('dbt', 'schema', 'm'))
self.cache.add_link(make_relation('dbt', 'schema', 'm'),
make_relation('dbt', 'schema', 'n'))
self.cache.add_link(make_relation('dbt', 'schema', 'n'),
make_relation('dbt', 'schema', 'o'))
self.cache.add_link(
make_relation("dbt", "schema", "a"), make_relation("dbt", "schema", "m")
)
self.cache.add_link(
make_relation("dbt", "schema", "m"), make_relation("dbt", "schema", "n")
)
self.cache.add_link(
make_relation("dbt", "schema", "n"), make_relation("dbt", "schema", "o")
)
pool = ThreadPool(4)
results = list(pool.imap_unordered(self._target, ('b', 'g', 'j', 'm')))
results = list(pool.imap_unordered(self._target, ("b", "g", "j", "m")))
pool.close()
pool.join()
# at a minimum, we expect each table to "see" itself, its parent ('a'),
# and the unrelated table ('a')
min_expect = {
'b': {'a', 'b', 'e'},
'g': {'a', 'g', 'e'},
'j': {'a', 'j', 'e'},
'm': {'a', 'm', 'e'},
"b": {"a", "b", "e"},
"g": {"a", "g", "e"},
"j": {"a", "j", "e"},
"m": {"a", "m", "e"},
}
for ident, relations in results:
seen = set(r.identifier for r in relations)
self.assertTrue(min_expect[ident].issubset(seen))
self.assert_has_relations(set('abgjme'))
self.assert_has_relations(set("abgjme"))
def test_threaded_repeated(self):
for _ in range(10):
@@ -398,13 +428,13 @@ class TestComplexCache(TestCase):
def setUp(self):
self.cache = RelationsCache()
inputs = [
('dbt', 'foo', 'table1'),
('dbt', 'foo', 'table3'),
('dbt', 'foo', 'table4'),
('dbt', 'bar', 'table2'),
('dbt', 'bar', 'table3'),
('dbt_2', 'foo', 'table1'),
('dbt_2', 'foo', 'table2'),
("dbt", "foo", "table1"),
("dbt", "foo", "table3"),
("dbt", "foo", "table4"),
("dbt", "bar", "table2"),
("dbt", "bar", "table3"),
("dbt_2", "foo", "table1"),
("dbt_2", "foo", "table2"),
]
self.inputs = [make_relation(d, s, i) for d, s, i in inputs]
for relation in self.inputs:
@@ -413,79 +443,78 @@ class TestComplexCache(TestCase):
# dbt.foo.table3 references dbt.foo.table1
# (create view dbt.foo.table3 as (select * from dbt.foo.table1...))
self.cache.add_link(
make_relation('dbt', 'foo', 'table1'),
make_relation('dbt', 'foo', 'table3')
make_relation("dbt", "foo", "table1"), make_relation("dbt", "foo", "table3")
)
# dbt.bar.table3 references dbt.foo.table3
# (create view dbt.bar.table5 as (select * from dbt.foo.table3...))
self.cache.add_link(
make_relation('dbt', 'foo', 'table3'),
make_relation('dbt', 'bar', 'table3')
make_relation("dbt", "foo", "table3"), make_relation("dbt", "bar", "table3")
)
# dbt.foo.table4 also references dbt.foo.table1
self.cache.add_link(
make_relation('dbt', 'foo', 'table1'),
make_relation('dbt', 'foo', 'table4')
make_relation("dbt", "foo", "table1"), make_relation("dbt", "foo", "table4")
)
# and dbt_2.foo.table1 references dbt.foo.table1
self.cache.add_link(
make_relation('dbt', 'foo', 'table1'),
make_relation('dbt_2', 'foo', 'table1'),
make_relation("dbt", "foo", "table1"),
make_relation("dbt_2", "foo", "table1"),
)
def test_get_relations(self):
self.assertEqual(len(self.cache.get_relations('dbt', 'foo')), 3)
self.assertEqual(len(self.cache.get_relations('dbt', 'bar')), 2)
self.assertEqual(len(self.cache.get_relations('dbt_2', 'foo')), 2)
self.assertEqual(len(self.cache.get_relations("dbt", "foo")), 3)
self.assertEqual(len(self.cache.get_relations("dbt", "bar")), 2)
self.assertEqual(len(self.cache.get_relations("dbt_2", "foo")), 2)
self.assertEqual(len(self.cache.relations), 7)
def test_drop_one(self):
# dropping dbt.bar.table2 should only drop itself
self.cache.drop(make_relation('dbt', 'bar', 'table2'))
self.assertEqual(len(self.cache.get_relations('dbt', 'foo')), 3)
self.assertEqual(len(self.cache.get_relations('dbt', 'bar')), 1)
self.assertEqual(len(self.cache.get_relations('dbt_2', 'foo')), 2)
self.cache.drop(make_relation("dbt", "bar", "table2"))
self.assertEqual(len(self.cache.get_relations("dbt", "foo")), 3)
self.assertEqual(len(self.cache.get_relations("dbt", "bar")), 1)
self.assertEqual(len(self.cache.get_relations("dbt_2", "foo")), 2)
self.assertEqual(len(self.cache.relations), 6)
def test_drop_many(self):
# dropping dbt.foo.table1 should drop everything but dbt.bar.table2 and
# dbt_2.foo.table2
self.cache.drop(make_relation('dbt', 'foo', 'table1'))
self.assertEqual(len(self.cache.get_relations('dbt', 'foo')), 0)
self.assertEqual(len(self.cache.get_relations('dbt', 'bar')), 1)
self.assertEqual(len(self.cache.get_relations('dbt_2', 'foo')), 1)
self.cache.drop(make_relation("dbt", "foo", "table1"))
self.assertEqual(len(self.cache.get_relations("dbt", "foo")), 0)
self.assertEqual(len(self.cache.get_relations("dbt", "bar")), 1)
self.assertEqual(len(self.cache.get_relations("dbt_2", "foo")), 1)
self.assertEqual(len(self.cache.relations), 2)
def test_rename_root(self):
self.cache.rename(make_relation('dbt', 'foo', 'table1'),
make_relation('dbt', 'bar', 'table1'))
retrieved = self.cache.relations[('dbt', 'bar', 'table1')].inner
self.assertEqual(retrieved.schema, 'bar')
self.assertEqual(retrieved.identifier, 'table1')
self.assertEqual(len(self.cache.get_relations('dbt', 'foo')), 2)
self.assertEqual(len(self.cache.get_relations('dbt', 'bar')), 3)
self.assertEqual(len(self.cache.get_relations('dbt_2', 'foo')), 2)
self.cache.rename(
make_relation("dbt", "foo", "table1"), make_relation("dbt", "bar", "table1")
)
retrieved = self.cache.relations[("dbt", "bar", "table1")].inner
self.assertEqual(retrieved.schema, "bar")
self.assertEqual(retrieved.identifier, "table1")
self.assertEqual(len(self.cache.get_relations("dbt", "foo")), 2)
self.assertEqual(len(self.cache.get_relations("dbt", "bar")), 3)
self.assertEqual(len(self.cache.get_relations("dbt_2", "foo")), 2)
self.assertEqual(len(self.cache.relations), 7)
# make sure drops still cascade from the renamed table
self.cache.drop(make_relation('dbt', 'bar', 'table1'))
self.assertEqual(len(self.cache.get_relations('dbt', 'foo')), 0)
self.assertEqual(len(self.cache.get_relations('dbt', 'bar')), 1)
self.assertEqual(len(self.cache.get_relations('dbt_2', 'foo')), 1)
self.cache.drop(make_relation("dbt", "bar", "table1"))
self.assertEqual(len(self.cache.get_relations("dbt", "foo")), 0)
self.assertEqual(len(self.cache.get_relations("dbt", "bar")), 1)
self.assertEqual(len(self.cache.get_relations("dbt_2", "foo")), 1)
self.assertEqual(len(self.cache.relations), 2)
def test_rename_branch(self):
self.cache.rename(make_relation('dbt', 'foo', 'table3'),
make_relation('dbt', 'foo', 'table2'))
self.assertEqual(len(self.cache.get_relations('dbt', 'foo')), 3)
self.assertEqual(len(self.cache.get_relations('dbt', 'bar')), 2)
self.assertEqual(len(self.cache.get_relations('dbt_2', 'foo')), 2)
self.cache.rename(
make_relation("dbt", "foo", "table3"), make_relation("dbt", "foo", "table2")
)
self.assertEqual(len(self.cache.get_relations("dbt", "foo")), 3)
self.assertEqual(len(self.cache.get_relations("dbt", "bar")), 2)
self.assertEqual(len(self.cache.get_relations("dbt_2", "foo")), 2)
# make sure drops still cascade through the renamed table
self.cache.drop(make_relation('dbt', 'foo', 'table1'))
self.assertEqual(len(self.cache.get_relations('dbt', 'foo')), 0)
self.assertEqual(len(self.cache.get_relations('dbt', 'bar')), 1)
self.assertEqual(len(self.cache.get_relations('dbt_2', 'foo')), 1)
self.cache.drop(make_relation("dbt", "foo", "table1"))
self.assertEqual(len(self.cache.get_relations("dbt", "foo")), 0)
self.assertEqual(len(self.cache.get_relations("dbt", "bar")), 1)
self.assertEqual(len(self.cache.get_relations("dbt_2", "foo")), 1)
self.assertEqual(len(self.cache.relations), 2)

File diff suppressed because it is too large Load Diff

View File

@@ -8,7 +8,6 @@ import pytest
from dbt.adapters import postgres
from dbt.adapters import factory
from dbt.adapters.base import AdapterConfig
from dbt.clients.jinja import MacroStack
from dbt.contracts.graph.nodes import (
ModelNode,
@@ -17,13 +16,12 @@ from dbt.contracts.graph.nodes import (
Macro,
)
from dbt.config.project import VarProvider
from dbt.context import base, target, configured, providers, docs, manifest, macros
from dbt.context import base, providers, docs, manifest, macros
from dbt.contracts.files import FileHash
from dbt.events.functions import reset_metadata_vars
from dbt.node_types import NodeType
import dbt.exceptions
from .utils import (
profile_from_dict,
config_from_parts_or_dicts,
inject_adapter,
clear_plugin,
@@ -61,7 +59,7 @@ class TestVar(unittest.TestCase):
),
tags=[],
path="model_one.sql",
language='sql',
language="sql",
raw_code="",
description="",
columns={},
@@ -201,14 +199,16 @@ REQUIRED_BASE_KEYS = frozenset(
"flags",
"print",
"diff_of_two_dicts",
"local_md5"
"local_md5",
}
)
REQUIRED_TARGET_KEYS = REQUIRED_BASE_KEYS | {"target"}
REQUIRED_DOCS_KEYS = REQUIRED_TARGET_KEYS | {"project_name"} | {"doc"}
MACROS = frozenset({"macro_a", "macro_b", "root", "dbt"})
REQUIRED_QUERY_HEADER_KEYS = REQUIRED_TARGET_KEYS | {"project_name", "context_macro_stack"} | MACROS
REQUIRED_QUERY_HEADER_KEYS = (
REQUIRED_TARGET_KEYS | {"project_name", "context_macro_stack"} | MACROS
)
REQUIRED_MACRO_KEYS = REQUIRED_QUERY_HEADER_KEYS | {
"_sql_results",
"load_result",
@@ -241,7 +241,7 @@ REQUIRED_MACRO_KEYS = REQUIRED_QUERY_HEADER_KEYS | {
"selected_resources",
"invocation_args_dict",
"submit_python_job",
"dbt_metadata_envs"
"dbt_metadata_envs",
}
REQUIRED_MODEL_KEYS = REQUIRED_MACRO_KEYS | {"this", "compiled_code"}
MAYBE_KEYS = frozenset({"debug"})
@@ -301,7 +301,7 @@ def model():
),
tags=[],
path="model_one.sql",
language='sql',
language="sql",
raw_code="",
description="",
columns={},
@@ -363,7 +363,7 @@ def mock_model():
),
tags=[],
path="model_one.sql",
language='sql',
language="sql",
raw_code="",
description="",
columns={},
@@ -419,6 +419,7 @@ def test_macro_runtime_context(config_postgres, manifest_fx, get_adapter, get_in
)
assert_has_keys(REQUIRED_MACRO_KEYS, MAYBE_KEYS, ctx)
def test_invocation_args_to_dict_in_macro_runtime_context(
config_postgres, manifest_fx, get_adapter, get_include_paths
):
@@ -435,6 +436,7 @@ def test_invocation_args_to_dict_in_macro_runtime_context(
# Comes from unit/utils.py config_from_parts_or_dicts method
assert ctx["invocation_args_dict"]["profile_dir"] == "/dev/null"
def test_model_parse_context(config_postgres, manifest_fx, get_adapter, get_include_paths):
ctx = providers.generate_parser_model_context(
model=mock_model(),
@@ -500,25 +502,28 @@ def test_macro_namespace(config_postgres, manifest_fx):
assert result["root"]["some_macro"].macro is package_macro
assert result["some_macro"].macro is package_macro
def test_dbt_metadata_envs(monkeypatch, config_postgres, manifest_fx, get_adapter, get_include_paths):
def test_dbt_metadata_envs(
monkeypatch, config_postgres, manifest_fx, get_adapter, get_include_paths
):
reset_metadata_vars()
envs = {
"DBT_ENV_CUSTOM_ENV_RUN_ID": 1234,
"DBT_ENV_CUSTOM_ENV_JOB_ID": 5678,
"DBT_ENV_RUN_ID": 91011,
"RANDOM_ENV": 121314
"RANDOM_ENV": 121314,
}
monkeypatch.setattr(os, 'environ', envs)
monkeypatch.setattr(os, "environ", envs)
ctx = providers.generate_runtime_macro_context(
macro=manifest_fx.macros["macro.root.macro_a"],
config=config_postgres,
manifest=manifest_fx,
package_name="root",
)
)
assert ctx["dbt_metadata_envs"] == {'JOB_ID': 5678, 'RUN_ID': 1234}
assert ctx["dbt_metadata_envs"] == {"JOB_ID": 5678, "RUN_ID": 1234}
# cleanup
reset_metadata_vars()

View File

@@ -2,216 +2,222 @@ import pickle
import pytest
from dbt.contracts.files import FileHash
from dbt.contracts.graph.nodes import (
ModelNode, InjectedCTE, GenericTestNode
)
from dbt.contracts.graph.nodes import (
DependsOn, NodeConfig, TestConfig, TestMetadata, ColumnInfo
)
from dbt.contracts.graph.nodes import ModelNode, InjectedCTE, GenericTestNode
from dbt.contracts.graph.nodes import DependsOn, NodeConfig, TestConfig, TestMetadata, ColumnInfo
from dbt.node_types import NodeType
from .utils import (
assert_symmetric,
assert_from_dict,
assert_fails_validation,
dict_replace,
replace_config,
compare_dicts,
)
@pytest.fixture
def basic_uncompiled_model():
return ModelNode(
package_name='test',
path='/root/models/foo.sql',
original_file_path='models/foo.sql',
language='sql',
package_name="test",
path="/root/models/foo.sql",
original_file_path="models/foo.sql",
language="sql",
raw_code='select * from {{ ref("other") }}',
name='foo',
name="foo",
resource_type=NodeType.Model,
unique_id='model.test.foo',
fqn=['test', 'models', 'foo'],
unique_id="model.test.foo",
fqn=["test", "models", "foo"],
refs=[],
sources=[],
metrics=[],
depends_on=DependsOn(),
deferred=False,
description='',
database='test_db',
schema='test_schema',
alias='bar',
description="",
database="test_db",
schema="test_schema",
alias="bar",
tags=[],
config=NodeConfig(),
meta={},
compiled=False,
extra_ctes=[],
extra_ctes_injected=False,
checksum=FileHash.from_contents(''),
unrendered_config={}
checksum=FileHash.from_contents(""),
unrendered_config={},
)
@pytest.fixture
def basic_compiled_model():
return ModelNode(
package_name='test',
path='/root/models/foo.sql',
original_file_path='models/foo.sql',
language='sql',
package_name="test",
path="/root/models/foo.sql",
original_file_path="models/foo.sql",
language="sql",
raw_code='select * from {{ ref("other") }}',
name='foo',
name="foo",
resource_type=NodeType.Model,
unique_id='model.test.foo',
fqn=['test', 'models', 'foo'],
unique_id="model.test.foo",
fqn=["test", "models", "foo"],
refs=[],
sources=[],
metrics=[],
depends_on=DependsOn(),
deferred=True,
description='',
database='test_db',
schema='test_schema',
alias='bar',
description="",
database="test_db",
schema="test_schema",
alias="bar",
tags=[],
config=NodeConfig(),
contract=False,
meta={},
compiled=True,
extra_ctes=[InjectedCTE('whatever', 'select * from other')],
extra_ctes=[InjectedCTE("whatever", "select * from other")],
extra_ctes_injected=True,
compiled_code='with whatever as (select * from other) select * from whatever',
checksum=FileHash.from_contents(''),
unrendered_config={}
compiled_code="with whatever as (select * from other) select * from whatever",
checksum=FileHash.from_contents(""),
unrendered_config={},
)
@pytest.fixture
def minimal_uncompiled_dict():
return {
'name': 'foo',
'created_at': 1,
'resource_type': str(NodeType.Model),
'path': '/root/models/foo.sql',
'original_file_path': 'models/foo.sql',
'package_name': 'test',
'language': 'sql',
'raw_code': 'select * from {{ ref("other") }}',
'unique_id': 'model.test.foo',
'fqn': ['test', 'models', 'foo'],
'database': 'test_db',
'schema': 'test_schema',
'alias': 'bar',
'compiled': False,
'checksum': {'name': 'sha256', 'checksum': 'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'},
'unrendered_config': {}
"name": "foo",
"created_at": 1,
"resource_type": str(NodeType.Model),
"path": "/root/models/foo.sql",
"original_file_path": "models/foo.sql",
"package_name": "test",
"language": "sql",
"raw_code": 'select * from {{ ref("other") }}',
"unique_id": "model.test.foo",
"fqn": ["test", "models", "foo"],
"database": "test_db",
"schema": "test_schema",
"alias": "bar",
"compiled": False,
"checksum": {
"name": "sha256",
"checksum": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
},
"unrendered_config": {},
}
@pytest.fixture
def basic_uncompiled_dict():
return {
'name': 'foo',
'created_at': 1,
'resource_type': str(NodeType.Model),
'path': '/root/models/foo.sql',
'original_file_path': 'models/foo.sql',
'package_name': 'test',
'language': 'sql',
'raw_code': 'select * from {{ ref("other") }}',
'unique_id': 'model.test.foo',
'fqn': ['test', 'models', 'foo'],
'refs': [],
'sources': [],
'metrics': [],
'depends_on': {'macros': [], 'nodes': []},
'database': 'test_db',
'deferred': False,
'description': '',
'schema': 'test_schema',
'alias': 'bar',
'tags': [],
'config': {
'column_types': {},
'enabled': True,
'materialized': 'view',
'persist_docs': {},
'post-hook': [],
'pre-hook': [],
'quoting': {},
'tags': [],
'on_schema_change': 'ignore',
'meta': {},
'grants': {},
'packages': [],
"name": "foo",
"created_at": 1,
"resource_type": str(NodeType.Model),
"path": "/root/models/foo.sql",
"original_file_path": "models/foo.sql",
"package_name": "test",
"language": "sql",
"raw_code": 'select * from {{ ref("other") }}',
"unique_id": "model.test.foo",
"fqn": ["test", "models", "foo"],
"refs": [],
"sources": [],
"metrics": [],
"depends_on": {"macros": [], "nodes": []},
"database": "test_db",
"deferred": False,
"description": "",
"schema": "test_schema",
"alias": "bar",
"tags": [],
"config": {
"column_types": {},
"enabled": True,
"materialized": "view",
"persist_docs": {},
"post-hook": [],
"pre-hook": [],
"quoting": {},
"tags": [],
"on_schema_change": "ignore",
"meta": {},
"grants": {},
"packages": [],
},
'docs': {'show': True},
'columns': {},
'meta': {},
'compiled': False,
'extra_ctes': [],
'extra_ctes_injected': False,
'checksum': {'name': 'sha256', 'checksum': 'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'},
'unrendered_config': {},
'config_call_dict': {},
"docs": {"show": True},
"columns": {},
"meta": {},
"compiled": False,
"extra_ctes": [],
"extra_ctes_injected": False,
"checksum": {
"name": "sha256",
"checksum": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
},
"unrendered_config": {},
"config_call_dict": {},
}
@pytest.fixture
def basic_compiled_dict():
return {
'name': 'foo',
'created_at': 1,
'resource_type': str(NodeType.Model),
'path': '/root/models/foo.sql',
'original_file_path': 'models/foo.sql',
'package_name': 'test',
'language':'sql',
'raw_code': 'select * from {{ ref("other") }}',
'unique_id': 'model.test.foo',
'fqn': ['test', 'models', 'foo'],
'refs': [],
'sources': [],
'metrics': [],
'depends_on': {'macros': [], 'nodes': []},
'database': 'test_db',
'deferred': True,
'description': '',
'schema': 'test_schema',
'alias': 'bar',
'tags': [],
'config': {
'column_types': {},
'enabled': True,
'materialized': 'view',
'persist_docs': {},
'post-hook': [],
'pre-hook': [],
'quoting': {},
'tags': [],
'on_schema_change': 'ignore',
'meta': {},
'grants': {},
'packages': [],
'contract': False,
'docs': {'show': True},
"name": "foo",
"created_at": 1,
"resource_type": str(NodeType.Model),
"path": "/root/models/foo.sql",
"original_file_path": "models/foo.sql",
"package_name": "test",
"language": "sql",
"raw_code": 'select * from {{ ref("other") }}',
"unique_id": "model.test.foo",
"fqn": ["test", "models", "foo"],
"refs": [],
"sources": [],
"metrics": [],
"depends_on": {"macros": [], "nodes": []},
"database": "test_db",
"deferred": True,
"description": "",
"schema": "test_schema",
"alias": "bar",
"tags": [],
"config": {
"column_types": {},
"enabled": True,
"materialized": "view",
"persist_docs": {},
"post-hook": [],
"pre-hook": [],
"quoting": {},
"tags": [],
"on_schema_change": "ignore",
"meta": {},
"grants": {},
"packages": [],
"contract": False,
"docs": {"show": True},
},
'docs': {'show': True},
'columns': {},
'contract': False,
'meta': {},
'compiled': True,
'extra_ctes': [{'id': 'whatever', 'sql': 'select * from other'}],
'extra_ctes_injected': True,
'compiled_code': 'with whatever as (select * from other) select * from whatever',
'checksum': {'name': 'sha256', 'checksum': 'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'},
'unrendered_config': {},
'config_call_dict': {},
'access': 'protected',
"docs": {"show": True},
"columns": {},
"contract": False,
"meta": {},
"compiled": True,
"extra_ctes": [{"id": "whatever", "sql": "select * from other"}],
"extra_ctes_injected": True,
"compiled_code": "with whatever as (select * from other) select * from whatever",
"checksum": {
"name": "sha256",
"checksum": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
},
"unrendered_config": {},
"config_call_dict": {},
"access": "protected",
}
@pytest.mark.skip("Haven't found where we would use uncompiled node")
def test_basic_uncompiled_model(minimal_uncompiled_dict, basic_uncompiled_dict, basic_uncompiled_model):
def test_basic_uncompiled_model(
minimal_uncompiled_dict, basic_uncompiled_dict, basic_uncompiled_model
):
node_dict = basic_uncompiled_dict
node = basic_uncompiled_model
assert_symmetric(node, node_dict, ModelNode)
@@ -234,90 +240,114 @@ def test_basic_compiled_model(basic_compiled_dict, basic_compiled_model):
def test_invalid_extra_fields_model(minimal_uncompiled_dict):
bad_extra = minimal_uncompiled_dict
bad_extra['notvalid'] = 'nope'
bad_extra["notvalid"] = "nope"
assert_fails_validation(bad_extra, ModelNode)
def test_invalid_bad_type_model(minimal_uncompiled_dict):
bad_type = minimal_uncompiled_dict
bad_type['resource_type'] = str(NodeType.Macro)
bad_type["resource_type"] = str(NodeType.Macro)
assert_fails_validation(bad_type, ModelNode)
unchanged_compiled_models = [
lambda u: (u, u.replace(description='a description')),
lambda u: (u, u.replace(tags=['mytag'])),
lambda u: (u, u.replace(meta={'cool_key': 'cool value'})),
lambda u: (u, u.replace(description="a description")),
lambda u: (u, u.replace(tags=["mytag"])),
lambda u: (u, u.replace(meta={"cool_key": "cool value"})),
# changing the final alias/schema/datbase isn't a change - could just be target changing!
lambda u: (u, u.replace(database='nope')),
lambda u: (u, u.replace(schema='nope')),
lambda u: (u, u.replace(alias='nope')),
lambda u: (u, u.replace(database="nope")),
lambda u: (u, u.replace(schema="nope")),
lambda u: (u, u.replace(alias="nope")),
# None -> False is a config change even though it's pretty much the same
lambda u: (u.replace(config=u.config.replace(persist_docs={'relation': False})), u.replace(
config=u.config.replace(persist_docs={'relation': False}))),
lambda u: (u.replace(config=u.config.replace(persist_docs={'columns': False})), u.replace(
config=u.config.replace(persist_docs={'columns': False}))),
lambda u: (
u.replace(config=u.config.replace(persist_docs={"relation": False})),
u.replace(config=u.config.replace(persist_docs={"relation": False})),
),
lambda u: (
u.replace(config=u.config.replace(persist_docs={"columns": False})),
u.replace(config=u.config.replace(persist_docs={"columns": False})),
),
# True -> True
lambda u: (u.replace(config=u.config.replace(persist_docs={'relation': True})), u.replace(
config=u.config.replace(persist_docs={'relation': True}))),
lambda u: (u.replace(config=u.config.replace(persist_docs={'columns': True})), u.replace(
config=u.config.replace(persist_docs={'columns': True}))),
lambda u: (
u.replace(config=u.config.replace(persist_docs={"relation": True})),
u.replace(config=u.config.replace(persist_docs={"relation": True})),
),
lambda u: (
u.replace(config=u.config.replace(persist_docs={"columns": True})),
u.replace(config=u.config.replace(persist_docs={"columns": True})),
),
# only columns docs enabled, but description changed
lambda u: (u.replace(config=u.config.replace(persist_docs={'columns': True})), u.replace(
config=u.config.replace(persist_docs={'columns': True}), description='a model description')),
lambda u: (
u.replace(config=u.config.replace(persist_docs={"columns": True})),
u.replace(
config=u.config.replace(persist_docs={"columns": True}),
description="a model description",
),
),
# only relation docs eanbled, but columns changed
lambda u: (u.replace(config=u.config.replace(persist_docs={'relation': True})), u.replace(config=u.config.replace(
persist_docs={'relation': True}), columns={'a': ColumnInfo(name='a', description='a column description')}))
lambda u: (
u.replace(config=u.config.replace(persist_docs={"relation": True})),
u.replace(
config=u.config.replace(persist_docs={"relation": True}),
columns={"a": ColumnInfo(name="a", description="a column description")},
),
),
]
changed_compiled_models = [
lambda u: (u, None),
lambda u: (u, u.replace(raw_code='select * from wherever')),
lambda u: (u, u.replace(fqn=['test', 'models', 'subdir', 'foo'],
original_file_path='models/subdir/foo.sql', path='/root/models/subdir/foo.sql')),
lambda u: (u, u.replace(raw_code="select * from wherever")),
lambda u: (
u,
u.replace(
fqn=["test", "models", "subdir", "foo"],
original_file_path="models/subdir/foo.sql",
path="/root/models/subdir/foo.sql",
),
),
lambda u: (u, replace_config(u, full_refresh=True)),
lambda u: (u, replace_config(u, post_hook=['select 1 as id'])),
lambda u: (u, replace_config(u, pre_hook=['select 1 as id'])),
lambda u: (u, replace_config(
u, quoting={'database': True, 'schema': False, 'identifier': False})),
lambda u: (u, replace_config(u, post_hook=["select 1 as id"])),
lambda u: (u, replace_config(u, pre_hook=["select 1 as id"])),
lambda u: (
u,
replace_config(u, quoting={"database": True, "schema": False, "identifier": False}),
),
# we changed persist_docs values
lambda u: (u, replace_config(u, persist_docs={'relation': True})),
lambda u: (u, replace_config(u, persist_docs={'columns': True})),
lambda u: (u, replace_config(u, persist_docs={
'columns': True, 'relation': True})),
lambda u: (u, replace_config(u, persist_docs={"relation": True})),
lambda u: (u, replace_config(u, persist_docs={"columns": True})),
lambda u: (u, replace_config(u, persist_docs={"columns": True, "relation": True})),
# None -> False is a config change even though it's pretty much the same
lambda u: (u, replace_config(u, persist_docs={'relation': False})),
lambda u: (u, replace_config(u, persist_docs={'columns': False})),
lambda u: (u, replace_config(u, persist_docs={"relation": False})),
lambda u: (u, replace_config(u, persist_docs={"columns": False})),
# persist docs was true for the relation and we changed the model description
lambda u: (
replace_config(u, persist_docs={'relation': True}),
replace_config(u, persist_docs={
'relation': True}, description='a model description'),
replace_config(u, persist_docs={"relation": True}),
replace_config(u, persist_docs={"relation": True}, description="a model description"),
),
# persist docs was true for columns and we changed the model description
lambda u: (
replace_config(u, persist_docs={'columns': True}),
replace_config(u, persist_docs={'columns': True}, columns={
'a': ColumnInfo(name='a', description='a column description')})
replace_config(u, persist_docs={"columns": True}),
replace_config(
u,
persist_docs={"columns": True},
columns={"a": ColumnInfo(name="a", description="a column description")},
),
),
# changing alias/schema/database on the config level is a change
lambda u: (u, replace_config(u, database='nope')),
lambda u: (u, replace_config(u, schema='nope')),
lambda u: (u, replace_config(u, alias='nope')),
lambda u: (u, replace_config(u, database="nope")),
lambda u: (u, replace_config(u, schema="nope")),
lambda u: (u, replace_config(u, alias="nope")),
]
@pytest.mark.parametrize('func', unchanged_compiled_models)
@pytest.mark.parametrize("func", unchanged_compiled_models)
def test_compare_unchanged_model(func, basic_uncompiled_model):
node, compare = func(basic_uncompiled_model)
assert node.same_contents(compare)
@pytest.mark.parametrize('func', changed_compiled_models)
@pytest.mark.parametrize("func", changed_compiled_models)
def test_compare_changed_model(func, basic_uncompiled_model):
node, compare = func(basic_uncompiled_model)
assert not node.same_contents(compare)
@@ -326,205 +356,217 @@ def test_compare_changed_model(func, basic_uncompiled_model):
@pytest.fixture
def minimal_schema_test_dict():
return {
'name': 'foo',
'created_at': 1,
'resource_type': str(NodeType.Test),
'path': '/root/x/path.sql',
'original_file_path': '/root/path.sql',
'package_name': 'test',
'language': 'sql',
'raw_code': 'select * from {{ ref("other") }}',
'unique_id': 'model.test.foo',
'fqn': ['test', 'models', 'foo'],
'database': 'test_db',
'schema': 'dbt_test__audit',
'alias': 'bar',
'test_metadata': {
'name': 'foo',
'kwargs': {},
"name": "foo",
"created_at": 1,
"resource_type": str(NodeType.Test),
"path": "/root/x/path.sql",
"original_file_path": "/root/path.sql",
"package_name": "test",
"language": "sql",
"raw_code": 'select * from {{ ref("other") }}',
"unique_id": "model.test.foo",
"fqn": ["test", "models", "foo"],
"database": "test_db",
"schema": "dbt_test__audit",
"alias": "bar",
"test_metadata": {
"name": "foo",
"kwargs": {},
},
"compiled": False,
"checksum": {
"name": "sha256",
"checksum": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
},
'compiled': False,
'checksum': {'name': 'sha256', 'checksum': 'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'},
}
@pytest.fixture
def basic_uncompiled_schema_test_node():
return GenericTestNode(
package_name='test',
path='/root/x/path.sql',
original_file_path='/root/path.sql',
language='sql',
package_name="test",
path="/root/x/path.sql",
original_file_path="/root/path.sql",
language="sql",
raw_code='select * from {{ ref("other") }}',
name='foo',
name="foo",
resource_type=NodeType.Test,
unique_id='model.test.foo',
fqn=['test', 'models', 'foo'],
unique_id="model.test.foo",
fqn=["test", "models", "foo"],
refs=[],
sources=[],
metrics=[],
deferred=False,
depends_on=DependsOn(),
description='',
database='test_db',
schema='dbt_test__audit',
alias='bar',
description="",
database="test_db",
schema="dbt_test__audit",
alias="bar",
tags=[],
config=TestConfig(),
meta={},
compiled=False,
extra_ctes=[],
extra_ctes_injected=False,
test_metadata=TestMetadata(namespace=None, name='foo', kwargs={}),
checksum=FileHash.from_contents(''),
unrendered_config={}
test_metadata=TestMetadata(namespace=None, name="foo", kwargs={}),
checksum=FileHash.from_contents(""),
unrendered_config={},
)
@pytest.fixture
def basic_compiled_schema_test_node():
return GenericTestNode(
package_name='test',
path='/root/x/path.sql',
original_file_path='/root/path.sql',
language='sql',
package_name="test",
path="/root/x/path.sql",
original_file_path="/root/path.sql",
language="sql",
raw_code='select * from {{ ref("other") }}',
name='foo',
name="foo",
resource_type=NodeType.Test,
unique_id='model.test.foo',
fqn=['test', 'models', 'foo'],
unique_id="model.test.foo",
fqn=["test", "models", "foo"],
refs=[],
sources=[],
metrics=[],
depends_on=DependsOn(),
deferred=False,
description='',
database='test_db',
schema='dbt_test__audit',
alias='bar',
description="",
database="test_db",
schema="dbt_test__audit",
alias="bar",
tags=[],
config=TestConfig(severity='warn'),
config=TestConfig(severity="warn"),
contract=False,
meta={},
compiled=True,
extra_ctes=[InjectedCTE('whatever', 'select * from other')],
extra_ctes=[InjectedCTE("whatever", "select * from other")],
extra_ctes_injected=True,
compiled_code='with whatever as (select * from other) select * from whatever',
column_name='id',
test_metadata=TestMetadata(namespace=None, name='foo', kwargs={}),
checksum=FileHash.from_contents(''),
compiled_code="with whatever as (select * from other) select * from whatever",
column_name="id",
test_metadata=TestMetadata(namespace=None, name="foo", kwargs={}),
checksum=FileHash.from_contents(""),
unrendered_config={
'severity': 'warn',
}
"severity": "warn",
},
)
@pytest.fixture
def basic_uncompiled_schema_test_dict():
return {
'name': 'foo',
'created_at': 1,
'resource_type': str(NodeType.Test),
'path': '/root/x/path.sql',
'original_file_path': '/root/path.sql',
'package_name': 'test',
'language': 'sql',
'raw_code': 'select * from {{ ref("other") }}',
'unique_id': 'model.test.foo',
'fqn': ['test', 'models', 'foo'],
'refs': [],
'sources': [],
'metrics': [],
'depends_on': {'macros': [], 'nodes': []},
'database': 'test_db',
'description': '',
'schema': 'dbt_test__audit',
'alias': 'bar',
'tags': [],
'config': {
'enabled': True,
'materialized': 'test',
'tags': [],
'severity': 'ERROR',
'schema': 'dbt_test__audit',
'warn_if': '!= 0',
'error_if': '!= 0',
'fail_calc': 'count(*)',
'meta': {},
"name": "foo",
"created_at": 1,
"resource_type": str(NodeType.Test),
"path": "/root/x/path.sql",
"original_file_path": "/root/path.sql",
"package_name": "test",
"language": "sql",
"raw_code": 'select * from {{ ref("other") }}',
"unique_id": "model.test.foo",
"fqn": ["test", "models", "foo"],
"refs": [],
"sources": [],
"metrics": [],
"depends_on": {"macros": [], "nodes": []},
"database": "test_db",
"description": "",
"schema": "dbt_test__audit",
"alias": "bar",
"tags": [],
"config": {
"enabled": True,
"materialized": "test",
"tags": [],
"severity": "ERROR",
"schema": "dbt_test__audit",
"warn_if": "!= 0",
"error_if": "!= 0",
"fail_calc": "count(*)",
"meta": {},
},
'deferred': False,
'docs': {'show': True},
'columns': {},
'meta': {},
'compiled': False,
'extra_ctes': [],
'extra_ctes_injected': False,
'test_metadata': {
'name': 'foo',
'kwargs': {},
"deferred": False,
"docs": {"show": True},
"columns": {},
"meta": {},
"compiled": False,
"extra_ctes": [],
"extra_ctes_injected": False,
"test_metadata": {
"name": "foo",
"kwargs": {},
},
'checksum': {'name': 'sha256', 'checksum': 'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'},
'unrendered_config': {},
'config_call_dict': {},
"checksum": {
"name": "sha256",
"checksum": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
},
"unrendered_config": {},
"config_call_dict": {},
}
@pytest.fixture
def basic_compiled_schema_test_dict():
return {
'name': 'foo',
'created_at': 1,
'resource_type': str(NodeType.Test),
'path': '/root/x/path.sql',
'original_file_path': '/root/path.sql',
'package_name': 'test',
'language': 'sql',
'raw_code': 'select * from {{ ref("other") }}',
'unique_id': 'model.test.foo',
'fqn': ['test', 'models', 'foo'],
'refs': [],
'sources': [],
'metrics': [],
'depends_on': {'macros': [], 'nodes': []},
'deferred': False,
'database': 'test_db',
'description': '',
'schema': 'dbt_test__audit',
'alias': 'bar',
'tags': [],
'config': {
'enabled': True,
'materialized': 'test',
'tags': [],
'severity': 'warn',
'schema': 'dbt_test__audit',
'warn_if': '!= 0',
'error_if': '!= 0',
'fail_calc': 'count(*)',
'meta': {},
"name": "foo",
"created_at": 1,
"resource_type": str(NodeType.Test),
"path": "/root/x/path.sql",
"original_file_path": "/root/path.sql",
"package_name": "test",
"language": "sql",
"raw_code": 'select * from {{ ref("other") }}',
"unique_id": "model.test.foo",
"fqn": ["test", "models", "foo"],
"refs": [],
"sources": [],
"metrics": [],
"depends_on": {"macros": [], "nodes": []},
"deferred": False,
"database": "test_db",
"description": "",
"schema": "dbt_test__audit",
"alias": "bar",
"tags": [],
"config": {
"enabled": True,
"materialized": "test",
"tags": [],
"severity": "warn",
"schema": "dbt_test__audit",
"warn_if": "!= 0",
"error_if": "!= 0",
"fail_calc": "count(*)",
"meta": {},
},
'docs': {'show': True},
'columns': {},
'contract': False,
'meta': {},
'compiled': True,
'extra_ctes': [{'id': 'whatever', 'sql': 'select * from other'}],
'extra_ctes_injected': True,
'compiled_code': 'with whatever as (select * from other) select * from whatever',
'column_name': 'id',
'test_metadata': {
'name': 'foo',
'kwargs': {},
"docs": {"show": True},
"columns": {},
"contract": False,
"meta": {},
"compiled": True,
"extra_ctes": [{"id": "whatever", "sql": "select * from other"}],
"extra_ctes_injected": True,
"compiled_code": "with whatever as (select * from other) select * from whatever",
"column_name": "id",
"test_metadata": {
"name": "foo",
"kwargs": {},
},
'checksum': {'name': 'sha256', 'checksum': 'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'},
'unrendered_config': {
'severity': 'warn',
"checksum": {
"name": "sha256",
"checksum": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
},
'config_call_dict': {},
"unrendered_config": {
"severity": "warn",
},
"config_call_dict": {},
}
@pytest.mark.skip("Haven't found where we would use uncompiled node")
def test_basic_uncompiled_schema_test(basic_uncompiled_schema_test_node, basic_uncompiled_schema_test_dict, minimal_schema_test_dict):
def test_basic_uncompiled_schema_test(
basic_uncompiled_schema_test_node, basic_uncompiled_schema_test_dict, minimal_schema_test_dict
):
node = basic_uncompiled_schema_test_node
node_dict = basic_uncompiled_schema_test_dict
minimum = minimal_schema_test_dict
@@ -536,7 +578,9 @@ def test_basic_uncompiled_schema_test(basic_uncompiled_schema_test_node, basic_u
assert_from_dict(node, minimum, GenericTestNode)
def test_basic_compiled_schema_test(basic_compiled_schema_test_node, basic_compiled_schema_test_dict):
def test_basic_compiled_schema_test(
basic_compiled_schema_test_node, basic_compiled_schema_test_dict
):
node = basic_compiled_schema_test_node
node_dict = basic_compiled_schema_test_dict
@@ -548,42 +592,44 @@ def test_basic_compiled_schema_test(basic_compiled_schema_test_node, basic_compi
def test_invalid_extra_schema_test_fields(minimal_schema_test_dict):
bad_extra = minimal_schema_test_dict
bad_extra['extra'] = 'extra value'
bad_extra["extra"] = "extra value"
assert_fails_validation(bad_extra, GenericTestNode)
def test_invalid_resource_type_schema_test(minimal_schema_test_dict):
bad_type = minimal_schema_test_dict
bad_type['resource_type'] = str(NodeType.Model)
bad_type["resource_type"] = str(NodeType.Model)
assert_fails_validation(bad_type, GenericTestNode)
unchanged_schema_tests = [
# for tests, raw_code isn't a change (because it's always the same for a given test macro)
lambda u: u.replace(raw_code='select * from wherever'),
lambda u: u.replace(description='a description'),
lambda u: u.replace(tags=['mytag']),
lambda u: u.replace(meta={'cool_key': 'cool value'}),
lambda u: u.replace(raw_code="select * from wherever"),
lambda u: u.replace(description="a description"),
lambda u: u.replace(tags=["mytag"]),
lambda u: u.replace(meta={"cool_key": "cool value"}),
# these values don't even mean anything on schema tests!
lambda u: replace_config(u, alias='nope'),
lambda u: replace_config(u, database='nope'),
lambda u: replace_config(u, schema='nope'),
lambda u: u.replace(database='other_db'),
lambda u: u.replace(schema='other_schema'),
lambda u: u.replace(alias='foo'),
lambda u: replace_config(u, alias="nope"),
lambda u: replace_config(u, database="nope"),
lambda u: replace_config(u, schema="nope"),
lambda u: u.replace(database="other_db"),
lambda u: u.replace(schema="other_schema"),
lambda u: u.replace(alias="foo"),
lambda u: replace_config(u, full_refresh=True),
lambda u: replace_config(u, post_hook=['select 1 as id']),
lambda u: replace_config(u, pre_hook=['select 1 as id']),
lambda u: replace_config(
u, quoting={'database': True, 'schema': False, 'identifier': False}),
lambda u: replace_config(u, post_hook=["select 1 as id"]),
lambda u: replace_config(u, pre_hook=["select 1 as id"]),
lambda u: replace_config(u, quoting={"database": True, "schema": False, "identifier": False}),
]
changed_schema_tests = [
lambda u: None,
lambda u: u.replace(fqn=['test', 'models', 'subdir', 'foo'],
original_file_path='models/subdir/foo.sql', path='/root/models/subdir/foo.sql'),
lambda u: replace_config(u, severity='warn'),
lambda u: u.replace(
fqn=["test", "models", "subdir", "foo"],
original_file_path="models/subdir/foo.sql",
path="/root/models/subdir/foo.sql",
),
lambda u: replace_config(u, severity="warn"),
# If we checked test metadata, these would caount. But we don't, because these changes would all change the unique ID, so it's irrelevant.
# lambda u: u.replace(test_metadata=u.test_metadata.replace(namespace='something')),
# lambda u: u.replace(test_metadata=u.test_metadata.replace(name='bar')),
@@ -591,13 +637,13 @@ changed_schema_tests = [
]
@pytest.mark.parametrize('func', unchanged_schema_tests)
@pytest.mark.parametrize("func", unchanged_schema_tests)
def test_compare_unchanged_schema_test(func, basic_uncompiled_schema_test_node):
value = func(basic_uncompiled_schema_test_node)
assert basic_uncompiled_schema_test_node.same_contents(value)
@pytest.mark.parametrize('func', changed_schema_tests)
@pytest.mark.parametrize("func", changed_schema_tests)
def test_compare_changed_schema_test(func, basic_uncompiled_schema_test_node):
value = func(basic_uncompiled_schema_test_node)
assert not basic_uncompiled_schema_test_node.same_contents(value)
@@ -610,5 +656,6 @@ def test_compare_to_compiled(basic_uncompiled_schema_test_node, basic_compiled_s
assert not uncompiled.same_contents(compiled)
fixed_config = compiled.config.replace(severity=uncompiled.config.severity)
fixed_compiled = compiled.replace(
config=fixed_config, unrendered_config=uncompiled.unrendered_config)
config=fixed_config, unrendered_config=uncompiled.unrendered_config
)
assert uncompiled.same_contents(fixed_compiled)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,3 @@
from .utils import ContractTestCase
from dbt.dataclass_schema import ValidationError
@@ -11,28 +10,28 @@ class TestProject(ContractTestCase):
def test_minimal(self):
dct = {
'name': 'test',
'version': '1.0',
'profile': 'test',
'project-root': '/usr/src/app',
'config-version': 2,
"name": "test",
"version": "1.0",
"profile": "test",
"project-root": "/usr/src/app",
"config-version": 2,
}
project = self.ContractType(
name='test',
version='1.0',
profile='test',
project_root='/usr/src/app',
name="test",
version="1.0",
profile="test",
project_root="/usr/src/app",
config_version=2,
)
self.assert_from_dict(project, dct)
def test_invalid_name(self):
dct = {
'name': 'log',
'version': '1.0',
'profile': 'test',
'project-root': '/usr/src/app',
'config-version': 2,
"name": "log",
"version": "1.0",
"profile": "test",
"project-root": "/usr/src/app",
"config-version": 2,
}
with self.assertRaises(ValidationError):
self.ContractType.validate(dct)

View File

@@ -15,42 +15,48 @@ class TestCoreDbtUtils(unittest.TestCase):
def test_connection_exception_retry_success_requests_exception(self):
Counter._reset()
connection_exception_retry(lambda: Counter._add_with_requests_exception(), 5)
self.assertEqual(2, counter) # 2 = original attempt returned None, plus 1 retry
self.assertEqual(2, counter) # 2 = original attempt returned None, plus 1 retry
def test_connection_exception_retry_max(self):
Counter._reset()
with self.assertRaises(ConnectionError):
connection_exception_retry(lambda: Counter._add_with_exception(), 5)
self.assertEqual(6, counter) # 6 = original attempt plus 5 retries
self.assertEqual(6, counter) # 6 = original attempt plus 5 retries
def test_connection_exception_retry_success_failed_untar(self):
Counter._reset()
connection_exception_retry(lambda: Counter._add_with_untar_exception(), 5)
self.assertEqual(2, counter) # 2 = original attempt returned ReadError, plus 1 retry
self.assertEqual(2, counter) # 2 = original attempt returned ReadError, plus 1 retry
counter:int = 0
class Counter():
counter: int = 0
class Counter:
def _add():
global counter
counter+=1
counter += 1
# All exceptions that Requests explicitly raises inherit from
# requests.exceptions.RequestException so we want to make sure that raises plus one exception
# that inherit from it for sanity
def _add_with_requests_exception():
global counter
counter+=1
counter += 1
if counter < 2:
raise requests.exceptions.RequestException
def _add_with_exception():
global counter
counter+=1
counter += 1
raise requests.exceptions.ConnectionError
def _add_with_untar_exception():
global counter
counter+=1
counter += 1
if counter < 2:
raise tarfile.ReadError
def _reset():
global counter
counter = 0

View File

@@ -26,13 +26,13 @@ from dbt.dataclass_schema import ValidationError
class TestLocalPackage(unittest.TestCase):
def test_init(self):
a_contract = LocalPackage.from_dict({'local': '/path/to/package'})
self.assertEqual(a_contract.local, '/path/to/package')
a_contract = LocalPackage.from_dict({"local": "/path/to/package"})
self.assertEqual(a_contract.local, "/path/to/package")
a = LocalUnpinnedPackage.from_contract(a_contract)
self.assertEqual(a.local, '/path/to/package')
self.assertEqual(a.local, "/path/to/package")
a_pinned = a.resolved()
self.assertEqual(a_pinned.local, '/path/to/package')
self.assertEqual(str(a_pinned), '/path/to/package')
self.assertEqual(a_pinned.local, "/path/to/package")
self.assertEqual(str(a_pinned), "/path/to/package")
class TestTarballPackage(unittest.TestCase):
@@ -40,73 +40,70 @@ class TestTarballPackage(unittest.TestCase):
from dbt.contracts.project import RegistryPackageMetadata
from mashumaro.exceptions import MissingField
dict_well_formed_contract = (
{'tarball': 'http://example.com',
'name': 'my_cool_package'})
dict_well_formed_contract = {"tarball": "http://example.com", "name": "my_cool_package"}
a_contract = (
TarballPackage.from_dict(dict_well_formed_contract))
a_contract = TarballPackage.from_dict(dict_well_formed_contract)
# check contract and resolver
self.assertEqual(a_contract.tarball, 'http://example.com')
self.assertEqual(a_contract.name, 'my_cool_package')
self.assertEqual(a_contract.tarball, "http://example.com")
self.assertEqual(a_contract.name, "my_cool_package")
a = TarballUnpinnedPackage.from_contract(a_contract)
self.assertEqual(a.tarball, 'http://example.com')
self.assertEqual(a.package, 'my_cool_package')
self.assertEqual(a.tarball, "http://example.com")
self.assertEqual(a.package, "my_cool_package")
a_pinned = a.resolved()
self.assertEqual(a_pinned.source_type(), 'tarball')
self.assertEqual(a_pinned.source_type(), "tarball")
# check bad contract (no name) fails
dict_missing_name_should_fail_on_contract = (
{'tarball': 'http://example.com'})
dict_missing_name_should_fail_on_contract = {"tarball": "http://example.com"}
with self.assertRaises(MissingField):
TarballPackage.from_dict(dict_missing_name_should_fail_on_contract)
# check RegistryPackageMetadata - it is used in TarballUnpinnedPackage
dct = {'name' : a.package,
'packages': [], # note: required by RegistryPackageMetadata
'downloads' : {'tarball' : a_pinned.tarball}}
dct = {
"name": a.package,
"packages": [], # note: required by RegistryPackageMetadata
"downloads": {"tarball": a_pinned.tarball},
}
metastore = RegistryPackageMetadata.from_dict(dct)
self.assertEqual(metastore.downloads.tarball, 'http://example.com')
self.assertEqual(metastore.downloads.tarball, "http://example.com")
class TestGitPackage(unittest.TestCase):
def test_init(self):
a_contract = GitPackage.from_dict(
{'git': 'http://example.com', 'revision': '0.0.1'},
{"git": "http://example.com", "revision": "0.0.1"},
)
self.assertEqual(a_contract.git, 'http://example.com')
self.assertEqual(a_contract.revision, '0.0.1')
self.assertEqual(a_contract.git, "http://example.com")
self.assertEqual(a_contract.revision, "0.0.1")
self.assertIs(a_contract.warn_unpinned, None)
a = GitUnpinnedPackage.from_contract(a_contract)
self.assertEqual(a.git, 'http://example.com')
self.assertEqual(a.revisions, ['0.0.1'])
self.assertEqual(a.git, "http://example.com")
self.assertEqual(a.revisions, ["0.0.1"])
self.assertIs(a.warn_unpinned, True)
a_pinned = a.resolved()
self.assertEqual(a_pinned.name, 'http://example.com')
self.assertEqual(a_pinned.get_version(), '0.0.1')
self.assertEqual(a_pinned.source_type(), 'git')
self.assertEqual(a_pinned.name, "http://example.com")
self.assertEqual(a_pinned.get_version(), "0.0.1")
self.assertEqual(a_pinned.source_type(), "git")
self.assertIs(a_pinned.warn_unpinned, True)
def test_invalid(self):
with self.assertRaises(ValidationError):
GitPackage.validate(
{'git': 'http://example.com', 'version': '0.0.1'},
{"git": "http://example.com", "version": "0.0.1"},
)
def test_resolve_ok(self):
a_contract = GitPackage.from_dict(
{'git': 'http://example.com', 'revision': '0.0.1'},
{"git": "http://example.com", "revision": "0.0.1"},
)
b_contract = GitPackage.from_dict(
{'git': 'http://example.com', 'revision': '0.0.1',
'warn-unpinned': False},
{"git": "http://example.com", "revision": "0.0.1", "warn-unpinned": False},
)
a = GitUnpinnedPackage.from_contract(a_contract)
b = GitUnpinnedPackage.from_contract(b_contract)
@@ -115,71 +112,69 @@ class TestGitPackage(unittest.TestCase):
c = a.incorporate(b)
c_pinned = c.resolved()
self.assertEqual(c_pinned.name, 'http://example.com')
self.assertEqual(c_pinned.get_version(), '0.0.1')
self.assertEqual(c_pinned.source_type(), 'git')
self.assertEqual(c_pinned.name, "http://example.com")
self.assertEqual(c_pinned.get_version(), "0.0.1")
self.assertEqual(c_pinned.source_type(), "git")
self.assertFalse(c_pinned.warn_unpinned)
def test_resolve_fail(self):
a_contract = GitPackage.from_dict(
{'git': 'http://example.com', 'revision': '0.0.1'},
{"git": "http://example.com", "revision": "0.0.1"},
)
b_contract = GitPackage.from_dict(
{'git': 'http://example.com', 'revision': '0.0.2'},
{"git": "http://example.com", "revision": "0.0.2"},
)
a = GitUnpinnedPackage.from_contract(a_contract)
b = GitUnpinnedPackage.from_contract(b_contract)
c = a.incorporate(b)
self.assertEqual(c.git, 'http://example.com')
self.assertEqual(c.revisions, ['0.0.1', '0.0.2'])
self.assertEqual(c.git, "http://example.com")
self.assertEqual(c.revisions, ["0.0.1", "0.0.2"])
with self.assertRaises(dbt.exceptions.DependencyError):
c.resolved()
def test_default_revision(self):
a_contract = GitPackage.from_dict({'git': 'http://example.com'})
a_contract = GitPackage.from_dict({"git": "http://example.com"})
self.assertEqual(a_contract.revision, None)
self.assertIs(a_contract.warn_unpinned, None)
a = GitUnpinnedPackage.from_contract(a_contract)
self.assertEqual(a.git, 'http://example.com')
self.assertEqual(a.git, "http://example.com")
self.assertEqual(a.revisions, [])
self.assertIs(a.warn_unpinned, True)
a_pinned = a.resolved()
self.assertEqual(a_pinned.name, 'http://example.com')
self.assertEqual(a_pinned.get_version(), 'HEAD')
self.assertEqual(a_pinned.source_type(), 'git')
self.assertEqual(a_pinned.name, "http://example.com")
self.assertEqual(a_pinned.get_version(), "HEAD")
self.assertEqual(a_pinned.source_type(), "git")
self.assertIs(a_pinned.warn_unpinned, True)
class TestHubPackage(unittest.TestCase):
def setUp(self):
self.patcher = mock.patch('dbt.deps.registry.registry')
self.patcher = mock.patch("dbt.deps.registry.registry")
self.registry = self.patcher.start()
self.index_cached = self.registry.index_cached
self.get_compatible_versions = self.registry.get_compatible_versions
self.package_version = self.registry.package_version
self.index_cached.return_value = [
'dbt-labs-test/a',
]
self.get_compatible_versions.return_value = [
'0.1.2', '0.1.3', '0.1.4a1'
"dbt-labs-test/a",
]
self.get_compatible_versions.return_value = ["0.1.2", "0.1.3", "0.1.4a1"]
self.package_version.return_value = {
'id': 'dbt-labs-test/a/0.1.2',
'name': 'a',
'version': '0.1.2',
'packages': [],
'_source': {
'blahblah': 'asdfas',
"id": "dbt-labs-test/a/0.1.2",
"name": "a",
"version": "0.1.2",
"packages": [],
"_source": {
"blahblah": "asdfas",
},
'downloads': {
'tarball': 'https://example.com/invalid-url!',
'extra': 'field',
"downloads": {
"tarball": "https://example.com/invalid-url!",
"extra": "field",
},
'newfield': ['another', 'value'],
"newfield": ["another", "value"],
}
def tearDown(self):
@@ -187,94 +182,83 @@ class TestHubPackage(unittest.TestCase):
def test_init(self):
a_contract = RegistryPackage(
package='dbt-labs-test/a',
version='0.1.2',
package="dbt-labs-test/a",
version="0.1.2",
)
self.assertEqual(a_contract.package, 'dbt-labs-test/a')
self.assertEqual(a_contract.version, '0.1.2')
self.assertEqual(a_contract.package, "dbt-labs-test/a")
self.assertEqual(a_contract.version, "0.1.2")
a = RegistryUnpinnedPackage.from_contract(a_contract)
self.assertEqual(a.package, 'dbt-labs-test/a')
self.assertEqual(a.package, "dbt-labs-test/a")
self.assertEqual(
a.versions,
[VersionSpecifier(
build=None,
major='0',
matcher='=',
minor='1',
patch='2',
prerelease=None
)]
[
VersionSpecifier(
build=None, major="0", matcher="=", minor="1", patch="2", prerelease=None
)
],
)
a_pinned = a.resolved()
self.assertEqual(a_contract.package, 'dbt-labs-test/a')
self.assertEqual(a_contract.version, '0.1.2')
self.assertEqual(a_pinned.source_type(), 'hub')
self.assertEqual(a_contract.package, "dbt-labs-test/a")
self.assertEqual(a_contract.version, "0.1.2")
self.assertEqual(a_pinned.source_type(), "hub")
def test_invalid(self):
with self.assertRaises(ValidationError):
RegistryPackage.validate(
{'package': 'namespace/name', 'key': 'invalid'},
{"package": "namespace/name", "key": "invalid"},
)
def test_resolve_ok(self):
a_contract = RegistryPackage(
package='dbt-labs-test/a',
version='0.1.2'
)
b_contract = RegistryPackage(
package='dbt-labs-test/a',
version='0.1.2'
)
a_contract = RegistryPackage(package="dbt-labs-test/a", version="0.1.2")
b_contract = RegistryPackage(package="dbt-labs-test/a", version="0.1.2")
a = RegistryUnpinnedPackage.from_contract(a_contract)
b = RegistryUnpinnedPackage.from_contract(b_contract)
c = a.incorporate(b)
self.assertEqual(c.package, 'dbt-labs-test/a')
self.assertEqual(c.package, "dbt-labs-test/a")
self.assertEqual(
c.versions,
[
VersionSpecifier(
build=None,
major='0',
matcher='=',
minor='1',
patch='2',
major="0",
matcher="=",
minor="1",
patch="2",
prerelease=None,
),
VersionSpecifier(
build=None,
major='0',
matcher='=',
minor='1',
patch='2',
major="0",
matcher="=",
minor="1",
patch="2",
prerelease=None,
),
]
],
)
c_pinned = c.resolved()
self.assertEqual(c_pinned.package, 'dbt-labs-test/a')
self.assertEqual(c_pinned.version, '0.1.2')
self.assertEqual(c_pinned.source_type(), 'hub')
self.assertEqual(c_pinned.package, "dbt-labs-test/a")
self.assertEqual(c_pinned.version, "0.1.2")
self.assertEqual(c_pinned.source_type(), "hub")
def test_resolve_missing_package(self):
a = RegistryUnpinnedPackage.from_contract(RegistryPackage(
package='dbt-labs-test/b',
version='0.1.2'
))
a = RegistryUnpinnedPackage.from_contract(
RegistryPackage(package="dbt-labs-test/b", version="0.1.2")
)
with self.assertRaises(dbt.exceptions.DependencyError) as exc:
a.resolved()
msg = 'Package dbt-labs-test/b was not found in the package index'
msg = "Package dbt-labs-test/b was not found in the package index"
self.assertEqual(msg, str(exc.exception))
def test_resolve_missing_version(self):
a = RegistryUnpinnedPackage.from_contract(RegistryPackage(
package='dbt-labs-test/a',
version='0.1.4'
))
a = RegistryUnpinnedPackage.from_contract(
RegistryPackage(package="dbt-labs-test/a", version="0.1.4")
)
with self.assertRaises(dbt.exceptions.DependencyError) as exc:
a.resolved()
@@ -286,14 +270,8 @@ class TestHubPackage(unittest.TestCase):
assert msg in str(exc.exception)
def test_resolve_conflict(self):
a_contract = RegistryPackage(
package='dbt-labs-test/a',
version='0.1.2'
)
b_contract = RegistryPackage(
package='dbt-labs-test/a',
version='0.1.3'
)
a_contract = RegistryPackage(package="dbt-labs-test/a", version="0.1.2")
b_contract = RegistryPackage(package="dbt-labs-test/a", version="0.1.3")
a = RegistryUnpinnedPackage.from_contract(a_contract)
b = RegistryUnpinnedPackage.from_contract(b_contract)
c = a.incorporate(b)
@@ -307,244 +285,216 @@ class TestHubPackage(unittest.TestCase):
self.assertEqual(msg, str(exc.exception))
def test_resolve_ranges(self):
a_contract = RegistryPackage(
package='dbt-labs-test/a',
version='0.1.2'
)
b_contract = RegistryPackage(
package='dbt-labs-test/a',
version='<0.1.4'
)
a_contract = RegistryPackage(package="dbt-labs-test/a", version="0.1.2")
b_contract = RegistryPackage(package="dbt-labs-test/a", version="<0.1.4")
a = RegistryUnpinnedPackage.from_contract(a_contract)
b = RegistryUnpinnedPackage.from_contract(b_contract)
c = a.incorporate(b)
self.assertEqual(c.package, 'dbt-labs-test/a')
self.assertEqual(c.package, "dbt-labs-test/a")
self.assertEqual(
c.versions,
[
VersionSpecifier(
build=None,
major='0',
matcher='=',
minor='1',
patch='2',
major="0",
matcher="=",
minor="1",
patch="2",
prerelease=None,
),
VersionSpecifier(
build=None,
major='0',
matcher='<',
minor='1',
patch='4',
major="0",
matcher="<",
minor="1",
patch="4",
prerelease=None,
),
]
],
)
c_pinned = c.resolved()
self.assertEqual(c_pinned.package, 'dbt-labs-test/a')
self.assertEqual(c_pinned.version, '0.1.2')
self.assertEqual(c_pinned.source_type(), 'hub')
self.assertEqual(c_pinned.package, "dbt-labs-test/a")
self.assertEqual(c_pinned.version, "0.1.2")
self.assertEqual(c_pinned.source_type(), "hub")
def test_resolve_ranges_install_prerelease_default_false(self):
a_contract = RegistryPackage(
package='dbt-labs-test/a',
version='>0.1.2'
)
b_contract = RegistryPackage(
package='dbt-labs-test/a',
version='<0.1.5'
)
a_contract = RegistryPackage(package="dbt-labs-test/a", version=">0.1.2")
b_contract = RegistryPackage(package="dbt-labs-test/a", version="<0.1.5")
a = RegistryUnpinnedPackage.from_contract(a_contract)
b = RegistryUnpinnedPackage.from_contract(b_contract)
c = a.incorporate(b)
self.assertEqual(c.package, 'dbt-labs-test/a')
self.assertEqual(c.package, "dbt-labs-test/a")
self.assertEqual(
c.versions,
[
VersionSpecifier(
build=None,
major='0',
matcher='>',
minor='1',
patch='2',
major="0",
matcher=">",
minor="1",
patch="2",
prerelease=None,
),
VersionSpecifier(
build=None,
major='0',
matcher='<',
minor='1',
patch='5',
major="0",
matcher="<",
minor="1",
patch="5",
prerelease=None,
),
]
],
)
c_pinned = c.resolved()
self.assertEqual(c_pinned.package, 'dbt-labs-test/a')
self.assertEqual(c_pinned.version, '0.1.3')
self.assertEqual(c_pinned.source_type(), 'hub')
self.assertEqual(c_pinned.package, "dbt-labs-test/a")
self.assertEqual(c_pinned.version, "0.1.3")
self.assertEqual(c_pinned.source_type(), "hub")
def test_resolve_ranges_install_prerelease_true(self):
a_contract = RegistryPackage(
package='dbt-labs-test/a',
version='>0.1.2',
install_prerelease=True
)
b_contract = RegistryPackage(
package='dbt-labs-test/a',
version='<0.1.5'
package="dbt-labs-test/a", version=">0.1.2", install_prerelease=True
)
b_contract = RegistryPackage(package="dbt-labs-test/a", version="<0.1.5")
a = RegistryUnpinnedPackage.from_contract(a_contract)
b = RegistryUnpinnedPackage.from_contract(b_contract)
c = a.incorporate(b)
self.assertEqual(c.package, 'dbt-labs-test/a')
self.assertEqual(c.package, "dbt-labs-test/a")
self.assertEqual(
c.versions,
[
VersionSpecifier(
build=None,
major='0',
matcher='>',
minor='1',
patch='2',
major="0",
matcher=">",
minor="1",
patch="2",
prerelease=None,
),
VersionSpecifier(
build=None,
major='0',
matcher='<',
minor='1',
patch='5',
major="0",
matcher="<",
minor="1",
patch="5",
prerelease=None,
),
]
],
)
c_pinned = c.resolved()
self.assertEqual(c_pinned.package, 'dbt-labs-test/a')
self.assertEqual(c_pinned.version, '0.1.4a1')
self.assertEqual(c_pinned.source_type(), 'hub')
self.assertEqual(c_pinned.package, "dbt-labs-test/a")
self.assertEqual(c_pinned.version, "0.1.4a1")
self.assertEqual(c_pinned.source_type(), "hub")
def test_get_version_latest_prelease_true(self):
a_contract = RegistryPackage(
package='dbt-labs-test/a',
version='>0.1.0',
install_prerelease=True
)
b_contract = RegistryPackage(
package='dbt-labs-test/a',
version='<0.1.4'
package="dbt-labs-test/a", version=">0.1.0", install_prerelease=True
)
b_contract = RegistryPackage(package="dbt-labs-test/a", version="<0.1.4")
a = RegistryUnpinnedPackage.from_contract(a_contract)
b = RegistryUnpinnedPackage.from_contract(b_contract)
c = a.incorporate(b)
self.assertEqual(c.package, 'dbt-labs-test/a')
self.assertEqual(c.package, "dbt-labs-test/a")
self.assertEqual(
c.versions,
[
VersionSpecifier(
build=None,
major='0',
matcher='>',
minor='1',
patch='0',
major="0",
matcher=">",
minor="1",
patch="0",
prerelease=None,
),
VersionSpecifier(
build=None,
major='0',
matcher='<',
minor='1',
patch='4',
major="0",
matcher="<",
minor="1",
patch="4",
prerelease=None,
),
]
],
)
c_pinned = c.resolved()
self.assertEqual(c_pinned.package, 'dbt-labs-test/a')
self.assertEqual(c_pinned.version, '0.1.3')
self.assertEqual(c_pinned.get_version_latest(), '0.1.4a1')
self.assertEqual(c_pinned.source_type(), 'hub')
self.assertEqual(c_pinned.package, "dbt-labs-test/a")
self.assertEqual(c_pinned.version, "0.1.3")
self.assertEqual(c_pinned.get_version_latest(), "0.1.4a1")
self.assertEqual(c_pinned.source_type(), "hub")
def test_get_version_latest_prelease_false(self):
a_contract = RegistryPackage(
package='dbt-labs-test/a',
version='>0.1.0',
install_prerelease=False
)
b_contract = RegistryPackage(
package='dbt-labs-test/a',
version='<0.1.4'
package="dbt-labs-test/a", version=">0.1.0", install_prerelease=False
)
b_contract = RegistryPackage(package="dbt-labs-test/a", version="<0.1.4")
a = RegistryUnpinnedPackage.from_contract(a_contract)
b = RegistryUnpinnedPackage.from_contract(b_contract)
c = a.incorporate(b)
self.assertEqual(c.package, 'dbt-labs-test/a')
self.assertEqual(c.package, "dbt-labs-test/a")
self.assertEqual(
c.versions,
[
VersionSpecifier(
build=None,
major='0',
matcher='>',
minor='1',
patch='0',
major="0",
matcher=">",
minor="1",
patch="0",
prerelease=None,
),
VersionSpecifier(
build=None,
major='0',
matcher='<',
minor='1',
patch='4',
major="0",
matcher="<",
minor="1",
patch="4",
prerelease=None,
),
]
],
)
c_pinned = c.resolved()
self.assertEqual(c_pinned.package, 'dbt-labs-test/a')
self.assertEqual(c_pinned.version, '0.1.3')
self.assertEqual(c_pinned.get_version_latest(), '0.1.3')
self.assertEqual(c_pinned.source_type(), 'hub')
self.assertEqual(c_pinned.package, "dbt-labs-test/a")
self.assertEqual(c_pinned.version, "0.1.3")
self.assertEqual(c_pinned.get_version_latest(), "0.1.3")
self.assertEqual(c_pinned.source_type(), "hub")
def test_get_version_prerelease_explicitly_requested(self):
a_contract = RegistryPackage(
package='dbt-labs-test/a',
version='0.1.4a1',
install_prerelease=None
package="dbt-labs-test/a", version="0.1.4a1", install_prerelease=None
)
a = RegistryUnpinnedPackage.from_contract(a_contract)
self.assertEqual(a.package, 'dbt-labs-test/a')
self.assertEqual(a.package, "dbt-labs-test/a")
self.assertEqual(
a.versions,
[
VersionSpecifier(
build=None,
major='0',
matcher='=',
minor='1',
patch='4',
prerelease='a1',
major="0",
matcher="=",
minor="1",
patch="4",
prerelease="a1",
),
]
],
)
a_pinned = a.resolved()
self.assertEqual(a_pinned.package, 'dbt-labs-test/a')
self.assertEqual(a_pinned.version, '0.1.4a1')
self.assertEqual(a_pinned.get_version_latest(), '0.1.4a1')
self.assertEqual(a_pinned.source_type(), 'hub')
self.assertEqual(a_pinned.package, "dbt-labs-test/a")
self.assertEqual(a_pinned.version, "0.1.4a1")
self.assertEqual(a_pinned.get_version_latest(), "0.1.4a1")
self.assertEqual(a_pinned.source_type(), "hub")
class MockRegistry:
def __init__(self, packages):
@@ -563,7 +513,8 @@ class MockRegistry:
def get_compatible_versions(self, package_name, dbt_version, should_version_check):
packages = self.package(package_name)
return [
pkg_version for pkg_version, info in packages.items()
pkg_version
for pkg_version, info in packages.items()
if is_compatible_version(info, dbt_version)
]
@@ -575,146 +526,158 @@ class MockRegistry:
class TestPackageSpec(unittest.TestCase):
def setUp(self):
def setUp(self):
dbt_version = get_installed_version()
next_version = deepcopy(dbt_version)
next_version.minor = str(int(next_version.minor) + 1)
next_version.prerelease = None
require_next_version = ">" + next_version.to_version_string()
self.patcher = mock.patch('dbt.deps.registry.registry')
self.patcher = mock.patch("dbt.deps.registry.registry")
self.registry = self.patcher.start()
self.mock_registry = MockRegistry(packages={
'dbt-labs-test/a': {
'0.1.2': {
'id': 'dbt-labs-test/a/0.1.2',
'name': 'a',
'version': '0.1.2',
'packages': [],
'_source': {
'blahblah': 'asdfas',
self.mock_registry = MockRegistry(
packages={
"dbt-labs-test/a": {
"0.1.2": {
"id": "dbt-labs-test/a/0.1.2",
"name": "a",
"version": "0.1.2",
"packages": [],
"_source": {
"blahblah": "asdfas",
},
"downloads": {
"tarball": "https://example.com/invalid-url!",
"extra": "field",
},
"newfield": ["another", "value"],
},
'downloads': {
'tarball': 'https://example.com/invalid-url!',
'extra': 'field',
"0.1.3": {
"id": "dbt-labs-test/a/0.1.3",
"name": "a",
"version": "0.1.3",
"packages": [],
"_source": {
"blahblah": "asdfas",
},
"downloads": {
"tarball": "https://example.com/invalid-url!",
"extra": "field",
},
"newfield": ["another", "value"],
},
"0.1.4a1": {
"id": "dbt-labs-test/a/0.1.3a1",
"name": "a",
"version": "0.1.4a1",
"packages": [],
"_source": {
"blahblah": "asdfas",
},
"downloads": {
"tarball": "https://example.com/invalid-url!",
"extra": "field",
},
"newfield": ["another", "value"],
},
"0.2.0": {
"id": "dbt-labs-test/a/0.2.0",
"name": "a",
"version": "0.2.0",
"packages": [],
"_source": {
"blahblah": "asdfas",
},
# this one shouldn't be picked!
"require_dbt_version": require_next_version,
"downloads": {
"tarball": "https://example.com/invalid-url!",
"extra": "field",
},
"newfield": ["another", "value"],
},
'newfield': ['another', 'value'],
},
'0.1.3': {
'id': 'dbt-labs-test/a/0.1.3',
'name': 'a',
'version': '0.1.3',
'packages': [],
'_source': {
'blahblah': 'asdfas',
"dbt-labs-test/b": {
"0.2.1": {
"id": "dbt-labs-test/b/0.2.1",
"name": "b",
"version": "0.2.1",
"packages": [{"package": "dbt-labs-test/a", "version": ">=0.1.3"}],
"_source": {
"blahblah": "asdfas",
},
"downloads": {
"tarball": "https://example.com/invalid-url!",
"extra": "field",
},
"newfield": ["another", "value"],
},
'downloads': {
'tarball': 'https://example.com/invalid-url!',
'extra': 'field',
},
'newfield': ['another', 'value'],
},
'0.1.4a1': {
'id': 'dbt-labs-test/a/0.1.3a1',
'name': 'a',
'version': '0.1.4a1',
'packages': [],
'_source': {
'blahblah': 'asdfas',
},
'downloads': {
'tarball': 'https://example.com/invalid-url!',
'extra': 'field',
},
'newfield': ['another', 'value'],
},
'0.2.0': {
'id': 'dbt-labs-test/a/0.2.0',
'name': 'a',
'version': '0.2.0',
'packages': [],
'_source': {
'blahblah': 'asdfas',
},
# this one shouldn't be picked!
'require_dbt_version': require_next_version,
'downloads': {
'tarball': 'https://example.com/invalid-url!',
'extra': 'field',
},
'newfield': ['another', 'value'],
}
},
'dbt-labs-test/b': {
'0.2.1': {
'id': 'dbt-labs-test/b/0.2.1',
'name': 'b',
'version': '0.2.1',
'packages': [{'package': 'dbt-labs-test/a', 'version': '>=0.1.3'}],
'_source': {
'blahblah': 'asdfas',
},
'downloads': {
'tarball': 'https://example.com/invalid-url!',
'extra': 'field',
},
'newfield': ['another', 'value'],
},
}
})
)
self.registry.index_cached.side_effect = self.mock_registry.index_cached
self.registry.get_compatible_versions.side_effect = self.mock_registry.get_compatible_versions
self.registry.get_compatible_versions.side_effect = (
self.mock_registry.get_compatible_versions
)
self.registry.package_version.side_effect = self.mock_registry.package_version
def tearDown(self):
self.patcher.stop()
def test_dependency_resolution(self):
package_config = PackageConfig.from_dict({
'packages': [
{'package': 'dbt-labs-test/a', 'version': '>0.1.2'},
{'package': 'dbt-labs-test/b', 'version': '0.2.1'},
],
})
resolved = resolve_packages(package_config.packages, mock.MagicMock(project_name='test'), {})
package_config = PackageConfig.from_dict(
{
"packages": [
{"package": "dbt-labs-test/a", "version": ">0.1.2"},
{"package": "dbt-labs-test/b", "version": "0.2.1"},
],
}
)
resolved = resolve_packages(
package_config.packages, mock.MagicMock(project_name="test"), {}
)
self.assertEqual(len(resolved), 2)
self.assertEqual(resolved[0].name, 'dbt-labs-test/a')
self.assertEqual(resolved[0].version, '0.1.3')
self.assertEqual(resolved[1].name, 'dbt-labs-test/b')
self.assertEqual(resolved[1].version, '0.2.1')
self.assertEqual(resolved[0].name, "dbt-labs-test/a")
self.assertEqual(resolved[0].version, "0.1.3")
self.assertEqual(resolved[1].name, "dbt-labs-test/b")
self.assertEqual(resolved[1].version, "0.2.1")
def test_dependency_resolution_allow_prerelease(self):
package_config = PackageConfig.from_dict({
'packages': [
{'package': 'dbt-labs-test/a', 'version': '>0.1.2', 'install_prerelease': True},
{'package': 'dbt-labs-test/b', 'version': '0.2.1'},
],
})
resolved = resolve_packages(package_config.packages, mock.MagicMock(project_name='test'), {})
self.assertEqual(resolved[0].name, 'dbt-labs-test/a')
self.assertEqual(resolved[0].version, '0.1.4a1')
package_config = PackageConfig.from_dict(
{
"packages": [
{
"package": "dbt-labs-test/a",
"version": ">0.1.2",
"install_prerelease": True,
},
{"package": "dbt-labs-test/b", "version": "0.2.1"},
],
}
)
resolved = resolve_packages(
package_config.packages, mock.MagicMock(project_name="test"), {}
)
self.assertEqual(resolved[0].name, "dbt-labs-test/a")
self.assertEqual(resolved[0].version, "0.1.4a1")
def test_validation_error_when_version_is_missing_from_package_config(self):
packages_data = {"packages": [{'package': 'dbt-labs-test/b', 'version': None}]}
with self.assertRaises(ValidationError) as exc:
a = PackageConfig.validate(data=packages_data)
msg = (
"dbt-labs-test/b is missing the version. When installing from the Hub package index, version is a required property"
)
packages_data = {"packages": [{"package": "dbt-labs-test/b", "version": None}]}
with self.assertRaises(ValidationError) as exc:
PackageConfig.validate(data=packages_data)
msg = "dbt-labs-test/b is missing the version. When installing from the Hub package index, version is a required property"
assert msg in str(exc.exception)
def test_validation_error_when_namespace_is_missing_from_package_config(self):
packages_data = {"packages": [{'package': 'dbt-labs', 'version': '1.0.0'}]}
with self.assertRaises(ValidationError) as exc:
a = PackageConfig.validate(data=packages_data)
msg = (
"dbt-labs was not found in the package index. Packages on the index require a namespace, e.g dbt-labs/dbt_utils"
)
packages_data = {"packages": [{"package": "dbt-labs", "version": "1.0.0"}]}
with self.assertRaises(ValidationError) as exc:
PackageConfig.validate(data=packages_data)
msg = "dbt-labs was not found in the package index. Packages on the index require a namespace, e.g dbt-labs/dbt_utils"
assert msg in str(exc.exception)

View File

@@ -11,7 +11,7 @@ from dbt.parser.search import FileBlock
from .utils import config_from_parts_or_dicts
SNOWPLOW_SESSIONS_DOCS = r'''
SNOWPLOW_SESSIONS_DOCS = r"""
This table contains one record for every session recorded by Snowplow.
A session is itself comprised of pageviews that all occur within 30 minutes
of each other. If more than 30 minutes elapse between pageviews, then a
@@ -29,42 +29,42 @@ The following sessions will be created:
| ---------- | ---------------- | --------------- |
| abc | 123 | 2 |
| abc | 789 | 1 |
'''
"""
SNOWPLOW_SESSIONS_SESSION_ID_DOCS = r'''
SNOWPLOW_SESSIONS_SESSION_ID_DOCS = r"""
This column is the unique identifier for a Snowplow session. It is generated by
a cookie then expires after 30 minutes of inactivity.
'''
"""
SNOWPLOW_SESSIONS_BLOCK = r'''
SNOWPLOW_SESSIONS_BLOCK = r"""
{{% docs snowplow_sessions %}}
{snowplow_sessions_docs}
{{% enddocs %}}
'''.format(
snowplow_sessions_docs=SNOWPLOW_SESSIONS_DOCS
""".format(
snowplow_sessions_docs=SNOWPLOW_SESSIONS_DOCS
).strip()
SNOWPLOW_SESSIONS_SESSION_ID_BLOCK = r'''
SNOWPLOW_SESSIONS_SESSION_ID_BLOCK = r"""
{{% docs snowplow_sessions__session_id %}}
{snowplow_sessions_session_id_docs}
{{% enddocs %}}
'''.format(
""".format(
snowplow_sessions_session_id_docs=SNOWPLOW_SESSIONS_SESSION_ID_DOCS
).strip()
TEST_DOCUMENTATION_FILE = r'''
TEST_DOCUMENTATION_FILE = r"""
{sessions_block}
{session_id_block}
'''.format(
""".format(
sessions_block=SNOWPLOW_SESSIONS_BLOCK,
session_id_block=SNOWPLOW_SESSIONS_SESSION_ID_BLOCK,
)
MULTIPLE_RAW_BLOCKS = r'''
MULTIPLE_RAW_BLOCKS = r"""
{% docs some_doc %}
{% raw %}
```
@@ -80,49 +80,49 @@ MULTIPLE_RAW_BLOCKS = r'''
```
{% endraw %}
{% enddocs %}
'''
"""
class DocumentationParserTest(unittest.TestCase):
def setUp(self):
if os.name == 'nt':
self.root_path = 'C:\\test_root'
self.subdir_path = 'C:\\test_root\\test_subdir'
self.testfile_path = 'C:\\test_root\\test_subdir\\test_file.md'
if os.name == "nt":
self.root_path = "C:\\test_root"
self.subdir_path = "C:\\test_root\\test_subdir"
self.testfile_path = "C:\\test_root\\test_subdir\\test_file.md"
else:
self.root_path = '/test_root'
self.subdir_path = '/test_root/test_subdir'
self.testfile_path = '/test_root/test_subdir/test_file.md'
self.root_path = "/test_root"
self.subdir_path = "/test_root/test_subdir"
self.testfile_path = "/test_root/test_subdir/test_file.md"
profile_data = {
'outputs': {
'test': {
'type': 'postgres',
'host': 'localhost',
'schema': 'analytics',
'user': 'test',
'pass': 'test',
'dbname': 'test',
'port': 1,
"outputs": {
"test": {
"type": "postgres",
"host": "localhost",
"schema": "analytics",
"user": "test",
"pass": "test",
"dbname": "test",
"port": 1,
}
},
'target': 'test',
"target": "test",
}
root_project = {
'name': 'root',
'version': '0.1',
'profile': 'test',
'project-root': self.root_path,
'config-version': 2,
"name": "root",
"version": "0.1",
"profile": "test",
"project-root": self.root_path,
"config-version": 2,
}
subdir_project = {
'name': 'some_package',
'version': '0.1',
'profile': 'test',
'project-root': self.subdir_path,
'quoting': {},
'config-version': 2,
"name": "some_package",
"version": "0.1",
"profile": "test",
"project-root": self.subdir_path,
"quoting": {},
"config-version": 2,
}
self.root_project_config = config_from_parts_or_dicts(
project=root_project, profile=profile_data
@@ -149,23 +149,23 @@ class DocumentationParserTest(unittest.TestCase):
project=self.subdir_project_config,
)
file_block = self._build_file(TEST_DOCUMENTATION_FILE, 'test_file.md')
file_block = self._build_file(TEST_DOCUMENTATION_FILE, "test_file.md")
parser.parse_file(file_block)
docs_values = sorted(parser.manifest.docs.values(), key=lambda n: n.name)
self.assertEqual(len(docs_values), 2)
for result in docs_values:
self.assertIsInstance(result, Documentation)
self.assertEqual(result.package_name, 'some_package')
self.assertEqual(result.package_name, "some_package")
self.assertEqual(result.original_file_path, self.testfile_path)
self.assertEqual(result.resource_type, NodeType.Documentation)
self.assertEqual(result.path, 'test_file.md')
self.assertEqual(result.path, "test_file.md")
self.assertEqual(docs_values[0].name, 'snowplow_sessions')
self.assertEqual(docs_values[1].name, 'snowplow_sessions__session_id')
self.assertEqual(docs_values[0].name, "snowplow_sessions")
self.assertEqual(docs_values[1].name, "snowplow_sessions__session_id")
def test_load_file_extras(self):
TEST_DOCUMENTATION_FILE + '{% model foo %}select 1 as id{% endmodel %}'
TEST_DOCUMENTATION_FILE + "{% model foo %}select 1 as id{% endmodel %}"
parser = docs.DocumentationParser(
root_project=self.root_project_config,
@@ -173,15 +173,15 @@ class DocumentationParserTest(unittest.TestCase):
project=self.subdir_project_config,
)
file_block = self._build_file(TEST_DOCUMENTATION_FILE, 'test_file.md')
file_block = self._build_file(TEST_DOCUMENTATION_FILE, "test_file.md")
parser.parse_file(file_block)
docs_values = sorted(parser.manifest.docs.values(), key=lambda n: n.name)
self.assertEqual(len(docs_values), 2)
for result in docs_values:
self.assertIsInstance(result, Documentation)
self.assertEqual(docs_values[0].name, 'snowplow_sessions')
self.assertEqual(docs_values[1].name, 'snowplow_sessions__session_id')
self.assertEqual(docs_values[0].name, "snowplow_sessions")
self.assertEqual(docs_values[1].name, "snowplow_sessions__session_id")
def test_multiple_raw_blocks(self):
parser = docs.DocumentationParser(
@@ -190,19 +190,24 @@ class DocumentationParserTest(unittest.TestCase):
project=self.subdir_project_config,
)
file_block = self._build_file(MULTIPLE_RAW_BLOCKS, 'test_file.md')
file_block = self._build_file(MULTIPLE_RAW_BLOCKS, "test_file.md")
parser.parse_file(file_block)
docs_values = sorted(parser.manifest.docs.values(), key=lambda n: n.name)
self.assertEqual(len(docs_values), 2)
for result in docs_values:
self.assertIsInstance(result, Documentation)
self.assertEqual(result.package_name, 'some_package')
self.assertEqual(result.package_name, "some_package")
self.assertEqual(result.original_file_path, self.testfile_path)
self.assertEqual(result.resource_type, NodeType.Documentation)
self.assertEqual(result.path, 'test_file.md')
self.assertEqual(result.path, "test_file.md")
self.assertEqual(docs_values[0].name, 'other_doc')
self.assertEqual(docs_values[0].block_contents, '```\n {% docs %}other doc{% enddocs %}\n ```')
self.assertEqual(docs_values[1].name, 'some_doc')
self.assertEqual(docs_values[1].block_contents, '```\n {% docs %}some doc{% enddocs %}\n ```', )
self.assertEqual(docs_values[0].name, "other_doc")
self.assertEqual(
docs_values[0].block_contents, "```\n {% docs %}other doc{% enddocs %}\n ```"
)
self.assertEqual(docs_values[1].name, "some_doc")
self.assertEqual(
docs_values[1].block_contents,
"```\n {% docs %}some doc{% enddocs %}\n ```",
)

View File

@@ -1,9 +1,7 @@
from datetime import datetime
from decimal import Decimal
from unittest import mock
import unittest
import dbt.flags
from dbt.task import generate
@@ -11,17 +9,14 @@ class GenerateTest(unittest.TestCase):
def setUp(self):
self.maxDiff = None
self.manifest = mock.MagicMock()
self.patcher = mock.patch('dbt.task.generate.get_unique_id_mapping')
self.patcher = mock.patch("dbt.task.generate.get_unique_id_mapping")
self.mock_get_unique_id_mapping = self.patcher.start()
def tearDown(self):
self.patcher.stop()
def map_uids(self, effects):
results = {
generate.CatalogKey(db, sch, tbl): uid
for db, sch, tbl, uid in effects
}
results = {generate.CatalogKey(db, sch, tbl): uid for db, sch, tbl, uid in effects}
self.mock_get_unique_id_mapping.return_value = results, {}
def generate_catalog_dict(self, columns):
@@ -31,7 +26,7 @@ class GenerateTest(unittest.TestCase):
sources=sources,
errors=None,
)
return result.to_dict(omit_none=False)['nodes']
return result.to_dict(omit_none=False)["nodes"]
def test__unflatten_empty(self):
columns = {}
@@ -44,48 +39,45 @@ class GenerateTest(unittest.TestCase):
self.assertEqual(result, expected)
def test__unflatten_one_column(self):
columns = [{
'column_comment': None,
'column_index': Decimal('1'),
'column_name': 'id',
'column_type': 'integer',
'table_comment': None,
'table_name': 'test_table',
'table_schema': 'test_schema',
'table_type': 'BASE TABLE',
'table_database': 'test_database',
}]
columns = [
{
"column_comment": None,
"column_index": Decimal("1"),
"column_name": "id",
"column_type": "integer",
"table_comment": None,
"table_name": "test_table",
"table_schema": "test_schema",
"table_type": "BASE TABLE",
"table_database": "test_database",
}
]
expected = {
'test.model.test_table': {
'metadata': {
'owner': None,
'comment': None,
'name': 'test_table',
'type': 'BASE TABLE',
'schema': 'test_schema',
'database': 'test_database',
"test.model.test_table": {
"metadata": {
"owner": None,
"comment": None,
"name": "test_table",
"type": "BASE TABLE",
"schema": "test_schema",
"database": "test_database",
},
'columns': {
'id': {
'type': 'integer',
'comment': None,
'index': 1,
'name': 'id'
"columns": {
"id": {"type": "integer", "comment": None, "index": 1, "name": "id"},
},
"stats": {
"has_stats": {
"id": "has_stats",
"label": "Has Stats?",
"value": False,
"description": "Indicates whether there are statistics for this table",
"include": False,
},
},
'stats': {
'has_stats': {
'id': 'has_stats',
'label': 'Has Stats?',
'value': False,
'description': 'Indicates whether there are statistics for this table',
'include': False,
},
},
'unique_id': 'test.model.test_table',
"unique_id": "test.model.test_table",
},
}
self.map_uids([('test_database', 'test_schema', 'test_table', 'test.model.test_table')])
self.map_uids([("test_database", "test_schema", "test_table", "test.model.test_table")])
result = self.generate_catalog_dict(columns)
@@ -95,258 +87,243 @@ class GenerateTest(unittest.TestCase):
def test__unflatten_multiple_schemas_dbs(self):
columns = [
{
'column_comment': None,
'column_index': Decimal('1'),
'column_name': 'id',
'column_type': 'integer',
'table_comment': None,
'table_name': 'test_table',
'table_schema': 'test_schema',
'table_type': 'BASE TABLE',
'table_database': 'test_database',
'table_owner': None,
"column_comment": None,
"column_index": Decimal("1"),
"column_name": "id",
"column_type": "integer",
"table_comment": None,
"table_name": "test_table",
"table_schema": "test_schema",
"table_type": "BASE TABLE",
"table_database": "test_database",
"table_owner": None,
},
{
'column_comment': None,
'column_index': Decimal('2'),
'column_name': 'name',
'column_type': 'text',
'table_comment': None,
'table_name': 'test_table',
'table_schema': 'test_schema',
'table_type': 'BASE TABLE',
'table_database': 'test_database',
'table_owner': None,
"column_comment": None,
"column_index": Decimal("2"),
"column_name": "name",
"column_type": "text",
"table_comment": None,
"table_name": "test_table",
"table_schema": "test_schema",
"table_type": "BASE TABLE",
"table_database": "test_database",
"table_owner": None,
},
{
'column_comment': None,
'column_index': Decimal('1'),
'column_name': 'id',
'column_type': 'integer',
'table_comment': None,
'table_name': 'other_test_table',
'table_schema': 'test_schema',
'table_type': 'BASE TABLE',
'table_database': 'test_database',
'table_owner': None,
"column_comment": None,
"column_index": Decimal("1"),
"column_name": "id",
"column_type": "integer",
"table_comment": None,
"table_name": "other_test_table",
"table_schema": "test_schema",
"table_type": "BASE TABLE",
"table_database": "test_database",
"table_owner": None,
},
{
'column_comment': None,
'column_index': Decimal('2'),
'column_name': 'email',
'column_type': 'character varying',
'table_comment': None,
'table_name': 'other_test_table',
'table_schema': 'test_schema',
'table_type': 'BASE TABLE',
'table_database': 'test_database',
'table_owner': None,
"column_comment": None,
"column_index": Decimal("2"),
"column_name": "email",
"column_type": "character varying",
"table_comment": None,
"table_name": "other_test_table",
"table_schema": "test_schema",
"table_type": "BASE TABLE",
"table_database": "test_database",
"table_owner": None,
},
{
'column_comment': None,
'column_index': Decimal('1'),
'column_name': 'id',
'column_type': 'integer',
'table_comment': None,
'table_name': 'test_table',
'table_schema': 'other_test_schema',
'table_type': 'BASE TABLE',
'table_database': 'test_database',
'table_owner': None,
"column_comment": None,
"column_index": Decimal("1"),
"column_name": "id",
"column_type": "integer",
"table_comment": None,
"table_name": "test_table",
"table_schema": "other_test_schema",
"table_type": "BASE TABLE",
"table_database": "test_database",
"table_owner": None,
},
{
'column_comment': None,
'column_index': Decimal('2'),
'column_name': 'name',
'column_type': 'text',
'table_comment': None,
'table_name': 'test_table',
'table_schema': 'other_test_schema',
'table_type': 'BASE TABLE',
'table_database': 'test_database',
'table_owner': None,
"column_comment": None,
"column_index": Decimal("2"),
"column_name": "name",
"column_type": "text",
"table_comment": None,
"table_name": "test_table",
"table_schema": "other_test_schema",
"table_type": "BASE TABLE",
"table_database": "test_database",
"table_owner": None,
},
{
'column_comment': None,
'column_index': Decimal('1'),
'column_name': 'id',
'column_type': 'integer',
'table_comment': None,
'table_name': 'test_table',
'table_schema': 'test_schema',
'table_type': 'BASE TABLE',
'table_database': 'other_test_database',
'table_owner': None,
"column_comment": None,
"column_index": Decimal("1"),
"column_name": "id",
"column_type": "integer",
"table_comment": None,
"table_name": "test_table",
"table_schema": "test_schema",
"table_type": "BASE TABLE",
"table_database": "other_test_database",
"table_owner": None,
},
{
'column_comment': None,
'column_index': Decimal('2'),
'column_name': 'name',
'column_type': 'text',
'table_comment': None,
'table_name': 'test_table',
'table_schema': 'test_schema',
'table_type': 'BASE TABLE',
'table_database': 'other_test_database',
'table_owner': None,
"column_comment": None,
"column_index": Decimal("2"),
"column_name": "name",
"column_type": "text",
"table_comment": None,
"table_name": "test_table",
"table_schema": "test_schema",
"table_type": "BASE TABLE",
"table_database": "other_test_database",
"table_owner": None,
},
]
expected = {
'test.model.test_table': {
'metadata': {
'owner': None,
'comment': None,
'name': 'test_table',
'type': 'BASE TABLE',
'schema': 'test_schema',
'database': 'test_database',
"test.model.test_table": {
"metadata": {
"owner": None,
"comment": None,
"name": "test_table",
"type": "BASE TABLE",
"schema": "test_schema",
"database": "test_database",
},
'columns': {
'id': {
'type': 'integer',
'comment': None,
'index': 1,
'name': 'id'
},
'name': {
'type': 'text',
'comment': None,
'index': 2,
'name': 'name',
}
},
'stats': {
'has_stats': {
'id': 'has_stats',
'label': 'Has Stats?',
'value': False,
'description': 'Indicates whether there are statistics for this table',
'include': False,
"columns": {
"id": {"type": "integer", "comment": None, "index": 1, "name": "id"},
"name": {
"type": "text",
"comment": None,
"index": 2,
"name": "name",
},
},
'unique_id': 'test.model.test_table',
"stats": {
"has_stats": {
"id": "has_stats",
"label": "Has Stats?",
"value": False,
"description": "Indicates whether there are statistics for this table",
"include": False,
},
},
"unique_id": "test.model.test_table",
},
'test.model.other_test_table': {
'metadata': {
'owner': None,
'comment': None,
'name': 'other_test_table',
'type': 'BASE TABLE',
'schema': 'test_schema',
'database': 'test_database',
"test.model.other_test_table": {
"metadata": {
"owner": None,
"comment": None,
"name": "other_test_table",
"type": "BASE TABLE",
"schema": "test_schema",
"database": "test_database",
},
'columns': {
'id': {
'type': 'integer',
'comment': None,
'index': 1,
'name': 'id'
},
'email': {
'type': 'character varying',
'comment': None,
'index': 2,
'name': 'email',
}
},
'stats': {
'has_stats': {
'id': 'has_stats',
'label': 'Has Stats?',
'value': False,
'description': 'Indicates whether there are statistics for this table',
'include': False,
"columns": {
"id": {"type": "integer", "comment": None, "index": 1, "name": "id"},
"email": {
"type": "character varying",
"comment": None,
"index": 2,
"name": "email",
},
},
'unique_id': 'test.model.other_test_table',
"stats": {
"has_stats": {
"id": "has_stats",
"label": "Has Stats?",
"value": False,
"description": "Indicates whether there are statistics for this table",
"include": False,
},
},
"unique_id": "test.model.other_test_table",
},
'test.model.test_table_otherschema': {
'metadata': {
'owner': None,
'comment': None,
'name': 'test_table',
'type': 'BASE TABLE',
'schema': 'other_test_schema',
'database': 'test_database',
"test.model.test_table_otherschema": {
"metadata": {
"owner": None,
"comment": None,
"name": "test_table",
"type": "BASE TABLE",
"schema": "other_test_schema",
"database": "test_database",
},
'columns': {
'id': {
'type': 'integer',
'comment': None,
'index': 1,
'name': 'id'
},
'name': {
'type': 'text',
'comment': None,
'index': 2,
'name': 'name',
}
},
'stats': {
'has_stats': {
'id': 'has_stats',
'label': 'Has Stats?',
'value': False,
'description': 'Indicates whether there are statistics for this table',
'include': False,
"columns": {
"id": {"type": "integer", "comment": None, "index": 1, "name": "id"},
"name": {
"type": "text",
"comment": None,
"index": 2,
"name": "name",
},
},
'unique_id': 'test.model.test_table_otherschema',
"stats": {
"has_stats": {
"id": "has_stats",
"label": "Has Stats?",
"value": False,
"description": "Indicates whether there are statistics for this table",
"include": False,
},
},
"unique_id": "test.model.test_table_otherschema",
},
'test.model.test_table_otherdb': {
'metadata': {
'owner': None,
'comment': None,
'name': 'test_table',
'type': 'BASE TABLE',
'schema': 'test_schema',
'database': 'other_test_database',
"test.model.test_table_otherdb": {
"metadata": {
"owner": None,
"comment": None,
"name": "test_table",
"type": "BASE TABLE",
"schema": "test_schema",
"database": "other_test_database",
},
'columns': {
'id': {
'type': 'integer',
'comment': None,
'index': 1,
'name': 'id'
},
'name': {
'type': 'text',
'comment': None,
'index': 2,
'name': 'name',
}
},
'stats': {
'has_stats': {
'id': 'has_stats',
'label': 'Has Stats?',
'value': False,
'description': 'Indicates whether there are statistics for this table',
'include': False,
"columns": {
"id": {"type": "integer", "comment": None, "index": 1, "name": "id"},
"name": {
"type": "text",
"comment": None,
"index": 2,
"name": "name",
},
},
'unique_id': 'test.model.test_table_otherdb',
}
"stats": {
"has_stats": {
"id": "has_stats",
"label": "Has Stats?",
"value": False,
"description": "Indicates whether there are statistics for this table",
"include": False,
},
},
"unique_id": "test.model.test_table_otherdb",
},
}
self.map_uids([
(
'test_database', 'test_schema', 'test_table',
'test.model.test_table'
),
(
'test_database', 'test_schema', 'other_test_table',
'test.model.other_test_table'
),
(
'test_database', 'other_test_schema', 'test_table',
'test.model.test_table_otherschema'
),
(
'other_test_database', 'test_schema', 'test_table',
'test.model.test_table_otherdb'
),
])
self.map_uids(
[
("test_database", "test_schema", "test_table", "test.model.test_table"),
(
"test_database",
"test_schema",
"other_test_table",
"test.model.other_test_table",
),
(
"test_database",
"other_test_schema",
"test_table",
"test.model.test_table_otherschema",
),
(
"other_test_database",
"test_schema",
"test_table",
"test.model.test_table_otherdb",
),
]
)
result = self.generate_catalog_dict(columns)

View File

@@ -5,31 +5,31 @@ from .utils import MockMacro
def test_raise_duplicate_macros_different_package():
macro_1 = MockMacro(package='dbt', name='some_macro')
macro_2 = MockMacro(package='dbt-myadapter', name='some_macro')
macro_1 = MockMacro(package="dbt", name="some_macro")
macro_2 = MockMacro(package="dbt-myadapter", name="some_macro")
with pytest.raises(CompilationError) as exc:
raise_duplicate_macro_name(
node_1=macro_1,
node_2=macro_2,
namespace='dbt',
namespace="dbt",
)
assert 'dbt-myadapter' in str(exc.value)
assert 'some_macro' in str(exc.value)
assert "dbt-myadapter" in str(exc.value)
assert "some_macro" in str(exc.value)
assert 'namespace "dbt"' in str(exc.value)
assert '("dbt" and "dbt-myadapter" are both in the "dbt" namespace)' in str(exc.value)
def test_raise_duplicate_macros_same_package():
macro_1 = MockMacro(package='dbt', name='some_macro')
macro_2 = MockMacro(package='dbt', name='some_macro')
macro_1 = MockMacro(package="dbt", name="some_macro")
macro_2 = MockMacro(package="dbt", name="some_macro")
with pytest.raises(CompilationError) as exc:
raise_duplicate_macro_name(
node_1=macro_1,
node_2=macro_2,
namespace='dbt',
namespace="dbt",
)
assert 'some_macro' in str(exc.value)
assert "some_macro" in str(exc.value)
assert 'namespace "dbt"' in str(exc.value)
assert "are both in" not in str(exc.value)

View File

@@ -10,8 +10,9 @@ from dbt.helper_types import WarnErrorOptions
# Skip due to interface for flag updated
pytestmark = pytest.mark.skip
class TestFlags(TestCase):
class TestFlags(TestCase):
def setUp(self):
self.args = Namespace()
self.user_config = UserConfig()
@@ -22,15 +23,15 @@ class TestFlags(TestCase):
self.user_config.use_experimental_parser = True
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.USE_EXPERIMENTAL_PARSER, True)
os.environ['DBT_USE_EXPERIMENTAL_PARSER'] = 'false'
os.environ["DBT_USE_EXPERIMENTAL_PARSER"] = "false"
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.USE_EXPERIMENTAL_PARSER, False)
setattr(self.args, 'use_experimental_parser', True)
setattr(self.args, "use_experimental_parser", True)
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.USE_EXPERIMENTAL_PARSER, True)
# cleanup
os.environ.pop('DBT_USE_EXPERIMENTAL_PARSER')
delattr(self.args, 'use_experimental_parser')
os.environ.pop("DBT_USE_EXPERIMENTAL_PARSER")
delattr(self.args, "use_experimental_parser")
flags.USE_EXPERIMENTAL_PARSER = False
self.user_config.use_experimental_parser = None
@@ -38,15 +39,15 @@ class TestFlags(TestCase):
self.user_config.static_parser = False
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.STATIC_PARSER, False)
os.environ['DBT_STATIC_PARSER'] = 'true'
os.environ["DBT_STATIC_PARSER"] = "true"
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.STATIC_PARSER, True)
setattr(self.args, 'static_parser', False)
setattr(self.args, "static_parser", False)
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.STATIC_PARSER, False)
# cleanup
os.environ.pop('DBT_STATIC_PARSER')
delattr(self.args, 'static_parser')
os.environ.pop("DBT_STATIC_PARSER")
delattr(self.args, "static_parser")
flags.STATIC_PARSER = True
self.user_config.static_parser = None
@@ -54,15 +55,15 @@ class TestFlags(TestCase):
self.user_config.warn_error = False
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.WARN_ERROR, False)
os.environ['DBT_WARN_ERROR'] = 'true'
os.environ["DBT_WARN_ERROR"] = "true"
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.WARN_ERROR, True)
setattr(self.args, 'warn_error', False)
setattr(self.args, "warn_error", False)
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.WARN_ERROR, False)
# cleanup
os.environ.pop('DBT_WARN_ERROR')
delattr(self.args, 'warn_error')
os.environ.pop("DBT_WARN_ERROR")
delattr(self.args, "warn_error")
flags.WARN_ERROR = False
self.user_config.warn_error = None
@@ -70,175 +71,175 @@ class TestFlags(TestCase):
self.user_config.warn_error_options = '{"include": "all"}'
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.WARN_ERROR_OPTIONS, WarnErrorOptions(include="all"))
os.environ['DBT_WARN_ERROR_OPTIONS'] = '{"include": []}'
os.environ["DBT_WARN_ERROR_OPTIONS"] = '{"include": []}'
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.WARN_ERROR_OPTIONS, WarnErrorOptions(include=[]))
setattr(self.args, 'warn_error_options', '{"include": "all"}')
setattr(self.args, "warn_error_options", '{"include": "all"}')
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.WARN_ERROR_OPTIONS, WarnErrorOptions(include="all"))
# cleanup
os.environ.pop('DBT_WARN_ERROR_OPTIONS')
delattr(self.args, 'warn_error_options')
os.environ.pop("DBT_WARN_ERROR_OPTIONS")
delattr(self.args, "warn_error_options")
self.user_config.warn_error_options = None
# write_json
self.user_config.write_json = True
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.WRITE_JSON, True)
os.environ['DBT_WRITE_JSON'] = 'false'
os.environ["DBT_WRITE_JSON"] = "false"
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.WRITE_JSON, False)
setattr(self.args, 'write_json', True)
setattr(self.args, "write_json", True)
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.WRITE_JSON, True)
# cleanup
os.environ.pop('DBT_WRITE_JSON')
delattr(self.args, 'write_json')
os.environ.pop("DBT_WRITE_JSON")
delattr(self.args, "write_json")
# partial_parse
self.user_config.partial_parse = True
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.PARTIAL_PARSE, True)
os.environ['DBT_PARTIAL_PARSE'] = 'false'
os.environ["DBT_PARTIAL_PARSE"] = "false"
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.PARTIAL_PARSE, False)
setattr(self.args, 'partial_parse', True)
setattr(self.args, "partial_parse", True)
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.PARTIAL_PARSE, True)
# cleanup
os.environ.pop('DBT_PARTIAL_PARSE')
delattr(self.args, 'partial_parse')
os.environ.pop("DBT_PARTIAL_PARSE")
delattr(self.args, "partial_parse")
self.user_config.partial_parse = False
# use_colors
self.user_config.use_colors = True
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.USE_COLORS, True)
os.environ['DBT_USE_COLORS'] = 'false'
os.environ["DBT_USE_COLORS"] = "false"
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.USE_COLORS, False)
setattr(self.args, 'use_colors', True)
setattr(self.args, "use_colors", True)
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.USE_COLORS, True)
# cleanup
os.environ.pop('DBT_USE_COLORS')
delattr(self.args, 'use_colors')
os.environ.pop("DBT_USE_COLORS")
delattr(self.args, "use_colors")
# debug
self.user_config.debug = True
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.DEBUG, True)
os.environ['DBT_DEBUG'] = 'True'
os.environ["DBT_DEBUG"] = "True"
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.DEBUG, True)
os.environ['DBT_DEBUG'] = 'False'
setattr(self.args, 'debug', True)
os.environ["DBT_DEBUG"] = "False"
setattr(self.args, "debug", True)
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.DEBUG, True)
# cleanup
os.environ.pop('DBT_DEBUG')
delattr(self.args, 'debug')
os.environ.pop("DBT_DEBUG")
delattr(self.args, "debug")
self.user_config.debug = None
# log_format -- text, json, default
self.user_config.log_format = 'text'
self.user_config.log_format = "text"
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.LOG_FORMAT, 'text')
os.environ['DBT_LOG_FORMAT'] = 'json'
self.assertEqual(flags.LOG_FORMAT, "text")
os.environ["DBT_LOG_FORMAT"] = "json"
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.LOG_FORMAT, 'json')
setattr(self.args, 'log_format', 'text')
self.assertEqual(flags.LOG_FORMAT, "json")
setattr(self.args, "log_format", "text")
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.LOG_FORMAT, 'text')
self.assertEqual(flags.LOG_FORMAT, "text")
# cleanup
os.environ.pop('DBT_LOG_FORMAT')
delattr(self.args, 'log_format')
os.environ.pop("DBT_LOG_FORMAT")
delattr(self.args, "log_format")
self.user_config.log_format = None
# version_check
self.user_config.version_check = True
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.VERSION_CHECK, True)
os.environ['DBT_VERSION_CHECK'] = 'false'
os.environ["DBT_VERSION_CHECK"] = "false"
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.VERSION_CHECK, False)
setattr(self.args, 'version_check', True)
setattr(self.args, "version_check", True)
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.VERSION_CHECK, True)
# cleanup
os.environ.pop('DBT_VERSION_CHECK')
delattr(self.args, 'version_check')
os.environ.pop("DBT_VERSION_CHECK")
delattr(self.args, "version_check")
# fail_fast
self.user_config.fail_fast = True
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.FAIL_FAST, True)
os.environ['DBT_FAIL_FAST'] = 'false'
os.environ["DBT_FAIL_FAST"] = "false"
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.FAIL_FAST, False)
setattr(self.args, 'fail_fast', True)
setattr(self.args, "fail_fast", True)
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.FAIL_FAST, True)
# cleanup
os.environ.pop('DBT_FAIL_FAST')
delattr(self.args, 'fail_fast')
os.environ.pop("DBT_FAIL_FAST")
delattr(self.args, "fail_fast")
self.user_config.fail_fast = False
# send_anonymous_usage_stats
self.user_config.send_anonymous_usage_stats = True
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.SEND_ANONYMOUS_USAGE_STATS, True)
os.environ['DBT_SEND_ANONYMOUS_USAGE_STATS'] = 'false'
os.environ["DBT_SEND_ANONYMOUS_USAGE_STATS"] = "false"
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.SEND_ANONYMOUS_USAGE_STATS, False)
setattr(self.args, 'send_anonymous_usage_stats', True)
setattr(self.args, "send_anonymous_usage_stats", True)
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.SEND_ANONYMOUS_USAGE_STATS, True)
os.environ['DO_NOT_TRACK'] = '1'
os.environ["DO_NOT_TRACK"] = "1"
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.SEND_ANONYMOUS_USAGE_STATS, False)
# cleanup
os.environ.pop('DBT_SEND_ANONYMOUS_USAGE_STATS')
os.environ.pop('DO_NOT_TRACK')
delattr(self.args, 'send_anonymous_usage_stats')
os.environ.pop("DBT_SEND_ANONYMOUS_USAGE_STATS")
os.environ.pop("DO_NOT_TRACK")
delattr(self.args, "send_anonymous_usage_stats")
# printer_width
self.user_config.printer_width = 100
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.PRINTER_WIDTH, 100)
os.environ['DBT_PRINTER_WIDTH'] = '80'
os.environ["DBT_PRINTER_WIDTH"] = "80"
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.PRINTER_WIDTH, 80)
setattr(self.args, 'printer_width', '120')
setattr(self.args, "printer_width", "120")
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.PRINTER_WIDTH, 120)
# cleanup
os.environ.pop('DBT_PRINTER_WIDTH')
delattr(self.args, 'printer_width')
os.environ.pop("DBT_PRINTER_WIDTH")
delattr(self.args, "printer_width")
self.user_config.printer_width = None
# indirect_selection
self.user_config.indirect_selection = 'eager'
self.user_config.indirect_selection = "eager"
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.INDIRECT_SELECTION, IndirectSelection.Eager)
self.user_config.indirect_selection = 'cautious'
self.user_config.indirect_selection = "cautious"
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.INDIRECT_SELECTION, IndirectSelection.Cautious)
self.user_config.indirect_selection = 'buildable'
self.user_config.indirect_selection = "buildable"
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.INDIRECT_SELECTION, IndirectSelection.Buildable)
self.user_config.indirect_selection = None
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.INDIRECT_SELECTION, IndirectSelection.Eager)
os.environ['DBT_INDIRECT_SELECTION'] = 'cautious'
os.environ["DBT_INDIRECT_SELECTION"] = "cautious"
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.INDIRECT_SELECTION, IndirectSelection.Cautious)
setattr(self.args, 'indirect_selection', 'cautious')
setattr(self.args, "indirect_selection", "cautious")
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.INDIRECT_SELECTION, IndirectSelection.Cautious)
# cleanup
os.environ.pop('DBT_INDIRECT_SELECTION')
delattr(self.args, 'indirect_selection')
os.environ.pop("DBT_INDIRECT_SELECTION")
delattr(self.args, "indirect_selection")
self.user_config.indirect_selection = None
# quiet
@@ -259,29 +260,29 @@ class TestFlags(TestCase):
self.user_config.cache_selected_only = True
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.CACHE_SELECTED_ONLY, True)
os.environ['DBT_CACHE_SELECTED_ONLY'] = 'false'
os.environ["DBT_CACHE_SELECTED_ONLY"] = "false"
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.CACHE_SELECTED_ONLY, False)
setattr(self.args, 'cache_selected_only', True)
setattr(self.args, "cache_selected_only", True)
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.CACHE_SELECTED_ONLY, True)
# cleanup
os.environ.pop('DBT_CACHE_SELECTED_ONLY')
delattr(self.args, 'cache_selected_only')
os.environ.pop("DBT_CACHE_SELECTED_ONLY")
delattr(self.args, "cache_selected_only")
self.user_config.cache_selected_only = False
# target_path/log_path
flags.set_from_args(self.args, self.user_config)
self.assertIsNone(flags.LOG_PATH)
os.environ['DBT_LOG_PATH'] = 'a/b/c'
os.environ["DBT_LOG_PATH"] = "a/b/c"
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.LOG_PATH, 'a/b/c')
setattr(self.args, 'log_path', 'd/e/f')
self.assertEqual(flags.LOG_PATH, "a/b/c")
setattr(self.args, "log_path", "d/e/f")
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.LOG_PATH, 'd/e/f')
self.assertEqual(flags.LOG_PATH, "d/e/f")
# cleanup
os.environ.pop('DBT_LOG_PATH')
delattr(self.args, 'log_path')
os.environ.pop("DBT_LOG_PATH")
delattr(self.args, "log_path")
def test__flags_are_mutually_exclusive(self):
# options from user config
@@ -289,52 +290,51 @@ class TestFlags(TestCase):
self.user_config.warn_error_options = '{"include":"all"}'
with pytest.raises(ValueError):
flags.set_from_args(self.args, self.user_config)
#cleanup
# cleanup
self.user_config.warn_error = None
self.user_config.warn_error_options = None
# options from args
setattr(self.args, 'warn_error', False)
setattr(self.args, 'warn_error_options', '{"include":"all"}')
setattr(self.args, "warn_error", False)
setattr(self.args, "warn_error_options", '{"include":"all"}')
with pytest.raises(ValueError):
flags.set_from_args(self.args, self.user_config)
# cleanup
delattr(self.args, 'warn_error')
delattr(self.args, 'warn_error_options')
delattr(self.args, "warn_error")
delattr(self.args, "warn_error_options")
# options from environment
os.environ['DBT_WARN_ERROR'] = 'false'
os.environ['DBT_WARN_ERROR_OPTIONS'] = '{"include": []}'
os.environ["DBT_WARN_ERROR"] = "false"
os.environ["DBT_WARN_ERROR_OPTIONS"] = '{"include": []}'
with pytest.raises(ValueError):
flags.set_from_args(self.args, self.user_config)
#cleanup
os.environ.pop('DBT_WARN_ERROR')
os.environ.pop('DBT_WARN_ERROR_OPTIONS')
# cleanup
os.environ.pop("DBT_WARN_ERROR")
os.environ.pop("DBT_WARN_ERROR_OPTIONS")
# options from user config + args
self.user_config.warn_error = False
setattr(self.args, 'warn_error_options', '{"include":"all"}')
setattr(self.args, "warn_error_options", '{"include":"all"}')
with pytest.raises(ValueError):
flags.set_from_args(self.args, self.user_config)
# cleanup
self.user_config.warn_error = None
delattr(self.args, 'warn_error_options')
delattr(self.args, "warn_error_options")
# options from user config + environ
self.user_config.warn_error = False
os.environ['DBT_WARN_ERROR_OPTIONS'] = '{"include": []}'
os.environ["DBT_WARN_ERROR_OPTIONS"] = '{"include": []}'
with pytest.raises(ValueError):
flags.set_from_args(self.args, self.user_config)
# cleanup
self.user_config.warn_error = None
os.environ.pop('DBT_WARN_ERROR_OPTIONS')
os.environ.pop("DBT_WARN_ERROR_OPTIONS")
# options from args + environ
setattr(self.args, 'warn_error', False)
os.environ['DBT_WARN_ERROR_OPTIONS'] = '{"include": []}'
setattr(self.args, "warn_error", False)
os.environ["DBT_WARN_ERROR_OPTIONS"] = '{"include": []}'
with pytest.raises(ValueError):
flags.set_from_args(self.args, self.user_config)
# cleanup
delattr(self.args, 'warn_error')
os.environ.pop('DBT_WARN_ERROR_OPTIONS')
delattr(self.args, "warn_error")
os.environ.pop("DBT_WARN_ERROR_OPTIONS")

View File

@@ -28,7 +28,6 @@ from .utils import config_from_parts_or_dicts, generate_name_macros, inject_plug
class GraphTest(unittest.TestCase):
def tearDown(self):
self.mock_filesystem_search.stop()
self.mock_hook_constructor.stop()
@@ -41,65 +40,69 @@ class GraphTest(unittest.TestCase):
self.graph_result = None
tracking.do_not_track()
self.profile = {
'outputs': {
'test': {
'type': 'postgres',
'threads': 4,
'host': 'thishostshouldnotexist',
'port': 5432,
'user': 'root',
'pass': 'password',
'dbname': 'dbt',
'schema': 'dbt_test'
"outputs": {
"test": {
"type": "postgres",
"threads": 4,
"host": "thishostshouldnotexist",
"port": 5432,
"user": "root",
"pass": "password",
"dbname": "dbt",
"schema": "dbt_test",
}
},
'target': 'test'
"target": "test",
}
self.macro_manifest = MacroManifest(
{n.unique_id: n for n in generate_name_macros('test_models_compile')})
{n.unique_id: n for n in generate_name_macros("test_models_compile")}
)
self.mock_models = [] # used by filesystem_searcher
# Create file filesystem searcher
self.filesystem_search = patch('dbt.parser.read_files.filesystem_search')
self.filesystem_search = patch("dbt.parser.read_files.filesystem_search")
def mock_filesystem_search(project, relative_dirs, extension, ignore_spec):
if 'sql' not in extension:
if "sql" not in extension:
return []
if 'models' not in relative_dirs:
if "models" not in relative_dirs:
return []
return [model.path for model in self.mock_models]
self.mock_filesystem_search = self.filesystem_search.start()
self.mock_filesystem_search.side_effect = mock_filesystem_search
# Create HookParser patcher
self.hook_patcher = patch.object(
dbt.parser.hooks.HookParser, '__new__'
)
self.hook_patcher = patch.object(dbt.parser.hooks.HookParser, "__new__")
def create_hook_patcher(cls, project, manifest, root_project):
result = MagicMock(project=project, manifest=manifest, root_project=root_project)
result.__iter__.side_effect = lambda: iter([])
return result
self.mock_hook_constructor = self.hook_patcher.start()
self.mock_hook_constructor.side_effect = create_hook_patcher
# Create the Manifest.state_check patcher
@patch('dbt.parser.manifest.ManifestLoader.build_manifest_state_check')
@patch("dbt.parser.manifest.ManifestLoader.build_manifest_state_check")
def _mock_state_check(self):
all_projects = self.all_projects
return ManifestStateCheck(
project_env_vars_hash=FileHash.from_contents(''),
profile_env_vars_hash=FileHash.from_contents(''),
vars_hash=FileHash.from_contents('vars'),
project_env_vars_hash=FileHash.from_contents(""),
profile_env_vars_hash=FileHash.from_contents(""),
vars_hash=FileHash.from_contents("vars"),
project_hashes={name: FileHash.from_contents(name) for name in all_projects},
profile_hash=FileHash.from_contents('profile'),
profile_hash=FileHash.from_contents("profile"),
)
self.load_state_check = patch('dbt.parser.manifest.ManifestLoader.build_manifest_state_check')
self.load_state_check = patch(
"dbt.parser.manifest.ManifestLoader.build_manifest_state_check"
)
self.mock_state_check = self.load_state_check.start()
self.mock_state_check.side_effect = _mock_state_check
# Create the source file patcher
self.load_source_file_patcher = patch('dbt.parser.read_files.load_source_file')
self.load_source_file_patcher = patch("dbt.parser.read_files.load_source_file")
self.mock_source_file = self.load_source_file_patcher.start()
def mock_load_source_file(path, parse_file_type, project_name, saved_files):
@@ -109,14 +112,15 @@ class GraphTest(unittest.TestCase):
source_file.project_name = project_name
source_file.parse_file_type = parse_file_type
return source_file
self.mock_source_file.side_effect = mock_load_source_file
@patch('dbt.parser.hooks.HookParser.get_path')
@patch("dbt.parser.hooks.HookParser.get_path")
def _mock_hook_path(self):
path = FilePath(
searched_path='.',
searched_path=".",
project_root=os.path.normcase(os.getcwd()),
relative_path='dbt_project.yml',
relative_path="dbt_project.yml",
modification_time=0.0,
)
return path
@@ -126,11 +130,11 @@ class GraphTest(unittest.TestCase):
extra_cfg = {}
cfg = {
'name': 'test_models_compile',
'version': '0.1',
'profile': 'test',
'project-root': os.path.abspath('.'),
'config-version': 2,
"name": "test_models_compile",
"version": "0.1",
"profile": "test",
"project-root": os.path.abspath("."),
"config-version": 2,
}
cfg.update(extra_cfg)
@@ -145,13 +149,13 @@ class GraphTest(unittest.TestCase):
def use_models(self, models):
for k, v in models.items():
path = FilePath(
searched_path='models',
searched_path="models",
project_root=os.path.normcase(os.getcwd()),
relative_path='{}.sql'.format(k),
relative_path="{}.sql".format(k),
modification_time=0.0,
)
# FileHash can't be empty or 'search_key' will be None
source_file = SourceFile(path=path, checksum=FileHash.from_contents('abc'))
source_file = SourceFile(path=path, checksum=FileHash.from_contents("abc"))
source_file.contents = v
self.mock_models.append(source_file)
@@ -164,9 +168,11 @@ class GraphTest(unittest.TestCase):
return loader.manifest
def test__single_model(self):
self.use_models({
'model_one': 'select * from events',
})
self.use_models(
{
"model_one": "select * from events",
}
)
config = self.get_config()
manifest = self.load_manifest(config)
@@ -174,19 +180,17 @@ class GraphTest(unittest.TestCase):
compiler = self.get_compiler(config)
linker = compiler.compile(manifest)
self.assertEqual(
list(linker.nodes()),
['model.test_models_compile.model_one'])
self.assertEqual(list(linker.nodes()), ["model.test_models_compile.model_one"])
self.assertEqual(
list(linker.edges()),
[])
self.assertEqual(list(linker.edges()), [])
def test__two_models_simple_ref(self):
self.use_models({
'model_one': 'select * from events',
'model_two': "select * from {{ref('model_one')}}",
})
self.use_models(
{
"model_one": "select * from events",
"model_two": "select * from {{ref('model_one')}}",
}
)
config = self.get_config()
manifest = self.load_manifest(config)
@@ -196,23 +200,30 @@ class GraphTest(unittest.TestCase):
self.assertCountEqual(
linker.nodes(),
[
'model.test_models_compile.model_one',
'model.test_models_compile.model_two',
]
"model.test_models_compile.model_one",
"model.test_models_compile.model_two",
],
)
self.assertCountEqual(
linker.edges(),
[('model.test_models_compile.model_one', 'model.test_models_compile.model_two',)]
[
(
"model.test_models_compile.model_one",
"model.test_models_compile.model_two",
)
],
)
def test__model_materializations(self):
self.use_models({
'model_one': 'select * from events',
'model_two': "select * from {{ref('model_one')}}",
'model_three': "select * from events",
'model_four': "select * from events",
})
self.use_models(
{
"model_one": "select * from events",
"model_two": "select * from {{ref('model_one')}}",
"model_three": "select * from events",
"model_four": "select * from events",
}
)
cfg = {
"models": {
@@ -220,8 +231,8 @@ class GraphTest(unittest.TestCase):
"test_models_compile": {
"model_one": {"materialized": "table"},
"model_two": {"materialized": "view"},
"model_three": {"materialized": "ephemeral"}
}
"model_three": {"materialized": "ephemeral"},
},
}
}
@@ -232,26 +243,21 @@ class GraphTest(unittest.TestCase):
"model_one": "table",
"model_two": "view",
"model_three": "ephemeral",
"model_four": "table"
"model_four": "table",
}
for model, expected in expected_materialization.items():
key = 'model.test_models_compile.{}'.format(model)
key = "model.test_models_compile.{}".format(model)
actual = manifest.nodes[key].config.materialized
self.assertEqual(actual, expected)
def test__model_incremental(self):
self.use_models({
'model_one': 'select * from events'
})
self.use_models({"model_one": "select * from events"})
cfg = {
"models": {
"test_models_compile": {
"model_one": {
"materialized": "incremental",
"unique_key": "id"
},
"model_one": {"materialized": "incremental", "unique_key": "id"},
}
}
}
@@ -261,48 +267,52 @@ class GraphTest(unittest.TestCase):
compiler = self.get_compiler(config)
linker = compiler.compile(manifest)
node = 'model.test_models_compile.model_one'
node = "model.test_models_compile.model_one"
self.assertEqual(list(linker.nodes()), [node])
self.assertEqual(list(linker.edges()), [])
self.assertEqual(manifest.nodes[node].config.materialized, 'incremental')
self.assertEqual(manifest.nodes[node].config.materialized, "incremental")
def test__dependency_list(self):
self.use_models({
'model_1': 'select * from events',
'model_2': 'select * from {{ ref("model_1") }}',
'model_3': '''
self.use_models(
{
"model_1": "select * from events",
"model_2": 'select * from {{ ref("model_1") }}',
"model_3": """
select * from {{ ref("model_1") }}
union all
select * from {{ ref("model_2") }}
''',
'model_4': 'select * from {{ ref("model_3") }}'
})
""",
"model_4": 'select * from {{ ref("model_3") }}',
}
)
config = self.get_config()
manifest = self.load_manifest(config)
compiler = self.get_compiler(config)
graph = compiler.compile(manifest)
models = ('model_1', 'model_2', 'model_3', 'model_4')
model_ids = ['model.test_models_compile.{}'.format(m) for m in models]
models = ("model_1", "model_2", "model_3", "model_4")
model_ids = ["model.test_models_compile.{}".format(m) for m in models]
manifest = MagicMock(nodes={
n: MagicMock(
unique_id=n,
name=n.split('.')[-1],
package_name='test_models_compile',
fqn=['test_models_compile', n],
empty=False,
config=MagicMock(enabled=True),
)
for n in model_ids
})
manifest = MagicMock(
nodes={
n: MagicMock(
unique_id=n,
name=n.split(".")[-1],
package_name="test_models_compile",
fqn=["test_models_compile", n],
empty=False,
config=MagicMock(enabled=True),
)
for n in model_ids
}
)
manifest.expect.side_effect = lambda n: MagicMock(unique_id=n)
selector = NodeSelector(graph, manifest)
# TODO: The "eager" string below needs to be replaced with programatic access
# to the default value for the indirect selection parameter in
# to the default value for the indirect selection parameter in
# dbt.cli.params.indirect_selection
#
# Doing that is actually a little tricky, so I'm punting it to a new ticket GH #6397
@@ -328,9 +338,9 @@ class GraphTest(unittest.TestCase):
is_partial_parsable, _ = loader.is_partial_parsable(manifest)
self.assertTrue(is_partial_parsable)
manifest.metadata.dbt_version = '0.0.1a1'
manifest.metadata.dbt_version = "0.0.1a1"
is_partial_parsable, _ = loader.is_partial_parsable(manifest)
self.assertFalse(is_partial_parsable)
manifest.metadata.dbt_version = '99999.99.99'
manifest.metadata.dbt_version = "99999.99.99"
is_partial_parsable, _ = loader.is_partial_parsable(manifest)
self.assertFalse(is_partial_parsable)

View File

@@ -1,4 +1,3 @@
import unittest
from unittest import mock
import pytest
@@ -23,7 +22,7 @@ def _get_graph():
integer_graph = nx.balanced_tree(2, 2, nx.DiGraph())
package_mapping = {
i: 'm.' + ('X' if i % 2 == 0 else 'Y') + '.' + letter
i: "m." + ("X" if i % 2 == 0 else "Y") + "." + letter
for (i, letter) in enumerate(string.ascii_lowercase)
}
@@ -34,7 +33,7 @@ def _get_graph():
def _get_manifest(graph):
nodes = {}
for unique_id in graph:
fqn = unique_id.split('.')
fqn = unique_id.split(".")
node = mock.MagicMock(
unique_id=unique_id,
fqn=fqn,
@@ -46,13 +45,13 @@ def _get_manifest(graph):
)
nodes[unique_id] = node
nodes['m.X.a'].tags = ['abc']
nodes['m.Y.b'].tags = ['abc', 'bcef']
nodes['m.X.c'].tags = ['abc', 'bcef']
nodes['m.Y.d'].tags = []
nodes['m.X.e'].tags = ['efg', 'bcef']
nodes['m.Y.f'].tags = ['efg', 'bcef']
nodes['m.X.g'].tags = ['efg']
nodes["m.X.a"].tags = ["abc"]
nodes["m.Y.b"].tags = ["abc", "bcef"]
nodes["m.X.c"].tags = ["abc", "bcef"]
nodes["m.Y.d"].tags = []
nodes["m.X.e"].tags = ["efg", "bcef"]
nodes["m.Y.f"].tags = ["efg", "bcef"]
nodes["m.X.g"].tags = ["efg"]
return mock.MagicMock(nodes=nodes)
@@ -70,64 +69,64 @@ def id_macro(arg):
if isinstance(arg, str):
return arg
try:
return '_'.join(arg)
return "_".join(arg)
except TypeError:
return arg
run_specs = [
# include by fqn
(['X.a'], [], {'m.X.a'}),
(["X.a"], [], {"m.X.a"}),
# include by tag
(['tag:abc'], [], {'m.X.a', 'm.Y.b', 'm.X.c'}),
(["tag:abc"], [], {"m.X.a", "m.Y.b", "m.X.c"}),
# exclude by tag
(['*'], ['tag:abc'], {'m.Y.d', 'm.X.e', 'm.Y.f', 'm.X.g'}),
(["*"], ["tag:abc"], {"m.Y.d", "m.X.e", "m.Y.f", "m.X.g"}),
# tag + fqn
(['tag:abc', 'a'], [], {'m.X.a', 'm.Y.b', 'm.X.c'}),
(['tag:abc', 'd'], [], {'m.X.a', 'm.Y.b', 'm.X.c', 'm.Y.d'}),
(["tag:abc", "a"], [], {"m.X.a", "m.Y.b", "m.X.c"}),
(["tag:abc", "d"], [], {"m.X.a", "m.Y.b", "m.X.c", "m.Y.d"}),
# multiple node selection across packages
(['X.a', 'b'], [], {'m.X.a', 'm.Y.b'}),
(['X.a+'], ['b'], {'m.X.a','m.X.c', 'm.Y.d','m.X.e','m.Y.f','m.X.g'}),
(["X.a", "b"], [], {"m.X.a", "m.Y.b"}),
(["X.a+"], ["b"], {"m.X.a", "m.X.c", "m.Y.d", "m.X.e", "m.Y.f", "m.X.g"}),
# children
(['X.c+'], [], {'m.X.c', 'm.Y.f', 'm.X.g'}),
(['X.a+1'], [], {'m.X.a', 'm.Y.b', 'm.X.c'}),
(['X.a+'], ['tag:efg'], {'m.X.a','m.Y.b','m.X.c', 'm.Y.d'}),
(["X.c+"], [], {"m.X.c", "m.Y.f", "m.X.g"}),
(["X.a+1"], [], {"m.X.a", "m.Y.b", "m.X.c"}),
(["X.a+"], ["tag:efg"], {"m.X.a", "m.Y.b", "m.X.c", "m.Y.d"}),
# parents
(['+Y.f'], [], {'m.X.c', 'm.Y.f', 'm.X.a'}),
(['1+Y.f'], [], {'m.X.c', 'm.Y.f'}),
(["+Y.f"], [], {"m.X.c", "m.Y.f", "m.X.a"}),
(["1+Y.f"], [], {"m.X.c", "m.Y.f"}),
# childrens parents
(['@X.c'], [], {'m.X.a', 'm.X.c', 'm.Y.f', 'm.X.g'}),
(["@X.c"], [], {"m.X.a", "m.X.c", "m.Y.f", "m.X.g"}),
# multiple selection/exclusion
(['tag:abc', 'tag:bcef'], [], {'m.X.a', 'm.Y.b', 'm.X.c', 'm.X.e', 'm.Y.f'}),
(['tag:abc', 'tag:bcef'], ['tag:efg'], {'m.X.a', 'm.Y.b', 'm.X.c'}),
(['tag:abc', 'tag:bcef'], ['tag:efg', 'a'], {'m.Y.b', 'm.X.c'}),
(["tag:abc", "tag:bcef"], [], {"m.X.a", "m.Y.b", "m.X.c", "m.X.e", "m.Y.f"}),
(["tag:abc", "tag:bcef"], ["tag:efg"], {"m.X.a", "m.Y.b", "m.X.c"}),
(["tag:abc", "tag:bcef"], ["tag:efg", "a"], {"m.Y.b", "m.X.c"}),
# intersections
(['a,a'], [], {'m.X.a'}),
(['+c,c+'], [], {'m.X.c'}),
(['a,b'], [], set()),
(['tag:abc,tag:bcef'], [], {'m.Y.b', 'm.X.c'}),
(['*,tag:abc,a'], [], {'m.X.a'}),
(['a,tag:abc,*'], [], {'m.X.a'}),
(['tag:abc,tag:bcef'], ['c'], {'m.Y.b'}),
(['tag:bcef,tag:efg'], ['tag:bcef,@b'], {'m.Y.f'}),
(['tag:bcef,tag:efg'], ['tag:bcef,@a'], set()),
(['*,@a,+b'], ['*,tag:abc,tag:bcef'], {'m.X.a'}),
(['tag:bcef,tag:efg', '*,tag:abc'], [], {'m.X.a', 'm.Y.b', 'm.X.c', 'm.X.e', 'm.Y.f'}),
(['tag:bcef,tag:efg', '*,tag:abc'], ['e'], {'m.X.a', 'm.Y.b', 'm.X.c', 'm.Y.f'}),
(['tag:bcef,tag:efg', '*,tag:abc'], ['e'], {'m.X.a', 'm.Y.b', 'm.X.c', 'm.Y.f'}),
(['tag:bcef,tag:efg', '*,tag:abc'], ['e', 'f'], {'m.X.a', 'm.Y.b', 'm.X.c'}),
(['tag:bcef,tag:efg', '*,tag:abc'], ['tag:abc,tag:bcef'], {'m.X.a', 'm.X.e', 'm.Y.f'}),
(['tag:bcef,tag:efg', '*,tag:abc'], ['tag:abc,tag:bcef', 'tag:abc,a'], {'m.X.e', 'm.Y.f'})
(["a,a"], [], {"m.X.a"}),
(["+c,c+"], [], {"m.X.c"}),
(["a,b"], [], set()),
(["tag:abc,tag:bcef"], [], {"m.Y.b", "m.X.c"}),
(["*,tag:abc,a"], [], {"m.X.a"}),
(["a,tag:abc,*"], [], {"m.X.a"}),
(["tag:abc,tag:bcef"], ["c"], {"m.Y.b"}),
(["tag:bcef,tag:efg"], ["tag:bcef,@b"], {"m.Y.f"}),
(["tag:bcef,tag:efg"], ["tag:bcef,@a"], set()),
(["*,@a,+b"], ["*,tag:abc,tag:bcef"], {"m.X.a"}),
(["tag:bcef,tag:efg", "*,tag:abc"], [], {"m.X.a", "m.Y.b", "m.X.c", "m.X.e", "m.Y.f"}),
(["tag:bcef,tag:efg", "*,tag:abc"], ["e"], {"m.X.a", "m.Y.b", "m.X.c", "m.Y.f"}),
(["tag:bcef,tag:efg", "*,tag:abc"], ["e"], {"m.X.a", "m.Y.b", "m.X.c", "m.Y.f"}),
(["tag:bcef,tag:efg", "*,tag:abc"], ["e", "f"], {"m.X.a", "m.Y.b", "m.X.c"}),
(["tag:bcef,tag:efg", "*,tag:abc"], ["tag:abc,tag:bcef"], {"m.X.a", "m.X.e", "m.Y.f"}),
(["tag:bcef,tag:efg", "*,tag:abc"], ["tag:abc,tag:bcef", "tag:abc,a"], {"m.X.e", "m.Y.f"}),
]
@pytest.mark.parametrize('include,exclude,expected', run_specs, ids=id_macro)
@pytest.mark.parametrize("include,exclude,expected", run_specs, ids=id_macro)
def test_run_specs(include, exclude, expected):
graph = _get_graph()
manifest = _get_manifest(graph)
selector = graph_selector.NodeSelector(graph, manifest)
# TODO: The "eager" string below needs to be replaced with programatic access
# to the default value for the indirect selection parameter in
# to the default value for the indirect selection parameter in
# dbt.cli.params.indirect_selection
#
# Doing that is actually a little tricky, so I'm punting it to a new ticket GH #6397
@@ -138,52 +137,61 @@ def test_run_specs(include, exclude, expected):
param_specs = [
('a', False, None, False, None, 'fqn', 'a', False),
('+a', True, None, False, None, 'fqn', 'a', False),
('256+a', True, 256, False, None, 'fqn', 'a', False),
('a+', False, None, True, None, 'fqn', 'a', False),
('a+256', False, None, True, 256, 'fqn', 'a', False),
('+a+', True, None, True, None, 'fqn', 'a', False),
('16+a+32', True, 16, True, 32, 'fqn', 'a', False),
('@a', False, None, False, None, 'fqn', 'a', True),
('a.b', False, None, False, None, 'fqn', 'a.b', False),
('+a.b', True, None, False, None, 'fqn', 'a.b', False),
('256+a.b', True, 256, False, None, 'fqn', 'a.b', False),
('a.b+', False, None, True, None, 'fqn', 'a.b', False),
('a.b+256', False, None, True, 256, 'fqn', 'a.b', False),
('+a.b+', True, None, True, None, 'fqn', 'a.b', False),
('16+a.b+32', True, 16, True, 32, 'fqn', 'a.b', False),
('@a.b', False, None, False, None, 'fqn', 'a.b', True),
('a.b.*', False, None, False, None, 'fqn', 'a.b.*', False),
('+a.b.*', True, None, False, None, 'fqn', 'a.b.*', False),
('256+a.b.*', True, 256, False, None, 'fqn', 'a.b.*', False),
('a.b.*+', False, None, True, None, 'fqn', 'a.b.*', False),
('a.b.*+256', False, None, True, 256, 'fqn', 'a.b.*', False),
('+a.b.*+', True, None, True, None, 'fqn', 'a.b.*', False),
('16+a.b.*+32', True, 16, True, 32, 'fqn', 'a.b.*', False),
('@a.b.*', False, None, False, None, 'fqn', 'a.b.*', True),
('tag:a', False, None, False, None, 'tag', 'a', False),
('+tag:a', True, None, False, None, 'tag', 'a', False),
('256+tag:a', True, 256, False, None, 'tag', 'a', False),
('tag:a+', False, None, True, None, 'tag', 'a', False),
('tag:a+256', False, None, True, 256, 'tag', 'a', False),
('+tag:a+', True, None, True, None, 'tag', 'a', False),
('16+tag:a+32', True, 16, True, 32, 'tag', 'a', False),
('@tag:a', False, None, False, None, 'tag', 'a', True),
('source:a', False, None, False, None, 'source', 'a', False),
('source:a+', False, None, True, None, 'source', 'a', False),
('source:a+1', False, None, True, 1, 'source', 'a', False),
('source:a+32', False, None, True, 32, 'source', 'a', False),
('@source:a', False, None, False, None, 'source', 'a', True),
("a", False, None, False, None, "fqn", "a", False),
("+a", True, None, False, None, "fqn", "a", False),
("256+a", True, 256, False, None, "fqn", "a", False),
("a+", False, None, True, None, "fqn", "a", False),
("a+256", False, None, True, 256, "fqn", "a", False),
("+a+", True, None, True, None, "fqn", "a", False),
("16+a+32", True, 16, True, 32, "fqn", "a", False),
("@a", False, None, False, None, "fqn", "a", True),
("a.b", False, None, False, None, "fqn", "a.b", False),
("+a.b", True, None, False, None, "fqn", "a.b", False),
("256+a.b", True, 256, False, None, "fqn", "a.b", False),
("a.b+", False, None, True, None, "fqn", "a.b", False),
("a.b+256", False, None, True, 256, "fqn", "a.b", False),
("+a.b+", True, None, True, None, "fqn", "a.b", False),
("16+a.b+32", True, 16, True, 32, "fqn", "a.b", False),
("@a.b", False, None, False, None, "fqn", "a.b", True),
("a.b.*", False, None, False, None, "fqn", "a.b.*", False),
("+a.b.*", True, None, False, None, "fqn", "a.b.*", False),
("256+a.b.*", True, 256, False, None, "fqn", "a.b.*", False),
("a.b.*+", False, None, True, None, "fqn", "a.b.*", False),
("a.b.*+256", False, None, True, 256, "fqn", "a.b.*", False),
("+a.b.*+", True, None, True, None, "fqn", "a.b.*", False),
("16+a.b.*+32", True, 16, True, 32, "fqn", "a.b.*", False),
("@a.b.*", False, None, False, None, "fqn", "a.b.*", True),
("tag:a", False, None, False, None, "tag", "a", False),
("+tag:a", True, None, False, None, "tag", "a", False),
("256+tag:a", True, 256, False, None, "tag", "a", False),
("tag:a+", False, None, True, None, "tag", "a", False),
("tag:a+256", False, None, True, 256, "tag", "a", False),
("+tag:a+", True, None, True, None, "tag", "a", False),
("16+tag:a+32", True, 16, True, 32, "tag", "a", False),
("@tag:a", False, None, False, None, "tag", "a", True),
("source:a", False, None, False, None, "source", "a", False),
("source:a+", False, None, True, None, "source", "a", False),
("source:a+1", False, None, True, 1, "source", "a", False),
("source:a+32", False, None, True, 32, "source", "a", False),
("@source:a", False, None, False, None, "source", "a", True),
]
@pytest.mark.parametrize(
'spec,parents,parents_depth,children,children_depth,filter_type,filter_value,childrens_parents',
"spec,parents,parents_depth,children,children_depth,filter_type,filter_value,childrens_parents",
param_specs,
ids=id_macro
ids=id_macro,
)
def test_parse_specs(spec, parents, parents_depth, children, children_depth, filter_type, filter_value, childrens_parents):
def test_parse_specs(
spec,
parents,
parents_depth,
children,
children_depth,
filter_type,
filter_value,
childrens_parents,
):
parsed = graph_selector.SelectionCriteria.from_single_spec(spec)
assert parsed.parents == parents
assert parsed.parents_depth == parents_depth
@@ -195,15 +203,15 @@ def test_parse_specs(spec, parents, parents_depth, children, children_depth, fil
invalid_specs = [
'@a+',
'@a.b+',
'@a.b*+',
'@tag:a+',
'@source:a+',
"@a+",
"@a.b+",
"@a.b*+",
"@tag:a+",
"@source:a+",
]
@pytest.mark.parametrize('invalid', invalid_specs, ids=lambda k: str(k))
@pytest.mark.parametrize("invalid", invalid_specs, ids=lambda k: str(k))
def test_invalid_specs(invalid):
with pytest.raises(dbt.exceptions.DbtRuntimeError):
graph_selector.SelectionCriteria.from_single_spec(invalid)

File diff suppressed because it is too large Load Diff

View File

@@ -24,10 +24,10 @@ class Union:
self.components = args
def __str__(self):
return f'Union(components={self.components})'
return f"Union(components={self.components})"
def __repr__(self):
return f'Union(components={self.components!r})'
return f"Union(components={self.components!r})"
def __eq__(self, other):
if not isinstance(other, SelectionUnion):
@@ -41,10 +41,10 @@ class Intersection:
self.components = args
def __str__(self):
return f'Intersection(components={self.components})'
return f"Intersection(components={self.components})"
def __repr__(self):
return f'Intersection(components={self.components!r})'
return f"Intersection(components={self.components!r})"
def __eq__(self, other):
if not isinstance(other, SelectionIntersection):
@@ -58,10 +58,10 @@ class Difference:
self.components = args
def __str__(self):
return f'Difference(components={self.components})'
return f"Difference(components={self.components})"
def __repr__(self):
return f'Difference(components={self.components!r})'
return f"Difference(components={self.components!r})"
def __eq__(self, other):
if not isinstance(other, SelectionDifference):
@@ -77,75 +77,86 @@ class Criteria:
self.kwargs = kwargs
def __str__(self):
return f'Criteria(method={self.method}, value={self.value}, **{self.kwargs})'
return f"Criteria(method={self.method}, value={self.value}, **{self.kwargs})"
def __repr__(self):
return f'Criteria(method={self.method!r}, value={self.value!r}, **{self.kwargs!r})'
return f"Criteria(method={self.method!r}, value={self.value!r}, **{self.kwargs!r})"
def __eq__(self, other):
if not isinstance(other, SelectionCriteria):
return False
return (
self.method == other.method and
self.value == other.value and
all(getattr(other, k) == v for k, v in self.kwargs.items())
self.method == other.method
and self.value == other.value
and all(getattr(other, k) == v for k, v in self.kwargs.items())
)
def test_parse_simple():
sf = parse_file('''\
sf = parse_file(
"""\
selectors:
- name: tagged_foo
description: Selector for foo-tagged models
definition:
tag: foo
''')
"""
)
assert len(sf.selectors) == 1
assert sf.selectors[0].description == 'Selector for foo-tagged models'
assert sf.selectors[0].description == "Selector for foo-tagged models"
parsed = cli.parse_from_selectors_definition(sf)
assert len(parsed) == 1
assert 'tagged_foo' in parsed
assert Criteria(
method=MethodName.Tag,
method_arguments=[],
value='foo',
children=False,
parents=False,
childrens_parents=False,
children_depth=None,
parents_depth=None,
) == parsed['tagged_foo']["definition"]
assert "tagged_foo" in parsed
assert (
Criteria(
method=MethodName.Tag,
method_arguments=[],
value="foo",
children=False,
parents=False,
childrens_parents=False,
children_depth=None,
parents_depth=None,
)
== parsed["tagged_foo"]["definition"]
)
def test_parse_simple_childrens_parents():
sf = parse_file('''\
sf = parse_file(
"""\
selectors:
- name: tagged_foo
definition:
method: tag
value: foo
childrens_parents: True
''')
"""
)
assert len(sf.selectors) == 1
parsed = cli.parse_from_selectors_definition(sf)
assert len(parsed) == 1
assert 'tagged_foo' in parsed
assert Criteria(
method=MethodName.Tag,
method_arguments=[],
value='foo',
children=False,
parents=False,
childrens_parents=True,
children_depth=None,
parents_depth=None,
) == parsed['tagged_foo']["definition"]
assert "tagged_foo" in parsed
assert (
Criteria(
method=MethodName.Tag,
method_arguments=[],
value="foo",
children=False,
parents=False,
childrens_parents=True,
children_depth=None,
parents_depth=None,
)
== parsed["tagged_foo"]["definition"]
)
def test_parse_simple_arguments_with_modifiers():
sf = parse_file('''\
sf = parse_file(
"""\
selectors:
- name: configured_view
definition:
@@ -154,26 +165,31 @@ def test_parse_simple_arguments_with_modifiers():
parents: True
children: True
children_depth: 2
''')
"""
)
assert len(sf.selectors) == 1
parsed = cli.parse_from_selectors_definition(sf)
assert len(parsed) == 1
assert 'configured_view' in parsed
assert Criteria(
method=MethodName.Config,
method_arguments=['materialized'],
value='view',
children=True,
parents=True,
childrens_parents=False,
children_depth=2,
parents_depth=None,
) == parsed['configured_view']["definition"]
assert "configured_view" in parsed
assert (
Criteria(
method=MethodName.Config,
method_arguments=["materialized"],
value="view",
children=True,
parents=True,
childrens_parents=False,
children_depth=2,
parents_depth=None,
)
== parsed["configured_view"]["definition"]
)
def test_parse_union():
sf = parse_file('''\
sf = parse_file(
"""\
selectors:
- name: views-or-foos
definition:
@@ -181,18 +197,23 @@ def test_parse_union():
- method: config.materialized
value: view
- tag: foo
''')
"""
)
assert len(sf.selectors) == 1
parsed = cli.parse_from_selectors_definition(sf)
assert 'views-or-foos' in parsed
assert Union(
Criteria(method=MethodName.Config, value='view', method_arguments=['materialized']),
Criteria(method=MethodName.Tag, value='foo', method_arguments=[])
) == parsed['views-or-foos']["definition"]
assert "views-or-foos" in parsed
assert (
Union(
Criteria(method=MethodName.Config, value="view", method_arguments=["materialized"]),
Criteria(method=MethodName.Tag, value="foo", method_arguments=[]),
)
== parsed["views-or-foos"]["definition"]
)
def test_parse_intersection():
sf = parse_file('''\
sf = parse_file(
"""\
selectors:
- name: views-and-foos
definition:
@@ -200,19 +221,24 @@ def test_parse_intersection():
- method: config.materialized
value: view
- tag: foo
''')
"""
)
assert len(sf.selectors) == 1
parsed = cli.parse_from_selectors_definition(sf)
assert 'views-and-foos' in parsed
assert Intersection(
Criteria(method=MethodName.Config, value='view', method_arguments=['materialized']),
Criteria(method=MethodName.Tag, value='foo', method_arguments=[]),
) == parsed['views-and-foos']["definition"]
assert "views-and-foos" in parsed
assert (
Intersection(
Criteria(method=MethodName.Config, value="view", method_arguments=["materialized"]),
Criteria(method=MethodName.Tag, value="foo", method_arguments=[]),
)
== parsed["views-and-foos"]["definition"]
)
def test_parse_union_excluding():
sf = parse_file('''\
sf = parse_file(
"""\
selectors:
- name: views-or-foos-not-bars
definition:
@@ -222,21 +248,28 @@ def test_parse_union_excluding():
- tag: foo
- exclude:
- tag: bar
''')
"""
)
assert len(sf.selectors) == 1
parsed = cli.parse_from_selectors_definition(sf)
assert 'views-or-foos-not-bars' in parsed
assert Difference(
Union(
Criteria(method=MethodName.Config, value='view', method_arguments=['materialized']),
Criteria(method=MethodName.Tag, value='foo', method_arguments=[])
),
Criteria(method=MethodName.Tag, value='bar', method_arguments=[]),
) == parsed['views-or-foos-not-bars']["definition"]
assert "views-or-foos-not-bars" in parsed
assert (
Difference(
Union(
Criteria(
method=MethodName.Config, value="view", method_arguments=["materialized"]
),
Criteria(method=MethodName.Tag, value="foo", method_arguments=[]),
),
Criteria(method=MethodName.Tag, value="bar", method_arguments=[]),
)
== parsed["views-or-foos-not-bars"]["definition"]
)
def test_parse_yaml_complex():
sf = parse_file('''\
sf = parse_file(
"""\
selectors:
- name: test_name
definition:
@@ -263,48 +296,66 @@ def test_parse_yaml_complex():
union:
- tag: nightly
- tag:weeknights_only
''')
"""
)
assert len(sf.selectors) == 2
parsed = cli.parse_from_selectors_definition(sf)
assert 'test_name' in parsed
assert 'weeknights' in parsed
assert Union(
Criteria(method=MethodName.Tag, value='nightly'),
Criteria(method=MethodName.Tag, value='weeknights_only'),
) == parsed['weeknights']["definition"]
assert Union(
Intersection(
Criteria(method=MethodName.Tag, value='foo'),
Criteria(method=MethodName.Tag, value='bar'),
Union(
Criteria(method=MethodName.Package, value='snowplow'),
Criteria(method=MethodName.Config, value='incremental', method_arguments=['materialized']),
),
),
assert "test_name" in parsed
assert "weeknights" in parsed
assert (
Union(
Criteria(method=MethodName.Tag, value="nightly"),
Criteria(method=MethodName.Tag, value="weeknights_only"),
)
== parsed["weeknights"]["definition"]
)
assert (
Union(
Criteria(method=MethodName.Path, value="models/snowplow/marketing/custom_events.sql"),
Criteria(method=MethodName.FQN, value='snowplow.marketing'),
),
Difference(
Intersection(
Criteria(method=MethodName.ResourceType, value='seed'),
Criteria(method=MethodName.Package, value='snowplow'),
),
Union(
Criteria(method=MethodName.FQN, value='country_codes'),
Intersection(
Criteria(method=MethodName.Tag, value='baz'),
Criteria(method=MethodName.Config, value='ephemeral', method_arguments=['materialized']),
Criteria(method=MethodName.Tag, value="foo"),
Criteria(method=MethodName.Tag, value="bar"),
Union(
Criteria(method=MethodName.Package, value="snowplow"),
Criteria(
method=MethodName.Config,
value="incremental",
method_arguments=["materialized"],
),
),
),
),
) == parsed['test_name']["definition"]
Union(
Criteria(
method=MethodName.Path, value="models/snowplow/marketing/custom_events.sql"
),
Criteria(method=MethodName.FQN, value="snowplow.marketing"),
),
Difference(
Intersection(
Criteria(method=MethodName.ResourceType, value="seed"),
Criteria(method=MethodName.Package, value="snowplow"),
),
Union(
Criteria(method=MethodName.FQN, value="country_codes"),
Intersection(
Criteria(method=MethodName.Tag, value="baz"),
Criteria(
method=MethodName.Config,
value="ephemeral",
method_arguments=["materialized"],
),
),
),
),
)
== parsed["test_name"]["definition"]
)
def test_parse_selection():
sf = parse_file('''\
sf = parse_file(
"""\
selectors:
- name: default
definition:
@@ -315,23 +366,31 @@ def test_parse_selection():
definition:
method: selector
value: default
''')
"""
)
assert len(sf.selectors) == 2
parsed = cli.parse_from_selectors_definition(sf)
assert 'default' in parsed
assert 'inherited' in parsed
assert Union(
Criteria(method=MethodName.Tag, value='foo'),
Criteria(method=MethodName.Tag, value='bar'),
) == parsed['default']["definition"]
assert Union(
Criteria(method=MethodName.Tag, value='foo'),
Criteria(method=MethodName.Tag, value='bar'),
) == parsed['inherited']["definition"]
assert "default" in parsed
assert "inherited" in parsed
assert (
Union(
Criteria(method=MethodName.Tag, value="foo"),
Criteria(method=MethodName.Tag, value="bar"),
)
== parsed["default"]["definition"]
)
assert (
Union(
Criteria(method=MethodName.Tag, value="foo"),
Criteria(method=MethodName.Tag, value="bar"),
)
== parsed["inherited"]["definition"]
)
def test_parse_selection_with_exclusion():
sf = parse_file('''\
sf = parse_file(
"""\
selectors:
- name: default
definition:
@@ -345,21 +404,28 @@ def test_parse_selection_with_exclusion():
value: default
- exclude:
- tag: bar
''')
"""
)
assert len(sf.selectors) == 2
parsed = cli.parse_from_selectors_definition(sf)
assert 'default' in parsed
assert 'inherited' in parsed
assert Union(
Criteria(method=MethodName.Tag, value='foo'),
Criteria(method=MethodName.Tag, value='bar'),
) == parsed['default']["definition"]
assert Difference(
assert "default" in parsed
assert "inherited" in parsed
assert (
Union(
Criteria(method=MethodName.Tag, value="foo"),
Criteria(method=MethodName.Tag, value="bar"),
)
== parsed["default"]["definition"]
)
assert (
Difference(
Union(
Criteria(method=MethodName.Tag, value='foo'),
Criteria(method=MethodName.Tag, value='bar'),
)
),
Criteria(method=MethodName.Tag, value='bar'),
) == parsed['inherited']["definition"]
Union(
Criteria(method=MethodName.Tag, value="foo"),
Criteria(method=MethodName.Tag, value="bar"),
)
),
Criteria(method=MethodName.Tag, value="bar"),
)
== parsed["inherited"]["definition"]
)

View File

@@ -12,7 +12,7 @@ import os
def test_raw_parse_simple():
raw = 'asdf'
raw = "asdf"
result = SelectionCriteria.from_single_spec(raw)
assert result.raw == raw
assert result.method == MethodName.FQN
@@ -26,7 +26,7 @@ def test_raw_parse_simple():
def test_raw_parse_simple_infer_path():
raw = os.path.join('asdf', '*')
raw = os.path.join("asdf", "*")
result = SelectionCriteria.from_single_spec(raw)
assert result.raw == raw
assert result.method == MethodName.Path
@@ -40,7 +40,7 @@ def test_raw_parse_simple_infer_path():
def test_raw_parse_simple_infer_path_modified():
raw = '@' + os.path.join('asdf', '*')
raw = "@" + os.path.join("asdf", "*")
result = SelectionCriteria.from_single_spec(raw)
assert result.raw == raw
assert result.method == MethodName.Path
@@ -54,12 +54,12 @@ def test_raw_parse_simple_infer_path_modified():
def test_raw_parse_simple_infer_fqn_parents():
raw = '+asdf'
raw = "+asdf"
result = SelectionCriteria.from_single_spec(raw)
assert result.raw == raw
assert result.method == MethodName.FQN
assert result.method_arguments == []
assert result.value == 'asdf'
assert result.value == "asdf"
assert not result.childrens_parents
assert not result.children
assert result.parents
@@ -68,12 +68,12 @@ def test_raw_parse_simple_infer_fqn_parents():
def test_raw_parse_simple_infer_fqn_children():
raw = 'asdf+'
raw = "asdf+"
result = SelectionCriteria.from_single_spec(raw)
assert result.raw == raw
assert result.method == MethodName.FQN
assert result.method_arguments == []
assert result.value == 'asdf'
assert result.value == "asdf"
assert not result.childrens_parents
assert result.children
assert not result.parents
@@ -82,12 +82,12 @@ def test_raw_parse_simple_infer_fqn_children():
def test_raw_parse_complex():
raw = '2+config.arg.secondarg:argument_value+4'
raw = "2+config.arg.secondarg:argument_value+4"
result = SelectionCriteria.from_single_spec(raw)
assert result.raw == raw
assert result.method == MethodName.Config
assert result.method_arguments == ['arg', 'secondarg']
assert result.value == 'argument_value'
assert result.method_arguments == ["arg", "secondarg"]
assert result.value == "argument_value"
assert not result.childrens_parents
assert result.children
assert result.parents
@@ -98,11 +98,11 @@ def test_raw_parse_complex():
def test_raw_parse_weird():
# you can have an empty method name (defaults to FQN/path) and you can have
# an empty value, so you can also have this...
result = SelectionCriteria.from_single_spec('')
assert result.raw == ''
result = SelectionCriteria.from_single_spec("")
assert result.raw == ""
assert result.method == MethodName.FQN
assert result.method_arguments == []
assert result.value == ''
assert result.value == ""
assert not result.childrens_parents
assert not result.children
assert not result.parents
@@ -112,40 +112,48 @@ def test_raw_parse_weird():
def test_raw_parse_invalid():
with pytest.raises(DbtRuntimeError):
SelectionCriteria.from_single_spec('invalid_method:something')
SelectionCriteria.from_single_spec("invalid_method:something")
with pytest.raises(DbtRuntimeError):
SelectionCriteria.from_single_spec('@foo+')
SelectionCriteria.from_single_spec("@foo+")
def test_intersection():
fqn_a = SelectionCriteria.from_single_spec('fqn:model_a')
fqn_b = SelectionCriteria.from_single_spec('fqn:model_b')
fqn_a = SelectionCriteria.from_single_spec("fqn:model_a")
fqn_b = SelectionCriteria.from_single_spec("fqn:model_b")
intersection = SelectionIntersection(components=[fqn_a, fqn_b])
assert list(intersection) == [fqn_a, fqn_b]
combined = intersection.combine_selections([{'model_a', 'model_b', 'model_c'}, {'model_c', 'model_d'}])
assert combined == {'model_c'}
combined = intersection.combine_selections(
[{"model_a", "model_b", "model_c"}, {"model_c", "model_d"}]
)
assert combined == {"model_c"}
def test_difference():
fqn_a = SelectionCriteria.from_single_spec('fqn:model_a')
fqn_b = SelectionCriteria.from_single_spec('fqn:model_b')
fqn_a = SelectionCriteria.from_single_spec("fqn:model_a")
fqn_b = SelectionCriteria.from_single_spec("fqn:model_b")
difference = SelectionDifference(components=[fqn_a, fqn_b])
assert list(difference) == [fqn_a, fqn_b]
combined = difference.combine_selections([{'model_a', 'model_b', 'model_c'}, {'model_c', 'model_d'}])
assert combined == {'model_a', 'model_b'}
combined = difference.combine_selections(
[{"model_a", "model_b", "model_c"}, {"model_c", "model_d"}]
)
assert combined == {"model_a", "model_b"}
fqn_c = SelectionCriteria.from_single_spec('fqn:model_c')
fqn_c = SelectionCriteria.from_single_spec("fqn:model_c")
difference = SelectionDifference(components=[fqn_a, fqn_b, fqn_c])
assert list(difference) == [fqn_a, fqn_b, fqn_c]
combined = difference.combine_selections([{'model_a', 'model_b', 'model_c'}, {'model_c', 'model_d'}, {'model_a'}])
assert combined == {'model_b'}
combined = difference.combine_selections(
[{"model_a", "model_b", "model_c"}, {"model_c", "model_d"}, {"model_a"}]
)
assert combined == {"model_b"}
def test_union():
fqn_a = SelectionCriteria.from_single_spec('fqn:model_a')
fqn_b = SelectionCriteria.from_single_spec('fqn:model_b')
fqn_c = SelectionCriteria.from_single_spec('fqn:model_c')
fqn_a = SelectionCriteria.from_single_spec("fqn:model_a")
fqn_b = SelectionCriteria.from_single_spec("fqn:model_b")
fqn_c = SelectionCriteria.from_single_spec("fqn:model_c")
difference = SelectionUnion(components=[fqn_a, fqn_b, fqn_c])
combined = difference.combine_selections([{'model_a', 'model_b'}, {'model_b', 'model_c'}, {'model_d'}])
assert combined == {'model_a', 'model_b', 'model_c', 'model_d'}
combined = difference.combine_selections(
[{"model_a", "model_b"}, {"model_b", "model_c"}, {"model_d"}]
)
assert combined == {"model_a", "model_b", "model_c", "model_d"}

View File

@@ -22,20 +22,20 @@ def raises(value):
def expected_id(arg):
if isinstance(arg, list):
return '_'.join(arg)
return "_".join(arg)
jinja_tests = [
# strings
(
'''foo: bar''',
returns('bar'),
returns('bar'),
"""foo: bar""",
returns("bar"),
returns("bar"),
),
(
'''foo: "bar"''',
returns('bar'),
returns('bar'),
returns("bar"),
returns("bar"),
),
(
'''foo: "'bar'"''',
@@ -49,34 +49,34 @@ jinja_tests = [
),
(
'''foo: "{{ 'bar' | as_text }}"''',
returns('bar'),
returns('bar'),
returns("bar"),
returns("bar"),
),
(
'''foo: "{{ 'bar' | as_bool }}"''',
returns('bar'),
returns("bar"),
raises(JinjaRenderingError),
),
(
'''foo: "{{ 'bar' | as_number }}"''',
returns('bar'),
returns("bar"),
raises(JinjaRenderingError),
),
(
'''foo: "{{ 'bar' | as_native }}"''',
returns('bar'),
returns('bar'),
returns("bar"),
returns("bar"),
),
# ints
(
'''foo: 1''',
returns('1'),
returns('1'),
"""foo: 1""",
returns("1"),
returns("1"),
),
(
'''foo: "1"''',
returns('1'),
returns('1'),
returns("1"),
returns("1"),
),
(
'''foo: "'1'"''',
@@ -90,13 +90,13 @@ jinja_tests = [
),
(
'''foo: "{{ 1 }}"''',
returns('1'),
returns('1'),
returns("1"),
returns("1"),
),
(
'''foo: "{{ '1' }}"''',
returns('1'),
returns('1'),
returns("1"),
returns("1"),
),
(
'''foo: "'{{ 1 }}'"''',
@@ -110,42 +110,42 @@ jinja_tests = [
),
(
'''foo: "{{ 1 | as_text }}"''',
returns('1'),
returns('1'),
returns("1"),
returns("1"),
),
(
'''foo: "{{ 1 | as_bool }}"''',
returns('1'),
returns("1"),
raises(JinjaRenderingError),
),
(
'''foo: "{{ 1 | as_number }}"''',
returns('1'),
returns("1"),
returns(1),
),
(
'''foo: "{{ 1 | as_native }}"''',
returns('1'),
returns("1"),
returns(1),
),
(
'''foo: "{{ '1' | as_text }}"''',
returns('1'),
returns('1'),
returns("1"),
returns("1"),
),
(
'''foo: "{{ '1' | as_bool }}"''',
returns('1'),
returns("1"),
raises(JinjaRenderingError),
),
(
'''foo: "{{ '1' | as_number }}"''',
returns('1'),
returns("1"),
returns(1),
),
(
'''foo: "{{ '1' | as_native }}"''',
returns('1'),
returns("1"),
returns(1),
),
# booleans.
@@ -155,27 +155,27 @@ jinja_tests = [
# unquoted true
(
'''foo: "{{ True }}"''',
returns('True'),
returns('True'),
returns("True"),
returns("True"),
),
(
'''foo: "{{ True | as_text }}"''',
returns('True'),
returns('True'),
returns("True"),
returns("True"),
),
(
'''foo: "{{ True | as_bool }}"''',
returns('True'),
returns("True"),
returns(True),
),
(
'''foo: "{{ True | as_number }}"''',
returns('True'),
returns("True"),
raises(JinjaRenderingError),
),
(
'''foo: "{{ True | as_native }}"''',
returns('True'),
returns("True"),
returns(True),
),
# unquoted true
@@ -238,8 +238,8 @@ jinja_tests = [
# unquoted True
(
'''foo: "{{ True }}"''',
returns('True'),
returns('True'),
returns("True"),
returns("True"),
),
(
'''foo: "{{ True | as_text }}"''',
@@ -270,7 +270,7 @@ jinja_tests = [
# 'True' -> string 'True' -> text -> str('True') -> 'True'
(
'''foo: "{{ 'True' | as_bool }}"''',
returns('True'),
returns("True"),
returns(True),
),
# quoted 'True' outside rendering
@@ -286,41 +286,41 @@ jinja_tests = [
),
# yaml turns 'yes' into a boolean true
(
'''foo: yes''',
returns('True'),
returns('True'),
"""foo: yes""",
returns("True"),
returns("True"),
),
(
'''foo: "yes"''',
returns('yes'),
returns('yes'),
returns("yes"),
returns("yes"),
),
# concatenation
(
'''foo: "{{ (a_int + 100) | as_native }}"''',
returns('200'),
returns("200"),
returns(200),
),
(
'''foo: "{{ (a_str ~ 100) | as_native }}"''',
returns('100100'),
returns("100100"),
returns(100100),
),
(
'''foo: "{{( a_int ~ 100) | as_native }}"''',
returns('100100'),
returns("100100"),
returns(100100),
),
# multiple nodes -> always str
(
'''foo: "{{ a_str | as_native }}{{ a_str | as_native }}"''',
returns('100100'),
returns('100100'),
returns("100100"),
returns("100100"),
),
(
'''foo: "{{ a_int | as_native }}{{ a_int | as_native }}"''',
returns('100100'),
returns('100100'),
returns("100100"),
returns("100100"),
),
(
'''foo: "'{{ a_int | as_native }}{{ a_int | as_native }}'"''',
@@ -328,57 +328,49 @@ jinja_tests = [
returns("'100100'"),
),
(
'''foo:''',
returns('None'),
returns('None'),
"""foo:""",
returns("None"),
returns("None"),
),
(
'''foo: null''',
returns('None'),
returns('None'),
"""foo: null""",
returns("None"),
returns("None"),
),
(
'''foo: ""''',
returns(''),
returns(''),
returns(""),
returns(""),
),
(
'''foo: "{{ '' | as_native }}"''',
returns(''),
returns(''),
returns(""),
returns(""),
),
# very annoying, but jinja 'none' is yaml 'null'.
(
'''foo: "{{ none | as_native }}"''',
returns('None'),
returns("None"),
returns(None),
),
# make sure we don't include comments in the output (see #2707)
(
'''foo: "{# #}hello"''',
returns('hello'),
returns('hello'),
returns("hello"),
returns("hello"),
),
(
'''foo: "{% if false %}{% endif %}hello"''',
returns('hello'),
returns('hello'),
returns("hello"),
returns("hello"),
),
]
@pytest.mark.parametrize(
'value,text_expectation,native_expectation',
jinja_tests,
ids=expected_id
)
@pytest.mark.parametrize("value,text_expectation,native_expectation", jinja_tests, ids=expected_id)
def test_jinja_rendering(value, text_expectation, native_expectation):
foo_value = yaml.safe_load(value)['foo']
ctx = {
'a_str': '100',
'a_int': 100,
'b_str': 'hello'
}
foo_value = yaml.safe_load(value)["foo"]
ctx = {"a_str": "100", "a_int": 100, "b_str": "hello"}
with text_expectation as text_result:
assert text_result == get_rendered(foo_value, ctx, native=False)
@@ -388,192 +380,221 @@ def test_jinja_rendering(value, text_expectation, native_expectation):
class TestJinja(unittest.TestCase):
def test_do(self):
s = '{% set my_dict = {} %}\n{% do my_dict.update(a=1) %}'
s = "{% set my_dict = {} %}\n{% do my_dict.update(a=1) %}"
template = get_template(s, {})
mod = template.make_module()
self.assertEqual(mod.my_dict, {'a': 1})
self.assertEqual(mod.my_dict, {"a": 1})
def test_regular_render(self):
s = '{{ "some_value" | as_native }}'
value = get_rendered(s, {}, native=False)
assert value == 'some_value'
s = '{{ 1991 | as_native }}'
assert value == "some_value"
s = "{{ 1991 | as_native }}"
value = get_rendered(s, {}, native=False)
assert value == '1991'
assert value == "1991"
s = '{{ "some_value" | as_text }}'
value = get_rendered(s, {}, native=False)
assert value == 'some_value'
s = '{{ 1991 | as_text }}'
assert value == "some_value"
s = "{{ 1991 | as_text }}"
value = get_rendered(s, {}, native=False)
assert value == '1991'
assert value == "1991"
def test_native_render(self):
s = '{{ "some_value" | as_native }}'
value = get_rendered(s, {}, native=True)
assert value == 'some_value'
s = '{{ 1991 | as_native }}'
assert value == "some_value"
s = "{{ 1991 | as_native }}"
value = get_rendered(s, {}, native=True)
assert value == 1991
s = '{{ "some_value" | as_text }}'
value = get_rendered(s, {}, native=True)
assert value == 'some_value'
s = '{{ 1991 | as_text }}'
assert value == "some_value"
s = "{{ 1991 | as_text }}"
value = get_rendered(s, {}, native=True)
assert value == '1991'
assert value == "1991"
class TestBlockLexer(unittest.TestCase):
def test_basic(self):
body = '{{ config(foo="bar") }}\r\nselect * from this.that\r\n'
block_data = ' \n\r\t{%- mytype foo %}'+body+'{%endmytype -%}'
blocks = extract_toplevel_blocks(block_data, allowed_blocks={'mytype'}, collect_raw_data=False)
block_data = " \n\r\t{%- mytype foo %}" + body + "{%endmytype -%}"
blocks = extract_toplevel_blocks(
block_data, allowed_blocks={"mytype"}, collect_raw_data=False
)
self.assertEqual(len(blocks), 1)
self.assertEqual(blocks[0].block_type_name, 'mytype')
self.assertEqual(blocks[0].block_name, 'foo')
self.assertEqual(blocks[0].block_type_name, "mytype")
self.assertEqual(blocks[0].block_name, "foo")
self.assertEqual(blocks[0].contents, body)
self.assertEqual(blocks[0].full_block, block_data)
def test_multiple(self):
body_one = '{{ config(foo="bar") }}\r\nselect * from this.that\r\n'
body_two = (
'{{ config(bar=1)}}\r\nselect * from {% if foo %} thing '
'{% else %} other_thing {% endif %}'
"{{ config(bar=1)}}\r\nselect * from {% if foo %} thing "
"{% else %} other_thing {% endif %}"
)
block_data = (
' {% mytype foo %}' + body_one + '{% endmytype %}' +
'\r\n{% othertype bar %}' + body_two + '{% endothertype %}'
" {% mytype foo %}"
+ body_one
+ "{% endmytype %}"
+ "\r\n{% othertype bar %}"
+ body_two
+ "{% endothertype %}"
)
blocks = extract_toplevel_blocks(
block_data, allowed_blocks={"mytype", "othertype"}, collect_raw_data=False
)
blocks = extract_toplevel_blocks(block_data, allowed_blocks={'mytype', 'othertype'}, collect_raw_data=False)
self.assertEqual(len(blocks), 2)
def test_comments(self):
body = '{{ config(foo="bar") }}\r\nselect * from this.that\r\n'
comment = '{# my comment #}'
block_data = ' \n\r\t{%- mytype foo %}'+body+'{%endmytype -%}'
blocks = extract_toplevel_blocks(comment+block_data, allowed_blocks={'mytype'}, collect_raw_data=False)
comment = "{# my comment #}"
block_data = " \n\r\t{%- mytype foo %}" + body + "{%endmytype -%}"
blocks = extract_toplevel_blocks(
comment + block_data, allowed_blocks={"mytype"}, collect_raw_data=False
)
self.assertEqual(len(blocks), 1)
self.assertEqual(blocks[0].block_type_name, 'mytype')
self.assertEqual(blocks[0].block_name, 'foo')
self.assertEqual(blocks[0].block_type_name, "mytype")
self.assertEqual(blocks[0].block_name, "foo")
self.assertEqual(blocks[0].contents, body)
self.assertEqual(blocks[0].full_block, block_data)
def test_evil_comments(self):
body = '{{ config(foo="bar") }}\r\nselect * from this.that\r\n'
comment = '{# external comment {% othertype bar %} select * from thing.other_thing{% endothertype %} #}'
block_data = ' \n\r\t{%- mytype foo %}'+body+'{%endmytype -%}'
blocks = extract_toplevel_blocks(comment+block_data, allowed_blocks={'mytype'}, collect_raw_data=False)
comment = "{# external comment {% othertype bar %} select * from thing.other_thing{% endothertype %} #}"
block_data = " \n\r\t{%- mytype foo %}" + body + "{%endmytype -%}"
blocks = extract_toplevel_blocks(
comment + block_data, allowed_blocks={"mytype"}, collect_raw_data=False
)
self.assertEqual(len(blocks), 1)
self.assertEqual(blocks[0].block_type_name, 'mytype')
self.assertEqual(blocks[0].block_name, 'foo')
self.assertEqual(blocks[0].block_type_name, "mytype")
self.assertEqual(blocks[0].block_name, "foo")
self.assertEqual(blocks[0].contents, body)
self.assertEqual(blocks[0].full_block, block_data)
def test_nested_comments(self):
body = '{# my comment #} {{ config(foo="bar") }}\r\nselect * from {# my other comment embedding {% endmytype %} #} this.that\r\n'
block_data = ' \n\r\t{%- mytype foo %}'+body+'{% endmytype -%}'
comment = '{# external comment {% othertype bar %} select * from thing.other_thing{% endothertype %} #}'
blocks = extract_toplevel_blocks(comment+block_data, allowed_blocks={'mytype'}, collect_raw_data=False)
block_data = " \n\r\t{%- mytype foo %}" + body + "{% endmytype -%}"
comment = "{# external comment {% othertype bar %} select * from thing.other_thing{% endothertype %} #}"
blocks = extract_toplevel_blocks(
comment + block_data, allowed_blocks={"mytype"}, collect_raw_data=False
)
self.assertEqual(len(blocks), 1)
self.assertEqual(blocks[0].block_type_name, 'mytype')
self.assertEqual(blocks[0].block_name, 'foo')
self.assertEqual(blocks[0].block_type_name, "mytype")
self.assertEqual(blocks[0].block_name, "foo")
self.assertEqual(blocks[0].contents, body)
self.assertEqual(blocks[0].full_block, block_data)
def test_complex_file(self):
blocks = extract_toplevel_blocks(complex_snapshot_file, allowed_blocks={'mytype', 'myothertype'}, collect_raw_data=False)
blocks = extract_toplevel_blocks(
complex_snapshot_file, allowed_blocks={"mytype", "myothertype"}, collect_raw_data=False
)
self.assertEqual(len(blocks), 3)
self.assertEqual(blocks[0].block_type_name, 'mytype')
self.assertEqual(blocks[0].block_name, 'foo')
self.assertEqual(blocks[0].full_block, '{% mytype foo %} some stuff {% endmytype %}')
self.assertEqual(blocks[0].contents, ' some stuff ')
self.assertEqual(blocks[1].block_type_name, 'mytype')
self.assertEqual(blocks[1].block_name, 'bar')
self.assertEqual(blocks[0].block_type_name, "mytype")
self.assertEqual(blocks[0].block_name, "foo")
self.assertEqual(blocks[0].full_block, "{% mytype foo %} some stuff {% endmytype %}")
self.assertEqual(blocks[0].contents, " some stuff ")
self.assertEqual(blocks[1].block_type_name, "mytype")
self.assertEqual(blocks[1].block_name, "bar")
self.assertEqual(blocks[1].full_block, bar_block)
self.assertEqual(blocks[1].contents, bar_block[16:-15].rstrip())
self.assertEqual(blocks[2].block_type_name, 'myothertype')
self.assertEqual(blocks[2].block_name, 'x')
self.assertEqual(blocks[2].block_type_name, "myothertype")
self.assertEqual(blocks[2].block_name, "x")
self.assertEqual(blocks[2].full_block, x_block.strip())
self.assertEqual(blocks[2].contents, x_block[len('\n{% myothertype x %}'):-len('{% endmyothertype %}\n')])
self.assertEqual(
blocks[2].contents,
x_block[len("\n{% myothertype x %}") : -len("{% endmyothertype %}\n")],
)
def test_peaceful_macro_coexistence(self):
body = '{# my macro #} {% macro foo(a, b) %} do a thing {%- endmacro %} {# my model #} {% a b %} test {% enda %}'
blocks = extract_toplevel_blocks(body, allowed_blocks={'macro', 'a'}, collect_raw_data=True)
body = "{# my macro #} {% macro foo(a, b) %} do a thing {%- endmacro %} {# my model #} {% a b %} test {% enda %}"
blocks = extract_toplevel_blocks(
body, allowed_blocks={"macro", "a"}, collect_raw_data=True
)
self.assertEqual(len(blocks), 4)
self.assertEqual(blocks[0].full_block, '{# my macro #} ')
self.assertEqual(blocks[1].block_type_name, 'macro')
self.assertEqual(blocks[1].block_name, 'foo')
self.assertEqual(blocks[1].contents, ' do a thing')
self.assertEqual(blocks[2].full_block, ' {# my model #} ')
self.assertEqual(blocks[3].block_type_name, 'a')
self.assertEqual(blocks[3].block_name, 'b')
self.assertEqual(blocks[3].contents, ' test ')
self.assertEqual(blocks[0].full_block, "{# my macro #} ")
self.assertEqual(blocks[1].block_type_name, "macro")
self.assertEqual(blocks[1].block_name, "foo")
self.assertEqual(blocks[1].contents, " do a thing")
self.assertEqual(blocks[2].full_block, " {# my model #} ")
self.assertEqual(blocks[3].block_type_name, "a")
self.assertEqual(blocks[3].block_name, "b")
self.assertEqual(blocks[3].contents, " test ")
def test_macro_with_trailing_data(self):
body = '{# my macro #} {% macro foo(a, b) %} do a thing {%- endmacro %} {# my model #} {% a b %} test {% enda %} raw data so cool'
blocks = extract_toplevel_blocks(body, allowed_blocks={'macro', 'a'}, collect_raw_data=True)
body = "{# my macro #} {% macro foo(a, b) %} do a thing {%- endmacro %} {# my model #} {% a b %} test {% enda %} raw data so cool"
blocks = extract_toplevel_blocks(
body, allowed_blocks={"macro", "a"}, collect_raw_data=True
)
self.assertEqual(len(blocks), 5)
self.assertEqual(blocks[0].full_block, '{# my macro #} ')
self.assertEqual(blocks[1].block_type_name, 'macro')
self.assertEqual(blocks[1].block_name, 'foo')
self.assertEqual(blocks[1].contents, ' do a thing')
self.assertEqual(blocks[2].full_block, ' {# my model #} ')
self.assertEqual(blocks[3].block_type_name, 'a')
self.assertEqual(blocks[3].block_name, 'b')
self.assertEqual(blocks[3].contents, ' test ')
self.assertEqual(blocks[4].full_block, ' raw data so cool')
self.assertEqual(blocks[0].full_block, "{# my macro #} ")
self.assertEqual(blocks[1].block_type_name, "macro")
self.assertEqual(blocks[1].block_name, "foo")
self.assertEqual(blocks[1].contents, " do a thing")
self.assertEqual(blocks[2].full_block, " {# my model #} ")
self.assertEqual(blocks[3].block_type_name, "a")
self.assertEqual(blocks[3].block_name, "b")
self.assertEqual(blocks[3].contents, " test ")
self.assertEqual(blocks[4].full_block, " raw data so cool")
def test_macro_with_crazy_args(self):
body = '''{% macro foo(a, b=asdf("cool this is 'embedded'" * 3) + external_var, c)%}cool{# block comment with {% endmacro %} in it #} stuff here {% endmacro %}'''
blocks = extract_toplevel_blocks(body, allowed_blocks={'macro'}, collect_raw_data=False)
body = """{% macro foo(a, b=asdf("cool this is 'embedded'" * 3) + external_var, c)%}cool{# block comment with {% endmacro %} in it #} stuff here {% endmacro %}"""
blocks = extract_toplevel_blocks(body, allowed_blocks={"macro"}, collect_raw_data=False)
self.assertEqual(len(blocks), 1)
self.assertEqual(blocks[0].block_type_name, 'macro')
self.assertEqual(blocks[0].block_name, 'foo')
self.assertEqual(blocks[0].contents, 'cool{# block comment with {% endmacro %} in it #} stuff here ')
self.assertEqual(blocks[0].block_type_name, "macro")
self.assertEqual(blocks[0].block_name, "foo")
self.assertEqual(
blocks[0].contents, "cool{# block comment with {% endmacro %} in it #} stuff here "
)
def test_materialization_parse(self):
body = '{% materialization xxx, default %} ... {% endmaterialization %}'
blocks = extract_toplevel_blocks(body, allowed_blocks={'materialization'}, collect_raw_data=False)
body = "{% materialization xxx, default %} ... {% endmaterialization %}"
blocks = extract_toplevel_blocks(
body, allowed_blocks={"materialization"}, collect_raw_data=False
)
self.assertEqual(len(blocks), 1)
self.assertEqual(blocks[0].block_type_name, 'materialization')
self.assertEqual(blocks[0].block_name, 'xxx')
self.assertEqual(blocks[0].block_type_name, "materialization")
self.assertEqual(blocks[0].block_name, "xxx")
self.assertEqual(blocks[0].full_block, body)
body = '{% materialization xxx, adapter="other" %} ... {% endmaterialization %}'
blocks = extract_toplevel_blocks(body, allowed_blocks={'materialization'}, collect_raw_data=False)
blocks = extract_toplevel_blocks(
body, allowed_blocks={"materialization"}, collect_raw_data=False
)
self.assertEqual(len(blocks), 1)
self.assertEqual(blocks[0].block_type_name, 'materialization')
self.assertEqual(blocks[0].block_name, 'xxx')
self.assertEqual(blocks[0].block_type_name, "materialization")
self.assertEqual(blocks[0].block_name, "xxx")
self.assertEqual(blocks[0].full_block, body)
def test_nested_not_ok(self):
# we don't allow nesting same blocks
body = '{% myblock a %} {% myblock b %} {% endmyblock %} {% endmyblock %}'
body = "{% myblock a %} {% myblock b %} {% endmyblock %} {% endmyblock %}"
with self.assertRaises(CompilationError):
extract_toplevel_blocks(body, allowed_blocks={'myblock'})
extract_toplevel_blocks(body, allowed_blocks={"myblock"})
def test_incomplete_block_failure(self):
fullbody = '{% myblock foo %} {% endmyblock %}'
for length in range(len('{% myblock foo %}'), len(fullbody)-1):
fullbody = "{% myblock foo %} {% endmyblock %}"
for length in range(len("{% myblock foo %}"), len(fullbody) - 1):
body = fullbody[:length]
with self.assertRaises(CompilationError):
extract_toplevel_blocks(body, allowed_blocks={'myblock'})
extract_toplevel_blocks(body, allowed_blocks={"myblock"})
def test_wrong_end_failure(self):
body = '{% myblock foo %} {% endotherblock %}'
body = "{% myblock foo %} {% endotherblock %}"
with self.assertRaises(CompilationError):
extract_toplevel_blocks(body, allowed_blocks={'myblock', 'otherblock'})
extract_toplevel_blocks(body, allowed_blocks={"myblock", "otherblock"})
def test_comment_no_end_failure(self):
body = '{# '
body = "{# "
with self.assertRaises(CompilationError):
extract_toplevel_blocks(body)
def test_comment_only(self):
body = '{# myblock #}'
body = "{# myblock #}"
blocks = extract_toplevel_blocks(body)
self.assertEqual(len(blocks), 1)
blocks = extract_toplevel_blocks(body, collect_raw_data=False)
@@ -582,164 +603,184 @@ class TestBlockLexer(unittest.TestCase):
def test_comment_block_self_closing(self):
# test the case where a comment start looks a lot like it closes itself
# (but it doesn't in jinja!)
body = '{#} {% myblock foo %} {#}'
body = "{#} {% myblock foo %} {#}"
blocks = extract_toplevel_blocks(body, collect_raw_data=False)
self.assertEqual(len(blocks), 0)
def test_embedded_self_closing_comment_block(self):
body = '{% myblock foo %} {#}{% endmyblock %} {#}{% endmyblock %}'
blocks = extract_toplevel_blocks(body, allowed_blocks={'myblock'}, collect_raw_data=False)
body = "{% myblock foo %} {#}{% endmyblock %} {#}{% endmyblock %}"
blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False)
self.assertEqual(len(blocks), 1)
self.assertEqual(blocks[0].full_block, body)
self.assertEqual(blocks[0].contents, ' {#}{% endmyblock %} {#}')
self.assertEqual(blocks[0].contents, " {#}{% endmyblock %} {#}")
def test_set_statement(self):
body = '{% set x = 1 %}{% myblock foo %}hi{% endmyblock %}'
blocks = extract_toplevel_blocks(body, allowed_blocks={'myblock'}, collect_raw_data=False)
body = "{% set x = 1 %}{% myblock foo %}hi{% endmyblock %}"
blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False)
self.assertEqual(len(blocks), 1)
self.assertEqual(blocks[0].full_block, '{% myblock foo %}hi{% endmyblock %}')
self.assertEqual(blocks[0].full_block, "{% myblock foo %}hi{% endmyblock %}")
def test_set_block(self):
body = '{% set x %}1{% endset %}{% myblock foo %}hi{% endmyblock %}'
blocks = extract_toplevel_blocks(body, allowed_blocks={'myblock'}, collect_raw_data=False)
body = "{% set x %}1{% endset %}{% myblock foo %}hi{% endmyblock %}"
blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False)
self.assertEqual(len(blocks), 1)
self.assertEqual(blocks[0].full_block, '{% myblock foo %}hi{% endmyblock %}')
self.assertEqual(blocks[0].full_block, "{% myblock foo %}hi{% endmyblock %}")
def test_crazy_set_statement(self):
body = '{% set x = (thing("{% myblock foo %}")) %}{% otherblock bar %}x{% endotherblock %}{% set y = otherthing("{% myblock foo %}") %}'
blocks = extract_toplevel_blocks(body, allowed_blocks={'otherblock'}, collect_raw_data=False)
blocks = extract_toplevel_blocks(
body, allowed_blocks={"otherblock"}, collect_raw_data=False
)
self.assertEqual(len(blocks), 1)
self.assertEqual(blocks[0].full_block, '{% otherblock bar %}x{% endotherblock %}')
self.assertEqual(blocks[0].block_type_name, 'otherblock')
self.assertEqual(blocks[0].full_block, "{% otherblock bar %}x{% endotherblock %}")
self.assertEqual(blocks[0].block_type_name, "otherblock")
def test_do_statement(self):
body = '{% do thing.update() %}{% myblock foo %}hi{% endmyblock %}'
blocks = extract_toplevel_blocks(body, allowed_blocks={'myblock'}, collect_raw_data=False)
body = "{% do thing.update() %}{% myblock foo %}hi{% endmyblock %}"
blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False)
self.assertEqual(len(blocks), 1)
self.assertEqual(blocks[0].full_block, '{% myblock foo %}hi{% endmyblock %}')
self.assertEqual(blocks[0].full_block, "{% myblock foo %}hi{% endmyblock %}")
def test_deceptive_do_statement(self):
body = '{% do thing %}{% myblock foo %}hi{% endmyblock %}'
blocks = extract_toplevel_blocks(body, allowed_blocks={'myblock'}, collect_raw_data=False)
body = "{% do thing %}{% myblock foo %}hi{% endmyblock %}"
blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False)
self.assertEqual(len(blocks), 1)
self.assertEqual(blocks[0].full_block, '{% myblock foo %}hi{% endmyblock %}')
self.assertEqual(blocks[0].full_block, "{% myblock foo %}hi{% endmyblock %}")
def test_do_block(self):
body = '{% do %}thing.update(){% enddo %}{% myblock foo %}hi{% endmyblock %}'
blocks = extract_toplevel_blocks(body, allowed_blocks={'do', 'myblock'}, collect_raw_data=False)
body = "{% do %}thing.update(){% enddo %}{% myblock foo %}hi{% endmyblock %}"
blocks = extract_toplevel_blocks(
body, allowed_blocks={"do", "myblock"}, collect_raw_data=False
)
self.assertEqual(len(blocks), 2)
self.assertEqual(blocks[0].contents, 'thing.update()')
self.assertEqual(blocks[0].block_type_name, 'do')
self.assertEqual(blocks[1].full_block, '{% myblock foo %}hi{% endmyblock %}')
self.assertEqual(blocks[0].contents, "thing.update()")
self.assertEqual(blocks[0].block_type_name, "do")
self.assertEqual(blocks[1].full_block, "{% myblock foo %}hi{% endmyblock %}")
def test_crazy_do_statement(self):
body = '{% do (thing("{% myblock foo %}")) %}{% otherblock bar %}x{% endotherblock %}{% do otherthing("{% myblock foo %}") %}{% myblock x %}hi{% endmyblock %}'
blocks = extract_toplevel_blocks(body, allowed_blocks={'myblock', 'otherblock'}, collect_raw_data=False)
blocks = extract_toplevel_blocks(
body, allowed_blocks={"myblock", "otherblock"}, collect_raw_data=False
)
self.assertEqual(len(blocks), 2)
self.assertEqual(blocks[0].full_block, '{% otherblock bar %}x{% endotherblock %}')
self.assertEqual(blocks[0].block_type_name, 'otherblock')
self.assertEqual(blocks[1].full_block, '{% myblock x %}hi{% endmyblock %}')
self.assertEqual(blocks[1].block_type_name, 'myblock')
self.assertEqual(blocks[0].full_block, "{% otherblock bar %}x{% endotherblock %}")
self.assertEqual(blocks[0].block_type_name, "otherblock")
self.assertEqual(blocks[1].full_block, "{% myblock x %}hi{% endmyblock %}")
self.assertEqual(blocks[1].block_type_name, "myblock")
def test_awful_jinja(self):
blocks = extract_toplevel_blocks(
if_you_do_this_you_are_awful,
allowed_blocks={'snapshot', 'materialization'},
collect_raw_data=False
if_you_do_this_you_are_awful,
allowed_blocks={"snapshot", "materialization"},
collect_raw_data=False,
)
self.assertEqual(len(blocks), 2)
self.assertEqual(len([b for b in blocks if b.block_type_name == '__dbt__data']), 0)
self.assertEqual(blocks[0].block_type_name, 'snapshot')
self.assertEqual(blocks[0].contents, '\n '.join([
'''{% set x = ("{% endsnapshot %}" + (40 * '%})')) %}''',
'{# {% endsnapshot %} #}',
'{% embedded %}',
' some block data right here',
'{% endembedded %}'
]))
self.assertEqual(blocks[1].block_type_name, 'materialization')
self.assertEqual(blocks[1].contents, '\nhi\n')
self.assertEqual(len([b for b in blocks if b.block_type_name == "__dbt__data"]), 0)
self.assertEqual(blocks[0].block_type_name, "snapshot")
self.assertEqual(
blocks[0].contents,
"\n ".join(
[
"""{% set x = ("{% endsnapshot %}" + (40 * '%})')) %}""",
"{# {% endsnapshot %} #}",
"{% embedded %}",
" some block data right here",
"{% endembedded %}",
]
),
)
self.assertEqual(blocks[1].block_type_name, "materialization")
self.assertEqual(blocks[1].contents, "\nhi\n")
def test_quoted_endblock_within_block(self):
body = '{% myblock something -%} {% set x = ("{% endmyblock %}") %} {% endmyblock %}'
blocks = extract_toplevel_blocks(body, allowed_blocks={'myblock'}, collect_raw_data=False)
blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False)
self.assertEqual(len(blocks), 1)
self.assertEqual(blocks[0].block_type_name, 'myblock')
self.assertEqual(blocks[0].block_type_name, "myblock")
self.assertEqual(blocks[0].contents, '{% set x = ("{% endmyblock %}") %} ')
def test_docs_block(self):
body = '{% docs __my_doc__ %} asdf {# nope {% enddocs %}} #} {% enddocs %} {% docs __my_other_doc__ %} asdf "{% enddocs %}'
blocks = extract_toplevel_blocks(body, allowed_blocks={'docs'}, collect_raw_data=False)
blocks = extract_toplevel_blocks(body, allowed_blocks={"docs"}, collect_raw_data=False)
self.assertEqual(len(blocks), 2)
self.assertEqual(blocks[0].block_type_name, 'docs')
self.assertEqual(blocks[0].contents, ' asdf {# nope {% enddocs %}} #} ')
self.assertEqual(blocks[0].block_name, '__my_doc__')
self.assertEqual(blocks[1].block_type_name, 'docs')
self.assertEqual(blocks[0].block_type_name, "docs")
self.assertEqual(blocks[0].contents, " asdf {# nope {% enddocs %}} #} ")
self.assertEqual(blocks[0].block_name, "__my_doc__")
self.assertEqual(blocks[1].block_type_name, "docs")
self.assertEqual(blocks[1].contents, ' asdf "')
self.assertEqual(blocks[1].block_name, '__my_other_doc__')
self.assertEqual(blocks[1].block_name, "__my_other_doc__")
def test_docs_block_expr(self):
body = '{% docs more_doc %} asdf {{ "{% enddocs %}" ~ "}}" }}{% enddocs %}'
blocks = extract_toplevel_blocks(body, allowed_blocks={'docs'}, collect_raw_data=False)
blocks = extract_toplevel_blocks(body, allowed_blocks={"docs"}, collect_raw_data=False)
self.assertEqual(len(blocks), 1)
self.assertEqual(blocks[0].block_type_name, 'docs')
self.assertEqual(blocks[0].block_type_name, "docs")
self.assertEqual(blocks[0].contents, ' asdf {{ "{% enddocs %}" ~ "}}" }}')
self.assertEqual(blocks[0].block_name, 'more_doc')
self.assertEqual(blocks[0].block_name, "more_doc")
def test_unclosed_model_quotes(self):
# test case for https://github.com/dbt-labs/dbt-core/issues/1533
body = '{% model my_model -%} select * from "something"."something_else{% endmodel %}'
blocks = extract_toplevel_blocks(body, allowed_blocks={'model'}, collect_raw_data=False)
blocks = extract_toplevel_blocks(body, allowed_blocks={"model"}, collect_raw_data=False)
self.assertEqual(len(blocks), 1)
self.assertEqual(blocks[0].block_type_name, 'model')
self.assertEqual(blocks[0].block_type_name, "model")
self.assertEqual(blocks[0].contents, 'select * from "something"."something_else')
self.assertEqual(blocks[0].block_name, 'my_model')
self.assertEqual(blocks[0].block_name, "my_model")
def test_if(self):
# if you conditionally define your macros/models, don't
body = '{% if true %}{% macro my_macro() %} adsf {% endmacro %}{% endif %}'
body = "{% if true %}{% macro my_macro() %} adsf {% endmacro %}{% endif %}"
with self.assertRaises(CompilationError):
extract_toplevel_blocks(body)
def test_if_innocuous(self):
body = '{% if true %}{% something %}asdfasd{% endsomething %}{% endif %}'
body = "{% if true %}{% something %}asdfasd{% endsomething %}{% endif %}"
blocks = extract_toplevel_blocks(body)
self.assertEqual(len(blocks), 1)
self.assertEqual(blocks[0].full_block, body)
def test_for(self):
# no for-loops over macros.
body = '{% for x in range(10) %}{% macro my_macro() %} adsf {% endmacro %}{% endfor %}'
body = "{% for x in range(10) %}{% macro my_macro() %} adsf {% endmacro %}{% endfor %}"
with self.assertRaises(CompilationError):
extract_toplevel_blocks(body)
def test_for_innocuous(self):
# no for-loops over macros.
body = '{% for x in range(10) %}{% something my_something %} adsf {% endsomething %}{% endfor %}'
body = "{% for x in range(10) %}{% something my_something %} adsf {% endsomething %}{% endfor %}"
blocks = extract_toplevel_blocks(body)
self.assertEqual(len(blocks), 1)
self.assertEqual(blocks[0].full_block, body)
def test_endif(self):
body = '{% snapshot foo %}select * from thing{% endsnapshot%}{% endif %}'
body = "{% snapshot foo %}select * from thing{% endsnapshot%}{% endif %}"
with self.assertRaises(CompilationError) as err:
extract_toplevel_blocks(body)
self.assertIn('Got an unexpected control flow end tag, got endif but never saw a preceeding if (@ 1:53)', str(err.exception))
self.assertIn(
"Got an unexpected control flow end tag, got endif but never saw a preceeding if (@ 1:53)",
str(err.exception),
)
def test_if_endfor(self):
body = '{% if x %}...{% endfor %}{% endif %}'
body = "{% if x %}...{% endfor %}{% endif %}"
with self.assertRaises(CompilationError) as err:
extract_toplevel_blocks(body)
self.assertIn('Got an unexpected control flow end tag, got endfor but expected endif next (@ 1:13)', str(err.exception))
self.assertIn(
"Got an unexpected control flow end tag, got endfor but expected endif next (@ 1:13)",
str(err.exception),
)
def test_if_endfor_newlines(self):
body = '{% if x %}\n ...\n {% endfor %}\n{% endif %}'
body = "{% if x %}\n ...\n {% endfor %}\n{% endif %}"
with self.assertRaises(CompilationError) as err:
extract_toplevel_blocks(body)
self.assertIn('Got an unexpected control flow end tag, got endfor but expected endif next (@ 3:4)', str(err.exception))
self.assertIn(
"Got an unexpected control flow end tag, got endfor but expected endif next (@ 3:4)",
str(err.exception),
)
bar_block = '''{% mytype bar %}
bar_block = """{% mytype bar %}
{# a comment
that inside it has
{% mytype baz %}
@@ -751,24 +792,28 @@ bar_block = '''{% mytype bar %}
some other stuff
{%- endmytype%}'''
{%- endmytype%}"""
x_block = '''
x_block = """
{% myothertype x %}
before
{##}
and after
{% endmyothertype %}
'''
"""
complex_snapshot_file = '''
complex_snapshot_file = (
"""
{#some stuff {% mytype foo %} #}
{% mytype foo %} some stuff {% endmytype %}
'''+bar_block+x_block
"""
+ bar_block
+ x_block
)
if_you_do_this_you_are_awful = '''
if_you_do_this_you_are_awful = """
{#} here is a comment with a block inside {% block x %} asdf {% endblock %} {#}
{% do
set('foo="bar"')
@@ -790,4 +835,4 @@ if_you_do_this_you_are_awful = '''
{% materialization whatever, adapter='thing' %}
hi
{% endmaterialization %}
'''
"""

View File

@@ -12,52 +12,51 @@ class MockContext:
def __init__(self, node):
self.timing = []
self.node = mock.MagicMock()
self.node._event_status = {
"node_status": RunningStatus.Started
}
self.node._event_status = {"node_status": RunningStatus.Started}
self.node.is_ephemeral_model = True
def noop_ephemeral_result(*args):
return None
class TestSqlCompileRunnerNoIntrospection(unittest.TestCase):
def setUp(self):
self.manifest = {'mock':'manifest'}
self.adapter = Plugin.adapter({})
self.adapter.connection_for = mock.MagicMock()
self.ephemeral_result = lambda: None
inject_adapter(self.adapter, Plugin)
self.manifest = {"mock": "manifest"}
self.adapter = Plugin.adapter({})
self.adapter.connection_for = mock.MagicMock()
self.ephemeral_result = lambda: None
inject_adapter(self.adapter, Plugin)
def tearDown(self):
clear_plugin(Plugin)
@mock.patch('dbt.lib._get_operation_node')
@mock.patch('dbt.task.sql.GenericSqlRunner.compile')
@mock.patch('dbt.task.sql.GenericSqlRunner.ephemeral_result', noop_ephemeral_result)
@mock.patch('dbt.task.base.ExecutionContext', MockContext)
@mock.patch("dbt.lib._get_operation_node")
@mock.patch("dbt.task.sql.GenericSqlRunner.compile")
@mock.patch("dbt.task.sql.GenericSqlRunner.ephemeral_result", noop_ephemeral_result)
@mock.patch("dbt.task.base.ExecutionContext", MockContext)
def test__compile_and_execute__with_connection(self, mock_compile, mock_get_node):
"""
By default, env var for allowing introspection is true, and calling this
method should defer to the parent method.
"""
mock_get_node.return_value = ({}, None, self.adapter)
compile_sql(self.manifest, 'some/path', None)
compile_sql(self.manifest, "some/path", None)
mock_compile.assert_called_once_with(self.manifest)
self.adapter.connection_for.assert_called_once()
@mock.patch('dbt.lib._get_operation_node')
@mock.patch('dbt.task.sql.GenericSqlRunner.compile')
@mock.patch('dbt.task.sql.GenericSqlRunner.ephemeral_result', noop_ephemeral_result)
@mock.patch('dbt.task.base.ExecutionContext', MockContext)
@mock.patch("dbt.lib._get_operation_node")
@mock.patch("dbt.task.sql.GenericSqlRunner.compile")
@mock.patch("dbt.task.sql.GenericSqlRunner.ephemeral_result", noop_ephemeral_result)
@mock.patch("dbt.task.base.ExecutionContext", MockContext)
def test__compile_and_execute__without_connection(self, mock_compile, mock_get_node):
"""
Ensure that compile is called but does not attempt warehouse connection
"""
with mock.patch.dict(os.environ, {"__DBT_ALLOW_INTROSPECTION": "0"}):
mock_get_node.return_value = ({}, None, self.adapter)
compile_sql(self.manifest, 'some/path', None)
compile_sql(self.manifest, "some/path", None)
mock_compile.assert_called_once_with(self.manifest)
self.adapter.connection_for.assert_not_called()

View File

@@ -4,6 +4,7 @@ import unittest
from unittest import mock
from dbt import compilation
try:
from queue import Empty
except ImportError:
@@ -15,27 +16,29 @@ from dbt.graph.cli import parse_difference
def _mock_manifest(nodes):
config = mock.MagicMock(enabled=True)
manifest = mock.MagicMock(nodes={
n: mock.MagicMock(
unique_id=n,
package_name='pkg',
name=n,
empty=False,
config=config,
fqn=['pkg', n],
) for n in nodes
})
manifest = mock.MagicMock(
nodes={
n: mock.MagicMock(
unique_id=n,
package_name="pkg",
name=n,
empty=False,
config=config,
fqn=["pkg", n],
)
for n in nodes
}
)
manifest.expect.side_effect = lambda n: mock.MagicMock(unique_id=n)
return manifest
class LinkerTest(unittest.TestCase):
def setUp(self):
self.linker = compilation.Linker()
def test_linker_add_node(self):
expected_nodes = ['A', 'B', 'C']
expected_nodes = ["A", "B", "C"]
for node in expected_nodes:
self.linker.add_node(node)
@@ -46,11 +49,11 @@ class LinkerTest(unittest.TestCase):
self.assertEqual(len(actual_nodes), len(expected_nodes))
def test_linker_write_graph(self):
expected_nodes = ['A', 'B', 'C']
expected_nodes = ["A", "B", "C"]
for node in expected_nodes:
self.linker.add_node(node)
manifest = _mock_manifest('ABC')
manifest = _mock_manifest("ABC")
(fd, fname) = tempfile.mkstemp()
os.close(fd)
try:
@@ -67,7 +70,7 @@ class LinkerTest(unittest.TestCase):
graph = compilation.Graph(self.linker.graph)
selector = NodeSelector(graph, manifest)
# TODO: The "eager" string below needs to be replaced with programatic access
# to the default value for the indirect selection parameter in
# to the default value for the indirect selection parameter in
# dbt.cli.params.indirect_selection
#
# Doing that is actually a little tricky, so I'm punting it to a new ticket GH #6397
@@ -75,62 +78,62 @@ class LinkerTest(unittest.TestCase):
return selector.get_graph_queue(spec)
def test_linker_add_dependency(self):
actual_deps = [('A', 'B'), ('A', 'C'), ('B', 'C')]
actual_deps = [("A", "B"), ("A", "C"), ("B", "C")]
for (l, r) in actual_deps:
self.linker.dependency(l, r)
queue = self._get_graph_queue(_mock_manifest('ABC'))
queue = self._get_graph_queue(_mock_manifest("ABC"))
got = queue.get(block=False)
self.assertEqual(got.unique_id, 'C')
self.assertEqual(got.unique_id, "C")
with self.assertRaises(Empty):
queue.get(block=False)
self.assertFalse(queue.empty())
queue.mark_done('C')
queue.mark_done("C")
self.assertFalse(queue.empty())
got = queue.get(block=False)
self.assertEqual(got.unique_id, 'B')
self.assertEqual(got.unique_id, "B")
with self.assertRaises(Empty):
queue.get(block=False)
self.assertFalse(queue.empty())
queue.mark_done('B')
queue.mark_done("B")
self.assertFalse(queue.empty())
got = queue.get(block=False)
self.assertEqual(got.unique_id, 'A')
self.assertEqual(got.unique_id, "A")
with self.assertRaises(Empty):
queue.get(block=False)
self.assertTrue(queue.empty())
queue.mark_done('A')
queue.mark_done("A")
self.assert_would_join(queue)
self.assertTrue(queue.empty())
def test_linker_add_disjoint_dependencies(self):
actual_deps = [('A', 'B')]
additional_node = 'Z'
actual_deps = [("A", "B")]
additional_node = "Z"
for (l, r) in actual_deps:
self.linker.dependency(l, r)
self.linker.add_node(additional_node)
queue = self._get_graph_queue(_mock_manifest('ABCZ'))
queue = self._get_graph_queue(_mock_manifest("ABCZ"))
# the first one we get must be B, it has the longest dep chain
first = queue.get(block=False)
self.assertEqual(first.unique_id, 'B')
self.assertEqual(first.unique_id, "B")
self.assertFalse(queue.empty())
queue.mark_done('B')
queue.mark_done("B")
self.assertFalse(queue.empty())
second = queue.get(block=False)
self.assertIn(second.unique_id, {'A', 'Z'})
self.assertIn(second.unique_id, {"A", "Z"})
self.assertFalse(queue.empty())
queue.mark_done(second.unique_id)
self.assertFalse(queue.empty())
third = queue.get(block=False)
self.assertIn(third.unique_id, {'A', 'Z'})
self.assertIn(third.unique_id, {"A", "Z"})
with self.assertRaises(Empty):
queue.get(block=False)
self.assertNotEqual(second.unique_id, third.unique_id)
@@ -140,38 +143,38 @@ class LinkerTest(unittest.TestCase):
self.assertTrue(queue.empty())
def test_linker_dependencies_limited_to_some_nodes(self):
actual_deps = [('A', 'B'), ('B', 'C'), ('C', 'D')]
actual_deps = [("A", "B"), ("B", "C"), ("C", "D")]
for (l, r) in actual_deps:
self.linker.dependency(l, r)
queue = self._get_graph_queue(_mock_manifest('ABCD'), ['B'])
queue = self._get_graph_queue(_mock_manifest("ABCD"), ["B"])
got = queue.get(block=False)
self.assertEqual(got.unique_id, 'B')
self.assertEqual(got.unique_id, "B")
self.assertTrue(queue.empty())
queue.mark_done('B')
queue.mark_done("B")
self.assert_would_join(queue)
queue_2 = queue = self._get_graph_queue(_mock_manifest('ABCD'), ['A', 'B'])
queue_2 = queue = self._get_graph_queue(_mock_manifest("ABCD"), ["A", "B"])
got = queue_2.get(block=False)
self.assertEqual(got.unique_id, 'B')
self.assertEqual(got.unique_id, "B")
self.assertFalse(queue_2.empty())
with self.assertRaises(Empty):
queue_2.get(block=False)
queue_2.mark_done('B')
queue_2.mark_done("B")
self.assertFalse(queue_2.empty())
got = queue_2.get(block=False)
self.assertEqual(got.unique_id, 'A')
self.assertEqual(got.unique_id, "A")
self.assertTrue(queue_2.empty())
with self.assertRaises(Empty):
queue_2.get(block=False)
self.assertTrue(queue_2.empty())
queue_2.mark_done('A')
queue_2.mark_done("A")
self.assert_would_join(queue_2)
def test__find_cycles__cycles(self):
actual_deps = [('A', 'B'), ('B', 'C'), ('C', 'A')]
actual_deps = [("A", "B"), ("B", "C"), ("C", "A")]
for (l, r) in actual_deps:
self.linker.dependency(l, r)
@@ -179,9 +182,9 @@ class LinkerTest(unittest.TestCase):
self.assertIsNotNone(self.linker.find_cycles())
def test__find_cycles__no_cycles(self):
actual_deps = [('A', 'B'), ('B', 'C'), ('C', 'D')]
actual_deps = [("A", "B"), ("B", "C"), ("C", "D")]
for (l, r) in actual_deps:
self.linker.dependency(l, r)
self.assertIsNone(self.linker.find_cycles())
self.assertIsNone(self.linker.find_cycles())

View File

@@ -1,16 +1,10 @@
import os
import unittest
from unittest.mock import MagicMock, patch
from dataclasses import dataclass, field
from typing import Dict, Any
from dbt.clients.jinja_static import statically_extract_macro_calls
from dbt.context.base import generate_base_context
class MacroCalls(unittest.TestCase):
def setUp(self):
self.macro_strings = [
"{% macro parent_macro() %} {% do return(nested_macro()) %} {% endmacro %}",
@@ -29,18 +23,18 @@ class MacroCalls(unittest.TestCase):
]
self.possible_macro_calls = [
['nested_macro'],
['load_result'],
['get_snapshot_unique_id'],
['get_columns_in_query'],
['get_snapshot_unique_id'],
['current_timestamp_backcompat'],
['test_some_kind4', 'foo_utils4.test_some_kind4'],
['test_some_kind5', 'foo_utils5.test_some_kind5'],
["nested_macro"],
["load_result"],
["get_snapshot_unique_id"],
["get_columns_in_query"],
["get_snapshot_unique_id"],
["current_timestamp_backcompat"],
["test_some_kind4", "foo_utils4.test_some_kind4"],
["test_some_kind5", "foo_utils5.test_some_kind5"],
]
def test_macro_calls(self):
cli_vars = {'local_utils_dispatch_list': ['foo_utils4']}
cli_vars = {"local_utils_dispatch_list": ["foo_utils4"]}
ctx = generate_base_context(cli_vars)
index = 0
@@ -48,5 +42,3 @@ class MacroCalls(unittest.TestCase):
possible_macro_calls = statically_extract_macro_calls(macro_string, ctx)
self.assertEqual(self.possible_macro_calls[index], possible_macro_calls)
index += 1

View File

@@ -1,9 +1,7 @@
import unittest
from unittest import mock
from dbt.contracts.graph.nodes import (
Macro
)
from dbt.contracts.graph.nodes import Macro
from dbt.context.macro_resolver import MacroResolver
@@ -11,8 +9,8 @@ def mock_macro(name, package_name):
macro = mock.MagicMock(
__class__=Macro,
package_name=package_name,
resource_type='macro',
unique_id=f'macro.{package_name}.{name}',
resource_type="macro",
unique_id=f"macro.{package_name}.{name}",
)
# Mock(name=...) does not set the `name` attribute, this does.
macro.name = name
@@ -20,20 +18,19 @@ def mock_macro(name, package_name):
class TestMacroResolver(unittest.TestCase):
def test_resolver(self):
data = [
{'package_name': 'my_test', 'name': 'unique'},
{'package_name': 'my_test', 'name': 'macro_xx'},
{'package_name': 'one', 'name': 'unique'},
{'package_name': 'one', 'name': 'not_null'},
{'package_name': 'two', 'name': 'macro_a'},
{'package_name': 'two', 'name': 'macro_b'},
{"package_name": "my_test", "name": "unique"},
{"package_name": "my_test", "name": "macro_xx"},
{"package_name": "one", "name": "unique"},
{"package_name": "one", "name": "not_null"},
{"package_name": "two", "name": "macro_a"},
{"package_name": "two", "name": "macro_b"},
]
macros = {}
for mdata in data:
macro = mock_macro(mdata['name'], mdata['package_name'])
macro = mock_macro(mdata["name"], mdata["package_name"])
macros[macro.unique_id] = macro
resolver = MacroResolver(macros, 'my_test', ['one'])
assert(resolver)
self.assertEqual(resolver.get_macro_id('one', 'not_null'), 'macro.one.not_null')
resolver = MacroResolver(macros, "my_test", ["one"])
assert resolver
self.assertEqual(resolver.get_macro_id("one", "not_null"), "macro.one.not_null")

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,3 @@
import dbt.exceptions
import textwrap
import yaml
from collections import OrderedDict
@@ -14,9 +13,9 @@ def get_selector_dict(txt: str) -> OrderedDict:
class SelectorUnitTest(unittest.TestCase):
def test_compare_cli_non_cli(self):
dct = get_selector_dict('''\
dct = get_selector_dict(
"""\
selectors:
- name: nightly_diet_snowplow
description: "This uses more CLI-style syntax"
@@ -51,73 +50,82 @@ class SelectorUnitTest(unittest.TestCase):
value: incremental
- method: fqn
value: export_performance_timing
''')
"""
)
sel_dict = SelectorDict.parse_from_selectors_list(dct['selectors'])
assert(sel_dict)
with_strings = sel_dict['nightly_diet_snowplow']['definition']
no_strings = sel_dict['nightly_diet_snowplow_full']['definition']
sel_dict = SelectorDict.parse_from_selectors_list(dct["selectors"])
assert sel_dict
with_strings = sel_dict["nightly_diet_snowplow"]["definition"]
no_strings = sel_dict["nightly_diet_snowplow_full"]["definition"]
self.assertEqual(with_strings, no_strings)
def test_single_string_definition(self):
dct = get_selector_dict('''\
dct = get_selector_dict(
"""\
selectors:
- name: nightly_selector
definition:
'tag:nightly'
''')
"""
)
sel_dict = SelectorDict.parse_from_selectors_list(dct['selectors'])
assert(sel_dict)
expected = {'method': 'tag', 'value': 'nightly'}
definition = sel_dict['nightly_selector']['definition']
sel_dict = SelectorDict.parse_from_selectors_list(dct["selectors"])
assert sel_dict
expected = {"method": "tag", "value": "nightly"}
definition = sel_dict["nightly_selector"]["definition"]
self.assertEqual(expected, definition)
def test_single_key_value_definition(self):
dct = get_selector_dict('''\
dct = get_selector_dict(
"""\
selectors:
- name: nightly_selector
definition:
tag: nightly
''')
"""
)
sel_dict = SelectorDict.parse_from_selectors_list(dct['selectors'])
assert(sel_dict)
expected = {'method': 'tag', 'value': 'nightly'}
definition = sel_dict['nightly_selector']['definition']
sel_dict = SelectorDict.parse_from_selectors_list(dct["selectors"])
assert sel_dict
expected = {"method": "tag", "value": "nightly"}
definition = sel_dict["nightly_selector"]["definition"]
self.assertEqual(expected, definition)
def test_parent_definition(self):
dct = get_selector_dict('''\
dct = get_selector_dict(
"""\
selectors:
- name: kpi_nightly_selector
definition:
'+exposure:kpi_nightly'
''')
"""
)
sel_dict = SelectorDict.parse_from_selectors_list(dct['selectors'])
assert(sel_dict)
expected = {'method': 'exposure', 'value': 'kpi_nightly', 'parents': True}
definition = sel_dict['kpi_nightly_selector']['definition']
sel_dict = SelectorDict.parse_from_selectors_list(dct["selectors"])
assert sel_dict
expected = {"method": "exposure", "value": "kpi_nightly", "parents": True}
definition = sel_dict["kpi_nightly_selector"]["definition"]
self.assertEqual(expected, definition)
def test_plus_definition(self):
dct = get_selector_dict('''\
dct = get_selector_dict(
"""\
selectors:
- name: my_model_children_selector
definition:
'my_model+2'
''')
"""
)
sel_dict = SelectorDict.parse_from_selectors_list(dct['selectors'])
assert(sel_dict)
expected = {'method': 'fqn', 'value': 'my_model', 'children': True, 'children_depth': '2'}
definition = sel_dict['my_model_children_selector']['definition']
sel_dict = SelectorDict.parse_from_selectors_list(dct["selectors"])
assert sel_dict
expected = {"method": "fqn", "value": "my_model", "children": True, "children_depth": "2"}
definition = sel_dict["my_model_children_selector"]["definition"]
self.assertEqual(expected, definition)
def test_selector_definition(self):
dct = get_selector_dict('''\
dct = get_selector_dict(
"""\
selectors:
- name: default
definition:
@@ -129,16 +137,18 @@ class SelectorUnitTest(unittest.TestCase):
definition:
method: selector
value: default
''')
"""
)
sel_dict = SelectorDict.parse_from_selectors_list(dct['selectors'])
assert(sel_dict)
definition = sel_dict['default']['definition']
expected = sel_dict['inherited']['definition']
sel_dict = SelectorDict.parse_from_selectors_list(dct["selectors"])
assert sel_dict
definition = sel_dict["default"]["definition"]
expected = sel_dict["inherited"]["definition"]
self.assertEqual(expected, definition)
def test_selector_definition_with_exclusion(self):
dct = get_selector_dict('''\
dct = get_selector_dict(
"""\
selectors:
- name: default
definition:
@@ -162,26 +172,28 @@ class SelectorUnitTest(unittest.TestCase):
- tag: bar
- exclude:
- tag: bar
''')
"""
)
sel_dict = SelectorDict.parse_from_selectors_list((dct['selectors']))
assert(sel_dict)
definition = sel_dict['inherited']['definition']
expected = sel_dict['comparison']['definition']
sel_dict = SelectorDict.parse_from_selectors_list((dct["selectors"]))
assert sel_dict
definition = sel_dict["inherited"]["definition"]
expected = sel_dict["comparison"]["definition"]
self.assertEqual(expected, definition)
def test_missing_selector(self):
dct = get_selector_dict('''\
dct = get_selector_dict(
"""\
selectors:
- name: inherited
definition:
method: selector
value: default
''')
"""
)
with self.assertRaises(DbtSelectorsError) as err:
sel_dict = SelectorDict.parse_from_selectors_list((dct['selectors']))
SelectorDict.parse_from_selectors_list((dct["selectors"]))
self.assertEqual(
'Existing selector definition for default not found.',
str(err.exception.msg)
"Existing selector definition for default not found.", str(err.exception.msg)
)

View File

@@ -1,89 +1,91 @@
from dataclasses import dataclass, field
from dbt.dataclass_schema import dbtClassMixin
from typing import List, Dict
import pytest
from dbt.contracts.graph.model_config import MergeBehavior, ShowBehavior, CompareBehavior
@dataclass
class ThingWithMergeBehavior(dbtClassMixin):
default_behavior: int
appended: List[str] = field(metadata={'merge': MergeBehavior.Append})
updated: Dict[str, int] = field(metadata={'merge': MergeBehavior.Update})
clobbered: str = field(metadata={'merge': MergeBehavior.Clobber})
keysappended: Dict[str, int] = field(metadata={'merge': MergeBehavior.DictKeyAppend})
appended: List[str] = field(metadata={"merge": MergeBehavior.Append})
updated: Dict[str, int] = field(metadata={"merge": MergeBehavior.Update})
clobbered: str = field(metadata={"merge": MergeBehavior.Clobber})
keysappended: Dict[str, int] = field(metadata={"merge": MergeBehavior.DictKeyAppend})
def test_merge_behavior_meta():
existing = {'foo': 'bar'}
existing = {"foo": "bar"}
initial_existing = existing.copy()
assert set(MergeBehavior) == {MergeBehavior.Append, MergeBehavior.Update, MergeBehavior.Clobber, MergeBehavior.DictKeyAppend}
assert set(MergeBehavior) == {
MergeBehavior.Append,
MergeBehavior.Update,
MergeBehavior.Clobber,
MergeBehavior.DictKeyAppend,
}
for behavior in MergeBehavior:
assert behavior.meta() == {'merge': behavior}
assert behavior.meta(existing) == {'merge': behavior, 'foo': 'bar'}
assert behavior.meta() == {"merge": behavior}
assert behavior.meta(existing) == {"merge": behavior, "foo": "bar"}
assert existing == initial_existing
def test_merge_behavior_from_field():
fields = [f[0] for f in ThingWithMergeBehavior._get_fields()]
fields = {name: f for f, name in ThingWithMergeBehavior._get_fields()}
assert set(fields) == {'default_behavior', 'appended', 'updated', 'clobbered', 'keysappended'}
assert MergeBehavior.from_field(fields['default_behavior']) == MergeBehavior.Clobber
assert MergeBehavior.from_field(fields['appended']) == MergeBehavior.Append
assert MergeBehavior.from_field(fields['updated']) == MergeBehavior.Update
assert MergeBehavior.from_field(fields['clobbered']) == MergeBehavior.Clobber
assert MergeBehavior.from_field(fields['keysappended']) == MergeBehavior.DictKeyAppend
assert set(fields) == {"default_behavior", "appended", "updated", "clobbered", "keysappended"}
assert MergeBehavior.from_field(fields["default_behavior"]) == MergeBehavior.Clobber
assert MergeBehavior.from_field(fields["appended"]) == MergeBehavior.Append
assert MergeBehavior.from_field(fields["updated"]) == MergeBehavior.Update
assert MergeBehavior.from_field(fields["clobbered"]) == MergeBehavior.Clobber
assert MergeBehavior.from_field(fields["keysappended"]) == MergeBehavior.DictKeyAppend
@dataclass
class ThingWithShowBehavior(dbtClassMixin):
default_behavior: int
hidden: str = field(metadata={'show_hide': ShowBehavior.Hide})
shown: float = field(metadata={'show_hide': ShowBehavior.Show})
hidden: str = field(metadata={"show_hide": ShowBehavior.Hide})
shown: float = field(metadata={"show_hide": ShowBehavior.Show})
def test_show_behavior_meta():
existing = {'foo': 'bar'}
existing = {"foo": "bar"}
initial_existing = existing.copy()
assert set(ShowBehavior) == {ShowBehavior.Hide, ShowBehavior.Show}
for behavior in ShowBehavior:
assert behavior.meta() == {'show_hide': behavior}
assert behavior.meta(existing) == {'show_hide': behavior, 'foo': 'bar'}
assert behavior.meta() == {"show_hide": behavior}
assert behavior.meta(existing) == {"show_hide": behavior, "foo": "bar"}
assert existing == initial_existing
def test_show_behavior_from_field():
fields = [f[0] for f in ThingWithShowBehavior._get_fields()]
fields = {name: f for f, name in ThingWithShowBehavior._get_fields()}
assert set(fields) == {'default_behavior', 'hidden', 'shown'}
assert ShowBehavior.from_field(fields['default_behavior']) == ShowBehavior.Show
assert ShowBehavior.from_field(fields['hidden']) == ShowBehavior.Hide
assert ShowBehavior.from_field(fields['shown']) == ShowBehavior.Show
assert set(fields) == {"default_behavior", "hidden", "shown"}
assert ShowBehavior.from_field(fields["default_behavior"]) == ShowBehavior.Show
assert ShowBehavior.from_field(fields["hidden"]) == ShowBehavior.Hide
assert ShowBehavior.from_field(fields["shown"]) == ShowBehavior.Show
@dataclass
class ThingWithCompareBehavior(dbtClassMixin):
default_behavior: int
included: float = field(metadata={'compare': CompareBehavior.Include})
excluded: str = field(metadata={'compare': CompareBehavior.Exclude})
included: float = field(metadata={"compare": CompareBehavior.Include})
excluded: str = field(metadata={"compare": CompareBehavior.Exclude})
def test_compare_behavior_meta():
existing = {'foo': 'bar'}
existing = {"foo": "bar"}
initial_existing = existing.copy()
assert set(CompareBehavior) == {CompareBehavior.Include, CompareBehavior.Exclude}
for behavior in CompareBehavior:
assert behavior.meta() == {'compare': behavior}
assert behavior.meta(existing) == {'compare': behavior, 'foo': 'bar'}
assert behavior.meta() == {"compare": behavior}
assert behavior.meta(existing) == {"compare": behavior, "foo": "bar"}
assert existing == initial_existing
def test_compare_behavior_from_field():
fields = [f[0] for f in ThingWithCompareBehavior._get_fields()]
fields = {name: f for f, name in ThingWithCompareBehavior._get_fields()}
assert set(fields) == {'default_behavior', 'included', 'excluded'}
assert CompareBehavior.from_field(fields['default_behavior']) == CompareBehavior.Include
assert CompareBehavior.from_field(fields['included']) == CompareBehavior.Include
assert CompareBehavior.from_field(fields['excluded']) == CompareBehavior.Exclude
assert set(fields) == {"default_behavior", "included", "excluded"}
assert CompareBehavior.from_field(fields["default_behavior"]) == CompareBehavior.Include
assert CompareBehavior.from_field(fields["included"]) == CompareBehavior.Include
assert CompareBehavior.from_field(fields["excluded"]) == CompareBehavior.Exclude

View File

@@ -6,13 +6,12 @@ from .utils import config_from_parts_or_dicts, normalize
from dbt.contracts.files import SourceFile, FileHash, FilePath
from dbt.contracts.graph.manifest import Manifest, ManifestStateCheck
from dbt.parser.search import FileBlock
from dbt.parser import manifest
class MatchingHash(FileHash):
def __init__(self):
return super().__init__('', '')
return super().__init__("", "")
def __eq__(self, other):
return True
@@ -20,7 +19,7 @@ class MatchingHash(FileHash):
class MismatchedHash(FileHash):
def __init__(self):
return super().__init__('', '')
return super().__init__("", "")
def __eq__(self, other):
return False
@@ -29,53 +28,52 @@ class MismatchedHash(FileHash):
class TestLoader(unittest.TestCase):
def setUp(self):
profile_data = {
'target': 'test',
'quoting': {},
'outputs': {
'test': {
'type': 'postgres',
'host': 'localhost',
'schema': 'analytics',
'user': 'test',
'pass': 'test',
'dbname': 'test',
'port': 1,
"target": "test",
"quoting": {},
"outputs": {
"test": {
"type": "postgres",
"host": "localhost",
"schema": "analytics",
"user": "test",
"pass": "test",
"dbname": "test",
"port": 1,
}
}
},
}
root_project = {
'name': 'root',
'version': '0.1',
'profile': 'test',
'project-root': normalize('/usr/src/app'),
'config-version': 2,
"name": "root",
"version": "0.1",
"profile": "test",
"project-root": normalize("/usr/src/app"),
"config-version": 2,
}
self.root_project_config = config_from_parts_or_dicts(
project=root_project,
profile=profile_data,
cli_vars='{"test_schema_name": "foo"}'
project=root_project, profile=profile_data, cli_vars='{"test_schema_name": "foo"}'
)
self.parser = mock.MagicMock()
# Create the Manifest.state_check patcher
@patch('dbt.parser.manifest.ManifestLoader.build_manifest_state_check')
@patch("dbt.parser.manifest.ManifestLoader.build_manifest_state_check")
def _mock_state_check(self):
config = self.root_project
all_projects = self.all_projects
return ManifestStateCheck(
vars_hash=FileHash.from_contents('vars'),
vars_hash=FileHash.from_contents("vars"),
project_hashes={name: FileHash.from_contents(name) for name in all_projects},
profile_hash=FileHash.from_contents('profile'),
profile_hash=FileHash.from_contents("profile"),
)
self.load_state_check = patch('dbt.parser.manifest.ManifestLoader.build_manifest_state_check')
self.load_state_check = patch(
"dbt.parser.manifest.ManifestLoader.build_manifest_state_check"
)
self.mock_state_check = self.load_state_check.start()
self.mock_state_check.side_effect = _mock_state_check
self.loader = manifest.ManifestLoader(
self.root_project_config,
{'root': self.root_project_config}
self.root_project_config, {"root": self.root_project_config}
)
def _new_manifest(self):

File diff suppressed because it is too large Load Diff

View File

@@ -1,8 +1,6 @@
import unittest
from unittest import mock
import time
import dbt.exceptions
from dbt.parser.partial import PartialParsing
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.nodes import ModelNode
@@ -12,51 +10,81 @@ from .utils import normalize
class TestPartialParsing(unittest.TestCase):
def setUp(self):
project_name = 'my_test'
project_root = '/users/root'
project_name = "my_test"
project_root = "/users/root"
sql_model_file = SourceFile(
path=FilePath(project_root=project_root, searched_path='models', relative_path='my_model.sql', modification_time=time.time()),
checksum=FileHash.from_contents('abcdef'),
path=FilePath(
project_root=project_root,
searched_path="models",
relative_path="my_model.sql",
modification_time=time.time(),
),
checksum=FileHash.from_contents("abcdef"),
project_name=project_name,
parse_file_type=ParseFileType.Model,
nodes=['model.my_test.my_model'],
nodes=["model.my_test.my_model"],
env_vars=[],
)
sql_model_file_untouched = SourceFile(
path=FilePath(project_root=project_root, searched_path='models', relative_path='my_model_untouched.sql', modification_time=time.time()),
checksum=FileHash.from_contents('abcdef'),
path=FilePath(
project_root=project_root,
searched_path="models",
relative_path="my_model_untouched.sql",
modification_time=time.time(),
),
checksum=FileHash.from_contents("abcdef"),
project_name=project_name,
parse_file_type=ParseFileType.Model,
nodes=['model.my_test.my_model_untouched'],
nodes=["model.my_test.my_model_untouched"],
env_vars=[],
)
python_model_file = SourceFile(
path=FilePath(project_root=project_root, searched_path='models', relative_path='python_model.py', modification_time=time.time()),
checksum=FileHash.from_contents('lalala'),
path=FilePath(
project_root=project_root,
searched_path="models",
relative_path="python_model.py",
modification_time=time.time(),
),
checksum=FileHash.from_contents("lalala"),
project_name=project_name,
parse_file_type=ParseFileType.Model,
nodes=['model.my_test.python_model'],
nodes=["model.my_test.python_model"],
env_vars=[],
)
python_model_file_untouched = SourceFile(
path=FilePath(project_root=project_root, searched_path='models', relative_path='python_model_untouched.py', modification_time=time.time()),
checksum=FileHash.from_contents('lalala'),
path=FilePath(
project_root=project_root,
searched_path="models",
relative_path="python_model_untouched.py",
modification_time=time.time(),
),
checksum=FileHash.from_contents("lalala"),
project_name=project_name,
parse_file_type=ParseFileType.Model,
nodes=['model.my_test.python_model_untouched'],
nodes=["model.my_test.python_model_untouched"],
env_vars=[],
)
schema_file = SchemaSourceFile(
path=FilePath(project_root=project_root, searched_path='models', relative_path='schema.yml', modification_time=time.time()),
checksum=FileHash.from_contents('ghijkl'),
path=FilePath(
project_root=project_root,
searched_path="models",
relative_path="schema.yml",
modification_time=time.time(),
),
checksum=FileHash.from_contents("ghijkl"),
project_name=project_name,
parse_file_type=ParseFileType.Schema,
dfy={'version': 2, 'models': [{'name': 'my_model', 'description': 'Test model'}, {'name': 'python_model', 'description': 'python'}]},
ndp=['model.my_test.my_model'],
dfy={
"version": 2,
"models": [
{"name": "my_model", "description": "Test model"},
{"name": "python_model", "description": "python"},
],
},
ndp=["model.my_test.my_model"],
env_vars={},
)
self.saved_files = {
@@ -65,23 +93,27 @@ class TestPartialParsing(unittest.TestCase):
python_model_file.file_id: python_model_file,
sql_model_file_untouched.file_id: sql_model_file_untouched,
python_model_file_untouched.file_id: python_model_file_untouched,
}
sql_model_node = self.get_model('my_model')
sql_model_node_untouched = self.get_model('my_model_untouched')
python_model_node = self.get_python_model('python_model')
python_model_node_untouched = self.get_python_model('python_model_untouched')
}
sql_model_node = self.get_model("my_model")
sql_model_node_untouched = self.get_model("my_model_untouched")
python_model_node = self.get_python_model("python_model")
python_model_node_untouched = self.get_python_model("python_model_untouched")
nodes = {
sql_model_node.unique_id: sql_model_node,
python_model_node.unique_id: python_model_node,
sql_model_node_untouched.unique_id: sql_model_node_untouched,
python_model_node_untouched.unique_id: python_model_node_untouched,
}
}
self.saved_manifest = Manifest(files=self.saved_files, nodes=nodes)
self.new_files = {
sql_model_file.file_id: SourceFile.from_dict(sql_model_file.to_dict()),
python_model_file.file_id: SourceFile.from_dict(python_model_file.to_dict()),
sql_model_file_untouched.file_id: SourceFile.from_dict(sql_model_file_untouched.to_dict()),
python_model_file_untouched.file_id: SourceFile.from_dict(python_model_file_untouched.to_dict()),
sql_model_file_untouched.file_id: SourceFile.from_dict(
sql_model_file_untouched.to_dict()
),
python_model_file_untouched.file_id: SourceFile.from_dict(
python_model_file_untouched.to_dict()
),
schema_file.file_id: SchemaSourceFile.from_dict(schema_file.to_dict()),
}
@@ -89,38 +121,38 @@ class TestPartialParsing(unittest.TestCase):
def get_model(self, name):
return ModelNode(
package_name='my_test',
path=f'{name}.sql',
original_file_path=f'models/{name}.sql',
language='sql',
raw_code='select * from wherever',
package_name="my_test",
path=f"{name}.sql",
original_file_path=f"models/{name}.sql",
language="sql",
raw_code="select * from wherever",
name=name,
resource_type=NodeType.Model,
unique_id=f'model.my_test.{name}',
fqn=['my_test', 'models', name],
database='test_db',
schema='test_schema',
alias='bar',
checksum=FileHash.from_contents(''),
patch_path='my_test://' + normalize('models/schema.yml'),
unique_id=f"model.my_test.{name}",
fqn=["my_test", "models", name],
database="test_db",
schema="test_schema",
alias="bar",
checksum=FileHash.from_contents(""),
patch_path="my_test://" + normalize("models/schema.yml"),
)
def get_python_model(self, name):
return ModelNode(
package_name='my_test',
path=f'{name}.py',
original_file_path=f'models/{name}.py',
raw_code='import something',
language='python',
package_name="my_test",
path=f"{name}.py",
original_file_path=f"models/{name}.py",
raw_code="import something",
language="python",
name=name,
resource_type=NodeType.Model,
unique_id=f'model.my_test.{name}',
fqn=['my_test', 'models', name],
database='test_db',
schema='test_schema',
alias='bar',
checksum=FileHash.from_contents(''),
patch_path='my_test://' + normalize('models/schema.yml'),
unique_id=f"model.my_test.{name}",
fqn=["my_test", "models", name],
database="test_db",
schema="test_schema",
alias="bar",
checksum=FileHash.from_contents(""),
patch_path="my_test://" + normalize("models/schema.yml"),
)
def test_simple(self):
@@ -130,25 +162,35 @@ class TestPartialParsing(unittest.TestCase):
self.assertTrue(self.partial_parsing.skip_parsing())
# Change a model file
sql_model_file_id = 'my_test://' + normalize('models/my_model.sql')
self.partial_parsing.new_files[sql_model_file_id].checksum = FileHash.from_contents('xyzabc')
python_model_file_id = 'my_test://' + normalize('models/python_model.py')
self.partial_parsing.new_files[python_model_file_id].checksum = FileHash.from_contents('ohohoh')
sql_model_file_id = "my_test://" + normalize("models/my_model.sql")
self.partial_parsing.new_files[sql_model_file_id].checksum = FileHash.from_contents(
"xyzabc"
)
python_model_file_id = "my_test://" + normalize("models/python_model.py")
self.partial_parsing.new_files[python_model_file_id].checksum = FileHash.from_contents(
"ohohoh"
)
self.partial_parsing.build_file_diff()
self.assertFalse(self.partial_parsing.skip_parsing())
pp_files = self.partial_parsing.get_parsing_files()
pp_files["my_test"]["ModelParser"] = set(pp_files["my_test"]["ModelParser"])
# models has 'patch_path' so we expect to see a SchemaParser file listed
schema_file_id = 'my_test://' + normalize('models/schema.yml')
expected_pp_files = {'my_test': {'ModelParser': set([sql_model_file_id, python_model_file_id]), 'SchemaParser': [schema_file_id]}}
schema_file_id = "my_test://" + normalize("models/schema.yml")
expected_pp_files = {
"my_test": {
"ModelParser": set([sql_model_file_id, python_model_file_id]),
"SchemaParser": [schema_file_id],
}
}
self.assertEqual(pp_files, expected_pp_files)
expected_pp_dict = {'version': 2, 'models': [{'name': 'my_model', 'description': 'Test model'}, {'name': 'python_model', 'description': 'python'}]}
schema_file = self.saved_files[schema_file_id]
schema_file_model_names = set([model['name'] for model in schema_file.pp_dict['models']])
expected_model_names = set(['python_model', 'my_model'])
schema_file_model_names = set([model["name"] for model in schema_file.pp_dict["models"]])
expected_model_names = set(["python_model", "my_model"])
self.assertEqual(schema_file_model_names, expected_model_names)
schema_file_model_descriptions = set([model['description'] for model in schema_file.pp_dict['models']])
expected_model_descriptions = set(['Test model', 'python'])
schema_file_model_descriptions = set(
[model["description"] for model in schema_file.pp_dict["models"]]
)
expected_model_descriptions = set(["Test model", "python"])
self.assertEqual(schema_file_model_descriptions, expected_model_descriptions)

View File

@@ -3,7 +3,6 @@ import decimal
import unittest
from unittest import mock
import dbt.flags as flags
from dbt.task.debug import DebugTask
from dbt.adapters.base.query_headers import MacroQueryStringSetter
@@ -16,32 +15,38 @@ from dbt.exceptions import DbtValidationError, DbtConfigError
from psycopg2 import extensions as psycopg2_extensions
from psycopg2 import DatabaseError
from .utils import config_from_parts_or_dicts, inject_adapter, mock_connection, TestAdapterConversions, load_internal_manifest_macros, clear_plugin
from .utils import (
config_from_parts_or_dicts,
inject_adapter,
mock_connection,
TestAdapterConversions,
load_internal_manifest_macros,
clear_plugin,
)
class TestPostgresAdapter(unittest.TestCase):
def setUp(self):
project_cfg = {
'name': 'X',
'version': '0.1',
'profile': 'test',
'project-root': '/tmp/dbt/does-not-exist',
'config-version': 2,
"name": "X",
"version": "0.1",
"profile": "test",
"project-root": "/tmp/dbt/does-not-exist",
"config-version": 2,
}
profile_cfg = {
'outputs': {
'test': {
'type': 'postgres',
'dbname': 'postgres',
'user': 'root',
'host': 'thishostshouldnotexist',
'pass': 'password',
'port': 5432,
'schema': 'public',
"outputs": {
"test": {
"type": "postgres",
"dbname": "postgres",
"user": "root",
"host": "thishostshouldnotexist",
"pass": "password",
"port": 5432,
"schema": "public",
}
},
'target': 'test'
"target": "test",
}
self.config = config_from_parts_or_dicts(project_cfg, profile_cfg)
@@ -54,28 +59,27 @@ class TestPostgresAdapter(unittest.TestCase):
inject_adapter(self._adapter, PostgresPlugin)
return self._adapter
@mock.patch('dbt.adapters.postgres.connections.psycopg2')
@mock.patch("dbt.adapters.postgres.connections.psycopg2")
def test_acquire_connection_validations(self, psycopg2):
try:
connection = self.adapter.acquire_connection('dummy')
connection = self.adapter.acquire_connection("dummy")
except DbtValidationError as e:
self.fail('got DbtValidationError: {}'.format(str(e)))
self.fail("got DbtValidationError: {}".format(str(e)))
except BaseException as e:
self.fail('acquiring connection failed with unknown exception: {}'
.format(str(e)))
self.assertEqual(connection.type, 'postgres')
self.fail("acquiring connection failed with unknown exception: {}".format(str(e)))
self.assertEqual(connection.type, "postgres")
psycopg2.connect.assert_not_called()
connection.handle
psycopg2.connect.assert_called_once()
@mock.patch('dbt.adapters.postgres.connections.psycopg2')
@mock.patch("dbt.adapters.postgres.connections.psycopg2")
def test_acquire_connection(self, psycopg2):
connection = self.adapter.acquire_connection('dummy')
connection = self.adapter.acquire_connection("dummy")
psycopg2.connect.assert_not_called()
connection.handle
self.assertEqual(connection.state, 'open')
self.assertEqual(connection.state, "open")
self.assertNotEqual(connection.handle, None)
psycopg2.connect.assert_called_once()
@@ -84,245 +88,258 @@ class TestPostgresAdapter(unittest.TestCase):
def test_cancel_open_connections_master(self):
key = self.adapter.connections.get_thread_identifier()
self.adapter.connections.thread_connections[key] = mock_connection('master')
self.adapter.connections.thread_connections[key] = mock_connection("master")
self.assertEqual(len(list(self.adapter.cancel_open_connections())), 0)
def test_cancel_open_connections_single(self):
master = mock_connection('master')
model = mock_connection('model')
master = mock_connection("master")
model = mock_connection("model")
key = self.adapter.connections.get_thread_identifier()
model.handle.get_backend_pid.return_value = 42
self.adapter.connections.thread_connections.update({
key: master,
1: model,
})
with mock.patch.object(self.adapter.connections, 'add_query') as add_query:
self.adapter.connections.thread_connections.update(
{
key: master,
1: model,
}
)
with mock.patch.object(self.adapter.connections, "add_query") as add_query:
query_result = mock.MagicMock()
add_query.return_value = (None, query_result)
self.assertEqual(len(list(self.adapter.cancel_open_connections())), 1)
add_query.assert_called_once_with('select pg_terminate_backend(42)')
add_query.assert_called_once_with("select pg_terminate_backend(42)")
master.handle.get_backend_pid.assert_not_called()
@mock.patch('dbt.adapters.postgres.connections.psycopg2')
@mock.patch("dbt.adapters.postgres.connections.psycopg2")
def test_default_connect_timeout(self, psycopg2):
connection = self.adapter.acquire_connection('dummy')
connection = self.adapter.acquire_connection("dummy")
psycopg2.connect.assert_not_called()
connection.handle
psycopg2.connect.assert_called_once_with(
dbname='postgres',
user='root',
host='thishostshouldnotexist',
password='password',
dbname="postgres",
user="root",
host="thishostshouldnotexist",
password="password",
port=5432,
connect_timeout=10,
application_name='dbt')
application_name="dbt",
)
@mock.patch('dbt.adapters.postgres.connections.psycopg2')
@mock.patch("dbt.adapters.postgres.connections.psycopg2")
def test_changed_connect_timeout(self, psycopg2):
self.config.credentials = self.config.credentials.replace(connect_timeout=30)
connection = self.adapter.acquire_connection('dummy')
connection = self.adapter.acquire_connection("dummy")
psycopg2.connect.assert_not_called()
connection.handle
psycopg2.connect.assert_called_once_with(
dbname='postgres',
user='root',
host='thishostshouldnotexist',
password='password',
dbname="postgres",
user="root",
host="thishostshouldnotexist",
password="password",
port=5432,
connect_timeout=30,
application_name='dbt')
application_name="dbt",
)
@mock.patch('dbt.adapters.postgres.connections.psycopg2')
@mock.patch("dbt.adapters.postgres.connections.psycopg2")
def test_default_keepalive(self, psycopg2):
connection = self.adapter.acquire_connection('dummy')
connection = self.adapter.acquire_connection("dummy")
psycopg2.connect.assert_not_called()
connection.handle
psycopg2.connect.assert_called_once_with(
dbname='postgres',
user='root',
host='thishostshouldnotexist',
password='password',
dbname="postgres",
user="root",
host="thishostshouldnotexist",
password="password",
port=5432,
connect_timeout=10,
application_name='dbt')
application_name="dbt",
)
@mock.patch('dbt.adapters.postgres.connections.psycopg2')
@mock.patch("dbt.adapters.postgres.connections.psycopg2")
def test_changed_keepalive(self, psycopg2):
self.config.credentials = self.config.credentials.replace(keepalives_idle=256)
connection = self.adapter.acquire_connection('dummy')
connection = self.adapter.acquire_connection("dummy")
psycopg2.connect.assert_not_called()
connection.handle
psycopg2.connect.assert_called_once_with(
dbname='postgres',
user='root',
host='thishostshouldnotexist',
password='password',
dbname="postgres",
user="root",
host="thishostshouldnotexist",
password="password",
port=5432,
connect_timeout=10,
keepalives_idle=256,
application_name='dbt')
application_name="dbt",
)
@mock.patch('dbt.adapters.postgres.connections.psycopg2')
@mock.patch("dbt.adapters.postgres.connections.psycopg2")
def test_default_application_name(self, psycopg2):
connection = self.adapter.acquire_connection('dummy')
connection = self.adapter.acquire_connection("dummy")
psycopg2.connect.assert_not_called()
connection.handle
psycopg2.connect.assert_called_once_with(
dbname='postgres',
user='root',
host='thishostshouldnotexist',
password='password',
dbname="postgres",
user="root",
host="thishostshouldnotexist",
password="password",
port=5432,
connect_timeout=10,
application_name='dbt')
application_name="dbt",
)
@mock.patch('dbt.adapters.postgres.connections.psycopg2')
@mock.patch("dbt.adapters.postgres.connections.psycopg2")
def test_changed_application_name(self, psycopg2):
self.config.credentials = self.config.credentials.replace(application_name='myapp')
connection = self.adapter.acquire_connection('dummy')
self.config.credentials = self.config.credentials.replace(application_name="myapp")
connection = self.adapter.acquire_connection("dummy")
psycopg2.connect.assert_not_called()
connection.handle
psycopg2.connect.assert_called_once_with(
dbname='postgres',
user='root',
host='thishostshouldnotexist',
password='password',
dbname="postgres",
user="root",
host="thishostshouldnotexist",
password="password",
port=5432,
connect_timeout=10,
application_name='myapp')
application_name="myapp",
)
@mock.patch('dbt.adapters.postgres.connections.psycopg2')
@mock.patch("dbt.adapters.postgres.connections.psycopg2")
def test_role(self, psycopg2):
self.config.credentials = self.config.credentials.replace(role='somerole')
connection = self.adapter.acquire_connection('dummy')
self.config.credentials = self.config.credentials.replace(role="somerole")
connection = self.adapter.acquire_connection("dummy")
cursor = connection.handle.cursor()
cursor.execute.assert_called_once_with('set role somerole')
cursor.execute.assert_called_once_with("set role somerole")
@mock.patch('dbt.adapters.postgres.connections.psycopg2')
@mock.patch("dbt.adapters.postgres.connections.psycopg2")
def test_search_path(self, psycopg2):
self.config.credentials = self.config.credentials.replace(search_path="test")
connection = self.adapter.acquire_connection('dummy')
connection = self.adapter.acquire_connection("dummy")
psycopg2.connect.assert_not_called()
connection.handle
psycopg2.connect.assert_called_once_with(
dbname='postgres',
user='root',
host='thishostshouldnotexist',
password='password',
dbname="postgres",
user="root",
host="thishostshouldnotexist",
password="password",
port=5432,
connect_timeout=10,
application_name='dbt',
options="-c search_path=test")
application_name="dbt",
options="-c search_path=test",
)
@mock.patch('dbt.adapters.postgres.connections.psycopg2')
@mock.patch("dbt.adapters.postgres.connections.psycopg2")
def test_sslmode(self, psycopg2):
self.config.credentials = self.config.credentials.replace(sslmode="require")
connection = self.adapter.acquire_connection('dummy')
connection = self.adapter.acquire_connection("dummy")
psycopg2.connect.assert_not_called()
connection.handle
psycopg2.connect.assert_called_once_with(
dbname='postgres',
user='root',
host='thishostshouldnotexist',
password='password',
dbname="postgres",
user="root",
host="thishostshouldnotexist",
password="password",
port=5432,
connect_timeout=10,
sslmode="require",
application_name='dbt')
application_name="dbt",
)
@mock.patch('dbt.adapters.postgres.connections.psycopg2')
@mock.patch("dbt.adapters.postgres.connections.psycopg2")
def test_ssl_parameters(self, psycopg2):
self.config.credentials = self.config.credentials.replace(sslmode="verify-ca")
self.config.credentials = self.config.credentials.replace(sslcert="service.crt")
self.config.credentials = self.config.credentials.replace(sslkey="service.key")
self.config.credentials = self.config.credentials.replace(sslrootcert="ca.crt")
connection = self.adapter.acquire_connection('dummy')
connection = self.adapter.acquire_connection("dummy")
psycopg2.connect.assert_not_called()
connection.handle
psycopg2.connect.assert_called_once_with(
dbname='postgres',
user='root',
host='thishostshouldnotexist',
password='password',
dbname="postgres",
user="root",
host="thishostshouldnotexist",
password="password",
port=5432,
connect_timeout=10,
sslmode="verify-ca",
sslcert="service.crt",
sslkey="service.key",
sslrootcert="ca.crt",
application_name='dbt')
application_name="dbt",
)
@mock.patch('dbt.adapters.postgres.connections.psycopg2')
@mock.patch("dbt.adapters.postgres.connections.psycopg2")
def test_schema_with_space(self, psycopg2):
self.config.credentials = self.config.credentials.replace(search_path="test test")
connection = self.adapter.acquire_connection('dummy')
connection = self.adapter.acquire_connection("dummy")
psycopg2.connect.assert_not_called()
connection.handle
psycopg2.connect.assert_called_once_with(
dbname='postgres',
user='root',
host='thishostshouldnotexist',
password='password',
dbname="postgres",
user="root",
host="thishostshouldnotexist",
password="password",
port=5432,
connect_timeout=10,
application_name='dbt',
options="-c search_path=test\ test")
application_name="dbt",
options="-c search_path=test\ test", # noqa: [W605]
)
@mock.patch('dbt.adapters.postgres.connections.psycopg2')
@mock.patch("dbt.adapters.postgres.connections.psycopg2")
def test_set_zero_keepalive(self, psycopg2):
self.config.credentials = self.config.credentials.replace(keepalives_idle=0)
connection = self.adapter.acquire_connection('dummy')
connection = self.adapter.acquire_connection("dummy")
psycopg2.connect.assert_not_called()
connection.handle
psycopg2.connect.assert_called_once_with(
dbname='postgres',
user='root',
host='thishostshouldnotexist',
password='password',
dbname="postgres",
user="root",
host="thishostshouldnotexist",
password="password",
port=5432,
connect_timeout=10,
application_name='dbt')
application_name="dbt",
)
@mock.patch.object(PostgresAdapter, 'execute_macro')
@mock.patch.object(PostgresAdapter, '_get_catalog_schemas')
@mock.patch.object(PostgresAdapter, "execute_macro")
@mock.patch.object(PostgresAdapter, "_get_catalog_schemas")
def test_get_catalog_various_schemas(self, mock_get_schemas, mock_execute):
column_names = ['table_database', 'table_schema', 'table_name']
column_names = ["table_database", "table_schema", "table_name"]
rows = [
('dbt', 'foo', 'bar'),
('dbt', 'FOO', 'baz'),
('dbt', None, 'bar'),
('dbt', 'quux', 'bar'),
('dbt', 'skip', 'bar'),
("dbt", "foo", "bar"),
("dbt", "FOO", "baz"),
("dbt", None, "bar"),
("dbt", "quux", "bar"),
("dbt", "skip", "bar"),
]
mock_execute.return_value = agate.Table(rows=rows,
column_names=column_names)
mock_execute.return_value = agate.Table(rows=rows, column_names=column_names)
mock_get_schemas.return_value.items.return_value = [(mock.MagicMock(database='dbt'), {'foo', 'FOO', 'quux'})]
mock_get_schemas.return_value.items.return_value = [
(mock.MagicMock(database="dbt"), {"foo", "FOO", "quux"})
]
mock_manifest = mock.MagicMock()
mock_manifest.get_used_schemas.return_value = {('dbt', 'foo'),
('dbt', 'quux')}
mock_manifest.get_used_schemas.return_value = {("dbt", "foo"), ("dbt", "quux")}
catalog, exceptions = self.adapter.get_catalog(mock_manifest)
self.assertEqual(
set(map(tuple, catalog)),
{('dbt', 'foo', 'bar'), ('dbt', 'FOO', 'baz'), ('dbt', 'quux', 'bar')}
{("dbt", "foo", "bar"), ("dbt", "FOO", "baz"), ("dbt", "quux", "bar")},
)
self.assertEqual(exceptions, [])
@@ -330,31 +347,31 @@ class TestPostgresAdapter(unittest.TestCase):
class TestConnectingPostgresAdapter(unittest.TestCase):
def setUp(self):
self.target_dict = {
'type': 'postgres',
'dbname': 'postgres',
'user': 'root',
'host': 'thishostshouldnotexist',
'pass': 'password',
'port': 5432,
'schema': 'public'
"type": "postgres",
"dbname": "postgres",
"user": "root",
"host": "thishostshouldnotexist",
"pass": "password",
"port": 5432,
"schema": "public",
}
profile_cfg = {
'outputs': {
'test': self.target_dict,
"outputs": {
"test": self.target_dict,
},
'target': 'test'
"target": "test",
}
project_cfg = {
'name': 'X',
'version': '0.1',
'profile': 'test',
'project-root': '/tmp/dbt/does-not-exist',
'quoting': {
'identifier': False,
'schema': True,
"name": "X",
"version": "0.1",
"profile": "test",
"project-root": "/tmp/dbt/does-not-exist",
"quoting": {
"identifier": False,
"schema": True,
},
'config-version': 2,
"config-version": 2,
}
self.config = config_from_parts_or_dicts(project_cfg, profile_cfg)
@@ -362,31 +379,35 @@ class TestConnectingPostgresAdapter(unittest.TestCase):
self.handle = mock.MagicMock(spec=psycopg2_extensions.connection)
self.cursor = self.handle.cursor.return_value
self.mock_execute = self.cursor.execute
self.patcher = mock.patch('dbt.adapters.postgres.connections.psycopg2')
self.patcher = mock.patch("dbt.adapters.postgres.connections.psycopg2")
self.psycopg2 = self.patcher.start()
# Create the Manifest.state_check patcher
@mock.patch('dbt.parser.manifest.ManifestLoader.build_manifest_state_check')
@mock.patch("dbt.parser.manifest.ManifestLoader.build_manifest_state_check")
def _mock_state_check(self):
config = self.root_project
all_projects = self.all_projects
return ManifestStateCheck(
vars_hash=FileHash.from_contents('vars'),
vars_hash=FileHash.from_contents("vars"),
project_hashes={name: FileHash.from_contents(name) for name in all_projects},
profile_hash=FileHash.from_contents('profile'),
profile_hash=FileHash.from_contents("profile"),
)
self.load_state_check = mock.patch('dbt.parser.manifest.ManifestLoader.build_manifest_state_check')
self.load_state_check = mock.patch(
"dbt.parser.manifest.ManifestLoader.build_manifest_state_check"
)
self.mock_state_check = self.load_state_check.start()
self.mock_state_check.side_effect = _mock_state_check
self.psycopg2.connect.return_value = self.handle
self.adapter = PostgresAdapter(self.config)
self.adapter._macro_manifest_lazy = load_internal_manifest_macros(self.config)
self.adapter.connections.query_header = MacroQueryStringSetter(self.config, self.adapter._macro_manifest_lazy)
self.adapter.connections.query_header = MacroQueryStringSetter(
self.config, self.adapter._macro_manifest_lazy
)
self.qh_patch = mock.patch.object(self.adapter.connections.query_header, 'add')
self.qh_patch = mock.patch.object(self.adapter.connections.query_header, "add")
self.mock_query_header_add = self.qh_patch.start()
self.mock_query_header_add.side_effect = lambda q: '/* dbt */\n{}'.format(q)
self.mock_query_header_add.side_effect = lambda q: "/* dbt */\n{}".format(q)
self.adapter.acquire_connection()
inject_adapter(self.adapter, PostgresPlugin)
@@ -400,73 +421,79 @@ class TestConnectingPostgresAdapter(unittest.TestCase):
def test_quoting_on_drop_schema(self):
relation = self.adapter.Relation.create(
database='postgres', schema='test_schema',
database="postgres",
schema="test_schema",
quote_policy=self.adapter.config.quoting,
)
self.adapter.drop_schema(relation)
self.mock_execute.assert_has_calls([
mock.call('/* dbt */\ndrop schema if exists "test_schema" cascade', None)
])
self.mock_execute.assert_has_calls(
[mock.call('/* dbt */\ndrop schema if exists "test_schema" cascade', None)]
)
def test_quoting_on_drop(self):
relation = self.adapter.Relation.create(
database='postgres',
schema='test_schema',
identifier='test_table',
type='table',
database="postgres",
schema="test_schema",
identifier="test_table",
type="table",
quote_policy=self.adapter.config.quoting,
)
self.adapter.drop_relation(relation)
self.mock_execute.assert_has_calls([
mock.call('/* dbt */\ndrop table if exists "postgres"."test_schema".test_table cascade', None)
])
self.mock_execute.assert_has_calls(
[
mock.call(
'/* dbt */\ndrop table if exists "postgres"."test_schema".test_table cascade',
None,
)
]
)
def test_quoting_on_truncate(self):
relation = self.adapter.Relation.create(
database='postgres',
schema='test_schema',
identifier='test_table',
type='table',
database="postgres",
schema="test_schema",
identifier="test_table",
type="table",
quote_policy=self.adapter.config.quoting,
)
self.adapter.truncate_relation(relation)
self.mock_execute.assert_has_calls([
mock.call('/* dbt */\ntruncate table "postgres"."test_schema".test_table', None)
])
self.mock_execute.assert_has_calls(
[mock.call('/* dbt */\ntruncate table "postgres"."test_schema".test_table', None)]
)
def test_quoting_on_rename(self):
from_relation = self.adapter.Relation.create(
database='postgres',
schema='test_schema',
identifier='table_a',
type='table',
database="postgres",
schema="test_schema",
identifier="table_a",
type="table",
quote_policy=self.adapter.config.quoting,
)
to_relation = self.adapter.Relation.create(
database='postgres',
schema='test_schema',
identifier='table_b',
type='table',
database="postgres",
schema="test_schema",
identifier="table_b",
type="table",
quote_policy=self.adapter.config.quoting,
)
self.adapter.rename_relation(
from_relation=from_relation,
to_relation=to_relation
self.adapter.rename_relation(from_relation=from_relation, to_relation=to_relation)
self.mock_execute.assert_has_calls(
[
mock.call(
'/* dbt */\nalter table "postgres"."test_schema".table_a rename to table_b',
None,
)
]
)
self.mock_execute.assert_has_calls([
mock.call('/* dbt */\nalter table "postgres"."test_schema".table_a rename to table_b', None)
])
def test_debug_connection_ok(self):
DebugTask.validate_connection(self.target_dict)
self.mock_execute.assert_has_calls([
mock.call('/* dbt */\nselect 1 as id', None)
])
self.mock_execute.assert_has_calls([mock.call("/* dbt */\nselect 1 as id", None)])
def test_debug_connection_fail_nopass(self):
del self.target_dict['pass']
del self.target_dict["pass"]
with self.assertRaises(DbtConfigError):
DebugTask.validate_connection(self.target_dict)
@@ -474,113 +501,115 @@ class TestConnectingPostgresAdapter(unittest.TestCase):
self.mock_execute.side_effect = DatabaseError()
with self.assertRaises(DbtConfigError):
DebugTask.validate_connection(self.target_dict)
self.mock_execute.assert_has_calls([
mock.call('/* dbt */\nselect 1 as id', None)
])
self.mock_execute.assert_has_calls([mock.call("/* dbt */\nselect 1 as id", None)])
def test_dbname_verification_is_case_insensitive(self):
# Override adapter settings from setUp()
self.target_dict['dbname'] = 'Postgres'
self.target_dict["dbname"] = "Postgres"
profile_cfg = {
'outputs': {
'test': self.target_dict,
"outputs": {
"test": self.target_dict,
},
'target': 'test'
"target": "test",
}
project_cfg = {
'name': 'X',
'version': '0.1',
'profile': 'test',
'project-root': '/tmp/dbt/does-not-exist',
'quoting': {
'identifier': False,
'schema': True,
"name": "X",
"version": "0.1",
"profile": "test",
"project-root": "/tmp/dbt/does-not-exist",
"quoting": {
"identifier": False,
"schema": True,
},
'config-version': 2,
"config-version": 2,
}
self.config = config_from_parts_or_dicts(project_cfg, profile_cfg)
self.adapter.cleanup_connections()
self._adapter = PostgresAdapter(self.config)
self.adapter.verify_database('postgres')
self.adapter.verify_database("postgres")
class TestPostgresFilterCatalog(unittest.TestCase):
def test__catalog_filter_table(self):
manifest = mock.MagicMock()
manifest.get_used_schemas.return_value = [['a', 'B'], ['a', '1234']]
column_names = ['table_name', 'table_database', 'table_schema', 'something']
manifest.get_used_schemas.return_value = [["a", "B"], ["a", "1234"]]
column_names = ["table_name", "table_database", "table_schema", "something"]
rows = [
['foo', 'a', 'b', '1234'], # include
['foo', 'a', '1234', '1234'], # include, w/ table schema as str
['foo', 'c', 'B', '1234'], # skip
['1234', 'A', 'B', '1234'], # include, w/ table name as str
["foo", "a", "b", "1234"], # include
["foo", "a", "1234", "1234"], # include, w/ table schema as str
["foo", "c", "B", "1234"], # skip
["1234", "A", "B", "1234"], # include, w/ table name as str
]
table = agate.Table(
rows, column_names, agate_helper.DEFAULT_TYPE_TESTER
)
table = agate.Table(rows, column_names, agate_helper.DEFAULT_TYPE_TESTER)
result = PostgresAdapter._catalog_filter_table(table, manifest)
assert len(result) == 3
for row in result.rows:
assert isinstance(row['table_schema'], str)
assert isinstance(row['table_database'], str)
assert isinstance(row['table_name'], str)
assert isinstance(row['something'], decimal.Decimal)
assert isinstance(row["table_schema"], str)
assert isinstance(row["table_database"], str)
assert isinstance(row["table_name"], str)
assert isinstance(row["something"], decimal.Decimal)
class TestPostgresAdapterConversions(TestAdapterConversions):
def test_convert_text_type(self):
rows = [
['', 'a1', 'stringval1'],
['', 'a2', 'stringvalasdfasdfasdfa'],
['', 'a3', 'stringval3'],
["", "a1", "stringval1"],
["", "a2", "stringvalasdfasdfasdfa"],
["", "a3", "stringval3"],
]
agate_table = self._make_table_of(rows, agate.Text)
expected = ['text', 'text', 'text']
expected = ["text", "text", "text"]
for col_idx, expect in enumerate(expected):
assert PostgresAdapter.convert_text_type(agate_table, col_idx) == expect
def test_convert_number_type(self):
rows = [
['', '23.98', '-1'],
['', '12.78', '-2'],
['', '79.41', '-3'],
["", "23.98", "-1"],
["", "12.78", "-2"],
["", "79.41", "-3"],
]
agate_table = self._make_table_of(rows, agate.Number)
expected = ['integer', 'float8', 'integer']
expected = ["integer", "float8", "integer"]
for col_idx, expect in enumerate(expected):
assert PostgresAdapter.convert_number_type(agate_table, col_idx) == expect
def test_convert_boolean_type(self):
rows = [
['', 'false', 'true'],
['', 'false', 'false'],
['', 'false', 'true'],
["", "false", "true"],
["", "false", "false"],
["", "false", "true"],
]
agate_table = self._make_table_of(rows, agate.Boolean)
expected = ['boolean', 'boolean', 'boolean']
expected = ["boolean", "boolean", "boolean"]
for col_idx, expect in enumerate(expected):
assert PostgresAdapter.convert_boolean_type(agate_table, col_idx) == expect
def test_convert_datetime_type(self):
rows = [
['', '20190101T01:01:01Z', '2019-01-01 01:01:01'],
['', '20190102T01:01:01Z', '2019-01-01 01:01:01'],
['', '20190103T01:01:01Z', '2019-01-01 01:01:01'],
["", "20190101T01:01:01Z", "2019-01-01 01:01:01"],
["", "20190102T01:01:01Z", "2019-01-01 01:01:01"],
["", "20190103T01:01:01Z", "2019-01-01 01:01:01"],
]
agate_table = self._make_table_of(
rows, [agate.DateTime, agate_helper.ISODateTime, agate.DateTime]
)
expected = [
"timestamp without time zone",
"timestamp without time zone",
"timestamp without time zone",
]
agate_table = self._make_table_of(rows, [agate.DateTime, agate_helper.ISODateTime, agate.DateTime])
expected = ['timestamp without time zone', 'timestamp without time zone', 'timestamp without time zone']
for col_idx, expect in enumerate(expected):
assert PostgresAdapter.convert_datetime_type(agate_table, col_idx) == expect
def test_convert_date_type(self):
rows = [
['', '2019-01-01', '2019-01-04'],
['', '2019-01-02', '2019-01-04'],
['', '2019-01-03', '2019-01-04'],
["", "2019-01-01", "2019-01-04"],
["", "2019-01-02", "2019-01-04"],
["", "2019-01-03", "2019-01-04"],
]
agate_table = self._make_table_of(rows, agate.Date)
expected = ['date', 'date', 'date']
expected = ["date", "date", "date"]
for col_idx, expect in enumerate(expected):
assert PostgresAdapter.convert_date_type(agate_table, col_idx) == expect
@@ -588,11 +617,11 @@ class TestPostgresAdapterConversions(TestAdapterConversions):
# dbt's default type testers actually don't have a TimeDelta at all.
agate.TimeDelta
rows = [
['', '120s', '10s'],
['', '3m', '11s'],
['', '1h', '12s'],
["", "120s", "10s"],
["", "3m", "11s"],
["", "1h", "12s"],
]
agate_table = self._make_table_of(rows, agate.TimeDelta)
expected = ['time', 'time', 'time']
expected = ["time", "time", "time"]
for col_idx, expect in enumerate(expected):
assert PostgresAdapter.convert_time_type(agate_table, col_idx) == expect

View File

@@ -7,27 +7,26 @@ from test.unit.utils import config_from_parts_or_dicts
class TestQueryHeaders(TestCase):
def setUp(self):
self.profile_cfg = {
'outputs': {
'test': {
'type': 'postgres',
'dbname': 'postgres',
'user': 'test',
'host': 'test',
'pass': 'test',
'port': 5432,
'schema': 'test'
"outputs": {
"test": {
"type": "postgres",
"dbname": "postgres",
"user": "test",
"host": "test",
"pass": "test",
"port": 5432,
"schema": "test",
},
},
'target': 'test'
"target": "test",
}
self.project_cfg = {
'name': 'query_headers',
'version': '0.1',
'profile': 'test',
'config-version': 2,
"name": "query_headers",
"version": "0.1",
"profile": "test",
"config-version": 2,
}
self.query = "SELECT 1;"
@@ -35,25 +34,17 @@ class TestQueryHeaders(TestCase):
config = config_from_parts_or_dicts(self.project_cfg, self.profile_cfg)
query_header = MacroQueryStringSetter(config, mock.MagicMock(macros={}))
sql = query_header.add(self.query)
self.assertTrue(re.match(f'^\/\*.*\*\/\n{self.query}$', sql))
self.assertTrue(re.match(f"^\/\*.*\*\/\n{self.query}$", sql)) # noqa: [W605]
def test_append_comment(self):
self.project_cfg.update({
'query-comment': {
'comment': 'executed by dbt',
'append': True
}
})
self.project_cfg.update({"query-comment": {"comment": "executed by dbt", "append": True}})
config = config_from_parts_or_dicts(self.project_cfg, self.profile_cfg)
query_header = MacroQueryStringSetter(config, mock.MagicMock(macros={}))
sql = query_header.add(self.query)
self.assertEqual(sql, f'{self.query[:-1]}\n/* executed by dbt */;')
self.assertEqual(sql, f"{self.query[:-1]}\n/* executed by dbt */;")
def test_disable_query_comment(self):
self.project_cfg.update({
'query-comment': ''
})
self.project_cfg.update({"query-comment": ""})
config = config_from_parts_or_dicts(self.project_cfg, self.profile_cfg)
query_header = MacroQueryStringSetter(config, mock.MagicMock(macros={}))
self.assertEqual(query_header.add(self.query), self.query)

View File

@@ -3,7 +3,8 @@ import unittest
from dbt.exceptions import ConnectionError
from dbt.clients.registry import _get_with_retries
class testRegistryGetRequestException(unittest.TestCase):
def test_registry_request_error_catching(self):
# using non routable IP to test connection error logic in the _get_with_retries function
self.assertRaises(ConnectionError, _get_with_retries, '', 'http://0.0.0.0')
self.assertRaises(ConnectionError, _get_with_retries, "", "http://0.0.0.0")

View File

@@ -2,9 +2,7 @@ import dbt.exceptions
import textwrap
import yaml
import unittest
from dbt.config.selectors import (
selector_config_from_data
)
from dbt.config.selectors import selector_config_from_data
from dbt.config.selectors import SelectorConfig
@@ -16,9 +14,9 @@ def get_selector_dict(txt: str) -> dict:
class SelectorUnitTest(unittest.TestCase):
def test_parse_multiple_excludes(self):
dct = get_selector_dict('''\
dct = get_selector_dict(
"""\
selectors:
- name: mult_excl
definition:
@@ -31,15 +29,16 @@ class SelectorUnitTest(unittest.TestCase):
- exclude:
- method: tag
value: daily
''')
"""
)
with self.assertRaisesRegex(
dbt.exceptions.DbtSelectorsError,
'cannot provide multiple exclude arguments'
dbt.exceptions.DbtSelectorsError, "cannot provide multiple exclude arguments"
):
selector_config_from_data(dct)
def test_parse_set_op_plus(self):
dct = get_selector_dict('''\
dct = get_selector_dict(
"""\
selectors:
- name: union_plus
definition:
@@ -51,30 +50,32 @@ class SelectorUnitTest(unittest.TestCase):
value: hourly
- method: tag
value: foo
''')
"""
)
with self.assertRaisesRegex(
dbt.exceptions.DbtSelectorsError,
'Valid root-level selector definitions'
dbt.exceptions.DbtSelectorsError, "Valid root-level selector definitions"
):
selector_config_from_data(dct)
def test_parse_multiple_methods(self):
dct = get_selector_dict('''\
dct = get_selector_dict(
"""\
selectors:
- name: mult_methods
definition:
- tag:hourly
- tag:nightly
- fqn:start
''')
"""
)
with self.assertRaisesRegex(
dbt.exceptions.DbtSelectorsError,
'Valid root-level selector definitions'
dbt.exceptions.DbtSelectorsError, "Valid root-level selector definitions"
):
selector_config_from_data(dct)
def test_parse_set_with_method(self):
dct = get_selector_dict('''\
dct = get_selector_dict(
"""\
selectors:
- name: mixed_syntaxes
definition:
@@ -87,15 +88,17 @@ class SelectorUnitTest(unittest.TestCase):
- exclude:
- method: tag
value: m5678
''')
"""
)
with self.assertRaisesRegex(
dbt.exceptions.DbtSelectorsError,
"Only a single 'union' or 'intersection' key is allowed"
dbt.exceptions.DbtSelectorsError,
"Only a single 'union' or 'intersection' key is allowed",
):
selector_config_from_data(dct)
def test_complex_sector(self):
dct = get_selector_dict('''\
dct = get_selector_dict(
"""\
selectors:
- name: nightly_diet_snowplow
definition:
@@ -116,12 +119,14 @@ class SelectorUnitTest(unittest.TestCase):
value: incremental
- method: fqn
value: export_performance_timing
''')
"""
)
selectors = selector_config_from_data(dct)
assert(isinstance(selectors, SelectorConfig))
assert isinstance(selectors, SelectorConfig)
def test_exclude_not_list(self):
dct = get_selector_dict('''\
dct = get_selector_dict(
"""\
selectors:
- name: summa_exclude
definition:
@@ -131,58 +136,54 @@ class SelectorUnitTest(unittest.TestCase):
- exclude:
method: tag
value: daily
''')
with self.assertRaisesRegex(
dbt.exceptions.DbtSelectorsError,
"Expected a list"
):
"""
)
with self.assertRaisesRegex(dbt.exceptions.DbtSelectorsError, "Expected a list"):
selector_config_from_data(dct)
def test_invalid_key(self):
dct = get_selector_dict('''\
dct = get_selector_dict(
"""\
selectors:
- name: summa_nothing
definition:
method: tag
key: nightly
''')
with self.assertRaisesRegex(
dbt.exceptions.DbtSelectorsError,
"Expected either 1 key"
):
"""
)
with self.assertRaisesRegex(dbt.exceptions.DbtSelectorsError, "Expected either 1 key"):
selector_config_from_data(dct)
def test_invalid_single_def(self):
dct = get_selector_dict('''\
dct = get_selector_dict(
"""\
selectors:
- name: summa_nothing
definition:
fubar: tag
''')
with self.assertRaisesRegex(
dbt.exceptions.DbtSelectorsError,
"not a valid method name"
):
"""
)
with self.assertRaisesRegex(dbt.exceptions.DbtSelectorsError, "not a valid method name"):
selector_config_from_data(dct)
def test_method_no_value(self):
dct = get_selector_dict('''\
dct = get_selector_dict(
"""\
selectors:
- name: summa_nothing
definition:
method: tag
''')
with self.assertRaisesRegex(
dbt.exceptions.DbtSelectorsError,
"not a valid method name"
):
"""
)
with self.assertRaisesRegex(dbt.exceptions.DbtSelectorsError, "not a valid method name"):
selector_config_from_data(dct)
def test_multiple_default_true(self):
"""Test selector_config_from_data returns the correct error when multiple
default values are set
"""
dct = get_selector_dict('''\
"""Test selector_config_from_data returns the correct error when multiple
default values are set
"""
dct = get_selector_dict(
"""\
selectors:
- name: summa_nothing
definition:
@@ -194,9 +195,9 @@ class SelectorUnitTest(unittest.TestCase):
method: tag
value: daily
default: true
''')
with self.assertRaisesRegex(
dbt.exceptions.DbtSelectorsError,
'Found multiple selectors with `default: true`:'
):
selector_config_from_data(dct)
"""
)
with self.assertRaisesRegex(
dbt.exceptions.DbtSelectorsError, "Found multiple selectors with `default: true`:"
):
selector_config_from_data(dct)

View File

@@ -3,9 +3,16 @@ import itertools
from typing import List
from dbt.exceptions import VersionsNotCompatibleError
from dbt.semver import VersionSpecifier, UnboundedVersionSpecifier, \
VersionRange, reduce_versions, versions_compatible, \
resolve_to_specific_version, filter_installable
from dbt.semver import (
VersionSpecifier,
UnboundedVersionSpecifier,
VersionRange,
reduce_versions,
versions_compatible,
resolve_to_specific_version,
filter_installable,
)
def semver_regex_versioning(versions: List[str]) -> bool:
for version_string in versions:
@@ -15,6 +22,7 @@ def semver_regex_versioning(versions: List[str]) -> bool:
return False
return True
def create_range(start_version_string, end_version_string):
start = UnboundedVersionSpecifier()
end = UnboundedVersionSpecifier()
@@ -29,14 +37,11 @@ def create_range(start_version_string, end_version_string):
class TestSemver(unittest.TestCase):
def assertVersionSetResult(self, inputs, output_range):
expected = create_range(*output_range)
for permutation in itertools.permutations(inputs):
self.assertEqual(
reduce_versions(*permutation),
expected)
self.assertEqual(reduce_versions(*permutation), expected)
def assertInvalidVersionSet(self, inputs):
for permutation in itertools.permutations(inputs):
@@ -44,181 +49,250 @@ class TestSemver(unittest.TestCase):
reduce_versions(*permutation)
def test__versions_compatible(self):
self.assertTrue(
versions_compatible('0.0.1', '0.0.1'))
self.assertFalse(
versions_compatible('0.0.1', '0.0.2'))
self.assertTrue(
versions_compatible('>0.0.1', '0.0.2'))
self.assertFalse(
versions_compatible('0.4.5a1', '0.4.5a2'))
self.assertTrue(versions_compatible("0.0.1", "0.0.1"))
self.assertFalse(versions_compatible("0.0.1", "0.0.2"))
self.assertTrue(versions_compatible(">0.0.1", "0.0.2"))
self.assertFalse(versions_compatible("0.4.5a1", "0.4.5a2"))
def test__semver_regex_versions(self):
self.assertTrue(semver_regex_versioning(
['0.0.4','1.2.3','10.20.30','1.1.2-prerelease+meta','1.1.2+meta','1.1.2+meta-valid',
'1.0.0-alpha','1.0.0-beta','1.0.0-alpha.beta','1.0.0-alpha.beta.1','1.0.0-alpha.1',
'1.0.0-alpha0.valid','1.0.0-alpha.0valid','1.0.0-alpha-a.b-c-somethinglong+build.1-aef.1-its-okay',
'1.0.0-rc.1+build.1','2.0.0-rc.1+build.123','1.2.3-beta','10.2.3-DEV-SNAPSHOT','1.2.3-SNAPSHOT-123',
'1.0.0','2.0.0','1.1.7','2.0.0+build.1848','2.0.1-alpha.1227','1.0.0-alpha+beta','1.2.3----RC-SNAPSHOT.12.9.1--.12+788',
'1.2.3----R-S.12.9.1--.12+meta','1.2.3----RC-SNAPSHOT.12.9.1--.12','1.0.0+0.build.1-rc.10000aaa-kk-0.1',
'99999999999999999999999.999999999999999999.99999999999999999','1.0.0-0A.is.legal']))
self.assertTrue(
semver_regex_versioning(
[
"0.0.4",
"1.2.3",
"10.20.30",
"1.1.2-prerelease+meta",
"1.1.2+meta",
"1.1.2+meta-valid",
"1.0.0-alpha",
"1.0.0-beta",
"1.0.0-alpha.beta",
"1.0.0-alpha.beta.1",
"1.0.0-alpha.1",
"1.0.0-alpha0.valid",
"1.0.0-alpha.0valid",
"1.0.0-alpha-a.b-c-somethinglong+build.1-aef.1-its-okay",
"1.0.0-rc.1+build.1",
"2.0.0-rc.1+build.123",
"1.2.3-beta",
"10.2.3-DEV-SNAPSHOT",
"1.2.3-SNAPSHOT-123",
"1.0.0",
"2.0.0",
"1.1.7",
"2.0.0+build.1848",
"2.0.1-alpha.1227",
"1.0.0-alpha+beta",
"1.2.3----RC-SNAPSHOT.12.9.1--.12+788",
"1.2.3----R-S.12.9.1--.12+meta",
"1.2.3----RC-SNAPSHOT.12.9.1--.12",
"1.0.0+0.build.1-rc.10000aaa-kk-0.1",
"99999999999999999999999.999999999999999999.99999999999999999",
"1.0.0-0A.is.legal",
]
)
)
self.assertFalse(semver_regex_versioning(
['1','1.2','1.2.3-0123','1.2.3-0123.0123','1.1.2+.123','+invalid','-invalid','-invalid+invalid',
'-invalid.01','alpha','alpha.beta','alpha.beta.1','alpha.1','alpha+beta','alpha_beta','alpha.',
'alpha..','beta','1.0.0-alpha_beta','-alpha.','1.0.0-alpha..','1.0.0-alpha..1','1.0.0-alpha...1',
'1.0.0-alpha....1','1.0.0-alpha.....1','1.0.0-alpha......1','1.0.0-alpha.......1','01.1.1','1.01.1',
'1.1.01','1.2','1.2.3.DEV','1.2-SNAPSHOT','1.2.31.2.3----RC-SNAPSHOT.12.09.1--..12+788','1.2-RC-SNAPSHOT',
'-1.0.3-gamma+b7718','+justmeta','9.8.7+meta+meta','9.8.7-whatever+meta+meta',
'99999999999999999999999.999999999999999999.99999999999999999----RC-SNAPSHOT.12.09.1--------------------------------..12']))
self.assertFalse(
semver_regex_versioning(
[
"1",
"1.2",
"1.2.3-0123",
"1.2.3-0123.0123",
"1.1.2+.123",
"+invalid",
"-invalid",
"-invalid+invalid",
"-invalid.01",
"alpha",
"alpha.beta",
"alpha.beta.1",
"alpha.1",
"alpha+beta",
"alpha_beta",
"alpha.",
"alpha..",
"beta",
"1.0.0-alpha_beta",
"-alpha.",
"1.0.0-alpha..",
"1.0.0-alpha..1",
"1.0.0-alpha...1",
"1.0.0-alpha....1",
"1.0.0-alpha.....1",
"1.0.0-alpha......1",
"1.0.0-alpha.......1",
"01.1.1",
"1.01.1",
"1.1.01",
"1.2",
"1.2.3.DEV",
"1.2-SNAPSHOT",
"1.2.31.2.3----RC-SNAPSHOT.12.09.1--..12+788",
"1.2-RC-SNAPSHOT",
"-1.0.3-gamma+b7718",
"+justmeta",
"9.8.7+meta+meta",
"9.8.7-whatever+meta+meta",
"99999999999999999999999.999999999999999999.99999999999999999----RC-SNAPSHOT.12.09.1--------------------------------..12",
]
)
)
def test__reduce_versions(self):
self.assertVersionSetResult(
['0.0.1', '0.0.1'],
['=0.0.1', '=0.0.1'])
self.assertVersionSetResult(["0.0.1", "0.0.1"], ["=0.0.1", "=0.0.1"])
self.assertVersionSetResult(
['0.0.1'],
['=0.0.1', '=0.0.1'])
self.assertVersionSetResult(["0.0.1"], ["=0.0.1", "=0.0.1"])
self.assertVersionSetResult(
['>0.0.1'],
['>0.0.1', None])
self.assertVersionSetResult([">0.0.1"], [">0.0.1", None])
self.assertVersionSetResult(
['<0.0.1'],
[None, '<0.0.1'])
self.assertVersionSetResult(["<0.0.1"], [None, "<0.0.1"])
self.assertVersionSetResult(
['>0.0.1', '0.0.2'],
['=0.0.2', '=0.0.2'])
self.assertVersionSetResult([">0.0.1", "0.0.2"], ["=0.0.2", "=0.0.2"])
self.assertVersionSetResult(
['0.0.2', '>=0.0.2'],
['=0.0.2', '=0.0.2'])
self.assertVersionSetResult(["0.0.2", ">=0.0.2"], ["=0.0.2", "=0.0.2"])
self.assertVersionSetResult(
['>0.0.1', '>0.0.2', '>0.0.3'],
['>0.0.3', None])
self.assertVersionSetResult([">0.0.1", ">0.0.2", ">0.0.3"], [">0.0.3", None])
self.assertVersionSetResult(
['>0.0.1', '<0.0.3'],
['>0.0.1', '<0.0.3'])
self.assertVersionSetResult([">0.0.1", "<0.0.3"], [">0.0.1", "<0.0.3"])
self.assertVersionSetResult(
['>0.0.1', '0.0.2', '<0.0.3'],
['=0.0.2', '=0.0.2'])
self.assertVersionSetResult([">0.0.1", "0.0.2", "<0.0.3"], ["=0.0.2", "=0.0.2"])
self.assertVersionSetResult(
['>0.0.1', '>=0.0.1', '<0.0.3'],
['>0.0.1', '<0.0.3'])
self.assertVersionSetResult([">0.0.1", ">=0.0.1", "<0.0.3"], [">0.0.1", "<0.0.3"])
self.assertVersionSetResult(
['>0.0.1', '<0.0.3', '<=0.0.3'],
['>0.0.1', '<0.0.3'])
self.assertVersionSetResult([">0.0.1", "<0.0.3", "<=0.0.3"], [">0.0.1", "<0.0.3"])
self.assertVersionSetResult(
['>0.0.1', '>0.0.2', '<0.0.3', '<0.0.4'],
['>0.0.2', '<0.0.3'])
self.assertVersionSetResult([">0.0.1", ">0.0.2", "<0.0.3", "<0.0.4"], [">0.0.2", "<0.0.3"])
self.assertVersionSetResult(
['<=0.0.3', '>=0.0.3'],
['>=0.0.3', '<=0.0.3'])
self.assertVersionSetResult(["<=0.0.3", ">=0.0.3"], [">=0.0.3", "<=0.0.3"])
self.assertInvalidVersionSet(['>0.0.2', '0.0.1'])
self.assertInvalidVersionSet(['>0.0.2', '0.0.2'])
self.assertInvalidVersionSet(['<0.0.2', '0.0.2'])
self.assertInvalidVersionSet(['<0.0.2', '>0.0.3'])
self.assertInvalidVersionSet(['<=0.0.3', '>0.0.3'])
self.assertInvalidVersionSet(['<0.0.3', '>=0.0.3'])
self.assertInvalidVersionSet(['<0.0.3', '>0.0.3'])
self.assertInvalidVersionSet([">0.0.2", "0.0.1"])
self.assertInvalidVersionSet([">0.0.2", "0.0.2"])
self.assertInvalidVersionSet(["<0.0.2", "0.0.2"])
self.assertInvalidVersionSet(["<0.0.2", ">0.0.3"])
self.assertInvalidVersionSet(["<=0.0.3", ">0.0.3"])
self.assertInvalidVersionSet(["<0.0.3", ">=0.0.3"])
self.assertInvalidVersionSet(["<0.0.3", ">0.0.3"])
def test__resolve_to_specific_version(self):
self.assertEqual(
resolve_to_specific_version(
create_range('>0.0.1', None),
['0.0.1', '0.0.2']),
'0.0.2')
resolve_to_specific_version(create_range(">0.0.1", None), ["0.0.1", "0.0.2"]), "0.0.2"
)
self.assertEqual(
resolve_to_specific_version(create_range(">=0.0.2", None), ["0.0.1", "0.0.2"]), "0.0.2"
)
self.assertEqual(
resolve_to_specific_version(create_range(">=0.0.3", None), ["0.0.1", "0.0.2"]), None
)
self.assertEqual(
resolve_to_specific_version(
create_range('>=0.0.2', None),
['0.0.1', '0.0.2']),
'0.0.2')
create_range(">=0.0.3", "<0.0.5"), ["0.0.3", "0.0.4", "0.0.5"]
),
"0.0.4",
)
self.assertEqual(
resolve_to_specific_version(
create_range('>=0.0.3', None),
['0.0.1', '0.0.2']),
None)
create_range(None, "<=0.0.5"), ["0.0.3", "0.1.4", "0.0.5"]
),
"0.0.5",
)
self.assertEqual(
resolve_to_specific_version(
create_range('>=0.0.3', '<0.0.5'),
['0.0.3', '0.0.4', '0.0.5']),
'0.0.4')
create_range("=0.4.5a2", "=0.4.5a2"), ["0.4.5a1", "0.4.5a2"]
),
"0.4.5a2",
)
self.assertEqual(
resolve_to_specific_version(create_range("=0.7.6", "=0.7.6"), ["0.7.6-b1", "0.7.6"]),
"0.7.6",
)
self.assertEqual(
resolve_to_specific_version(
create_range(None, '<=0.0.5'),
['0.0.3', '0.1.4', '0.0.5']),
'0.0.5')
create_range(">=1.0.0", None), ["1.0.0", "1.1.0a1", "1.1.0", "1.2.0a1"]
),
"1.2.0a1",
)
self.assertEqual(
resolve_to_specific_version(
create_range('=0.4.5a2', '=0.4.5a2'),
['0.4.5a1', '0.4.5a2']),
'0.4.5a2')
create_range(">=1.0.0", "<1.2.0"), ["1.0.0", "1.1.0a1", "1.1.0", "1.2.0a1"]
),
"1.1.0",
)
self.assertEqual(
resolve_to_specific_version(
create_range('=0.7.6', '=0.7.6'),
['0.7.6-b1', '0.7.6']),
'0.7.6')
create_range(">=1.0.0", None), ["1.0.0", "1.1.0a1", "1.1.0", "1.2.0a1", "1.2.0"]
),
"1.2.0",
)
self.assertEqual(
resolve_to_specific_version(
create_range('>=1.0.0', None),
['1.0.0', '1.1.0a1', '1.1.0', '1.2.0a1']),
'1.2.0a1')
self.assertEqual(
resolve_to_specific_version(
create_range('>=1.0.0', '<1.2.0'),
['1.0.0', '1.1.0a1', '1.1.0', '1.2.0a1']),
'1.1.0')
self.assertEqual(
resolve_to_specific_version(
create_range('>=1.0.0', None),
['1.0.0', '1.1.0a1', '1.1.0', '1.2.0a1', '1.2.0']),
'1.2.0')
self.assertEqual(
resolve_to_specific_version(
create_range('>=1.0.0', '<1.2.0'),
['1.0.0', '1.1.0a1', '1.1.0', '1.2.0a1', '1.2.0']),
'1.1.0')
create_range(">=1.0.0", "<1.2.0"),
["1.0.0", "1.1.0a1", "1.1.0", "1.2.0a1", "1.2.0"],
),
"1.1.0",
)
self.assertEqual(
resolve_to_specific_version(
# https://github.com/dbt-labs/dbt-core/issues/7039
# 10 is greater than 9
create_range('>0.9.0', '<0.10.0'),
['0.9.0', '0.9.1', '0.10.0']),
'0.9.1')
create_range(">0.9.0", "<0.10.0"),
["0.9.0", "0.9.1", "0.10.0"],
),
"0.9.1",
)
def test__filter_installable(self):
installable = filter_installable(
['1.1.0', '1.2.0a1', '1.0.0','2.1.0-alpha','2.2.0asdf','2.1.0','2.2.0','2.2.0-fishtown-beta','2.2.0-2'],
install_prerelease=True
[
"1.1.0",
"1.2.0a1",
"1.0.0",
"2.1.0-alpha",
"2.2.0asdf",
"2.1.0",
"2.2.0",
"2.2.0-fishtown-beta",
"2.2.0-2",
],
install_prerelease=True,
)
expected = ['1.0.0', '1.1.0', '1.2.0a1','2.1.0-alpha','2.1.0','2.2.0-2','2.2.0asdf','2.2.0-fishtown-beta','2.2.0']
expected = [
"1.0.0",
"1.1.0",
"1.2.0a1",
"2.1.0-alpha",
"2.1.0",
"2.2.0-2",
"2.2.0asdf",
"2.2.0-fishtown-beta",
"2.2.0",
]
assert installable == expected
installable = filter_installable(
['1.1.0', '1.2.0a1', '1.0.0','2.1.0-alpha','2.2.0asdf','2.1.0','2.2.0','2.2.0-fishtown-beta'],
install_prerelease=False
[
"1.1.0",
"1.2.0a1",
"1.0.0",
"2.1.0-alpha",
"2.2.0asdf",
"2.1.0",
"2.2.0",
"2.2.0-fishtown-beta",
],
install_prerelease=False,
)
expected = ['1.0.0', '1.1.0','2.1.0','2.2.0']
expected = ["1.0.0", "1.1.0", "2.1.0", "2.2.0"]
assert installable == expected

View File

@@ -1,18 +1,19 @@
import unittest
from dbt.adapters.sql.connections import SQLConnectionManager
class TestProcessSQLResult(unittest.TestCase):
def test_duplicated_columns(self):
cols_with_one_dupe = ['a', 'b', 'a', 'd']
rows = [(1, 2, 3, 4)]
self.assertEqual(
SQLConnectionManager.process_results(cols_with_one_dupe, rows),
[{"a": 1, "b": 2, "a_2": 3, "d": 4}]
)
cols_with_more_dupes = ['a', 'a', 'a', 'b']
rows = [(1, 2, 3, 4)]
self.assertEqual(
SQLConnectionManager.process_results(cols_with_more_dupes, rows),
[{"a": 1, "a_2": 2, "a_3": 3, "b": 4}]
)
class TestProcessSQLResult(unittest.TestCase):
def test_duplicated_columns(self):
cols_with_one_dupe = ["a", "b", "a", "d"]
rows = [(1, 2, 3, 4)]
self.assertEqual(
SQLConnectionManager.process_results(cols_with_one_dupe, rows),
[{"a": 1, "b": 2, "a_2": 3, "d": 4}],
)
cols_with_more_dupes = ["a", "a", "a", "b"]
rows = [(1, 2, 3, 4)]
self.assertEqual(
SQLConnectionManager.process_results(cols_with_more_dupes, rows),
[{"a": 1, "a_2": 2, "a_3": 3, "b": 4}],
)

View File

@@ -3,7 +3,6 @@ import shutil
import stat
import unittest
import tarfile
import io
import pathspec
from pathlib import Path
from tempfile import mkdtemp, NamedTemporaryFile
@@ -12,46 +11,47 @@ from dbt.exceptions import ExecutableError, WorkingDirectoryError
import dbt.clients.system
class SystemClient(unittest.TestCase):
def setUp(self):
super().setUp()
self.tmp_dir = mkdtemp()
self.profiles_path = '{}/profiles.yml'.format(self.tmp_dir)
self.profiles_path = "{}/profiles.yml".format(self.tmp_dir)
def set_up_profile(self):
with open(self.profiles_path, 'w') as f:
f.write('ORIGINAL_TEXT')
with open(self.profiles_path, "w") as f:
f.write("ORIGINAL_TEXT")
def get_profile_text(self):
with open(self.profiles_path, 'r') as f:
with open(self.profiles_path, "r") as f:
return f.read()
def tearDown(self):
try:
shutil.rmtree(self.tmp_dir)
except:
except Exception as e: # noqa: [F841]
pass
def test__make_file_when_exists(self):
self.set_up_profile()
written = dbt.clients.system.make_file(self.profiles_path, contents='NEW_TEXT')
written = dbt.clients.system.make_file(self.profiles_path, contents="NEW_TEXT")
self.assertFalse(written)
self.assertEqual(self.get_profile_text(), 'ORIGINAL_TEXT')
self.assertEqual(self.get_profile_text(), "ORIGINAL_TEXT")
def test__make_file_when_not_exists(self):
written = dbt.clients.system.make_file(self.profiles_path, contents='NEW_TEXT')
written = dbt.clients.system.make_file(self.profiles_path, contents="NEW_TEXT")
self.assertTrue(written)
self.assertEqual(self.get_profile_text(), 'NEW_TEXT')
self.assertEqual(self.get_profile_text(), "NEW_TEXT")
def test__make_file_with_overwrite(self):
self.set_up_profile()
written = dbt.clients.system.make_file(self.profiles_path, contents='NEW_TEXT', overwrite=True)
written = dbt.clients.system.make_file(
self.profiles_path, contents="NEW_TEXT", overwrite=True
)
self.assertTrue(written)
self.assertEqual(self.get_profile_text(), 'NEW_TEXT')
self.assertEqual(self.get_profile_text(), "NEW_TEXT")
def test__make_dir_from_str(self):
test_dir_str = self.tmp_dir + "/test_make_from_str/sub_dir"
@@ -62,7 +62,6 @@ class SystemClient(unittest.TestCase):
test_dir_pathobj = Path(self.tmp_dir + "/test_make_from_pathobj/sub_dir")
dbt.clients.system.make_directory(test_dir_pathobj)
self.assertTrue(test_dir_pathobj.is_dir())
class TestRunCmd(unittest.TestCase):
@@ -70,20 +69,21 @@ class TestRunCmd(unittest.TestCase):
Don't mock out subprocess, in order to expose any OS-level differences.
"""
not_a_file = 'zzzbbfasdfasdfsdaq'
not_a_file = "zzzbbfasdfasdfsdaq"
def setUp(self):
self.tempdir = mkdtemp()
self.run_dir = os.path.join(self.tempdir, 'run_dir')
self.does_not_exist = os.path.join(self.tempdir, 'does_not_exist')
self.empty_file = os.path.join(self.tempdir, 'empty_file')
if os.name == 'nt':
self.exists_cmd = ['cmd', '/C', 'echo', 'hello']
self.run_dir = os.path.join(self.tempdir, "run_dir")
self.does_not_exist = os.path.join(self.tempdir, "does_not_exist")
self.empty_file = os.path.join(self.tempdir, "empty_file")
if os.name == "nt":
self.exists_cmd = ["cmd", "/C", "echo", "hello"]
else:
self.exists_cmd = ['echo', 'hello']
self.exists_cmd = ["echo", "hello"]
os.mkdir(self.run_dir)
with open(self.empty_file, 'w') as fp:
with open(self.empty_file, "w") as fp: # noqa: [F841]
pass # "touch"
def tearDown(self):
@@ -95,8 +95,8 @@ class TestRunCmd(unittest.TestCase):
msg = str(exc.exception).lower()
self.assertIn('path', msg)
self.assertIn('could not find', msg)
self.assertIn("path", msg)
self.assertIn("could not find", msg)
self.assertIn(self.does_not_exist.lower(), msg)
def test__not_exe(self):
@@ -104,19 +104,19 @@ class TestRunCmd(unittest.TestCase):
dbt.clients.system.run_cmd(self.run_dir, [self.empty_file])
msg = str(exc.exception).lower()
if os.name == 'nt':
if os.name == "nt":
# on windows, this means it's not an executable at all!
self.assertIn('not executable', msg)
self.assertIn("not executable", msg)
else:
# on linux, this means you don't have executable permissions on it
self.assertIn('permissions', msg)
self.assertIn("permissions", msg)
self.assertIn(self.empty_file.lower(), msg)
def test__cwd_does_not_exist(self):
with self.assertRaises(WorkingDirectoryError) as exc:
dbt.clients.system.run_cmd(self.does_not_exist, self.exists_cmd)
msg = str(exc.exception).lower()
self.assertIn('does not exist', msg)
self.assertIn("does not exist", msg)
self.assertIn(self.does_not_exist.lower(), msg)
def test__cwd_not_directory(self):
@@ -124,7 +124,7 @@ class TestRunCmd(unittest.TestCase):
dbt.clients.system.run_cmd(self.empty_file, self.exists_cmd)
msg = str(exc.exception).lower()
self.assertIn('not a directory', msg)
self.assertIn("not a directory", msg)
self.assertIn(self.empty_file.lower(), msg)
def test__cwd_no_permissions(self):
@@ -132,7 +132,7 @@ class TestRunCmd(unittest.TestCase):
# `psexec` (to get SYSTEM privs), use `icacls` to set permissions on
# the directory for the test user. I'm pretty sure windows users can't
# create files that they themselves cannot access.
if os.name == 'nt':
if os.name == "nt":
return
# read-only -> cannot cd to it
@@ -142,114 +142,100 @@ class TestRunCmd(unittest.TestCase):
dbt.clients.system.run_cmd(self.run_dir, self.exists_cmd)
msg = str(exc.exception).lower()
self.assertIn('permissions', msg)
self.assertIn("permissions", msg)
self.assertIn(self.run_dir.lower(), msg)
def test__ok(self):
out, err = dbt.clients.system.run_cmd(self.run_dir, self.exists_cmd)
self.assertEqual(out.strip(), b'hello')
self.assertEqual(err.strip(), b'')
self.assertEqual(out.strip(), b"hello")
self.assertEqual(err.strip(), b"")
class TestFindMatching(unittest.TestCase):
def setUp(self):
self.base_dir = mkdtemp()
self.tempdir = mkdtemp(dir=self.base_dir)
def test_find_matching_lowercase_file_pattern(self):
with NamedTemporaryFile(
prefix='sql-files', suffix='.sql', dir=self.tempdir
) as named_file:
with NamedTemporaryFile(prefix="sql-files", suffix=".sql", dir=self.tempdir) as named_file:
file_path = os.path.dirname(named_file.name)
relative_path = os.path.basename(file_path)
out = dbt.clients.system.find_matching(
self.base_dir,
[relative_path],
'*.sql',
"*.sql",
)
expected_output = [{
'searched_path': relative_path,
'absolute_path': named_file.name,
'relative_path': os.path.basename(named_file.name),
'modification_time': out[0]['modification_time'],
}]
expected_output = [
{
"searched_path": relative_path,
"absolute_path": named_file.name,
"relative_path": os.path.basename(named_file.name),
"modification_time": out[0]["modification_time"],
}
]
self.assertEqual(out, expected_output)
def test_find_matching_uppercase_file_pattern(self):
with NamedTemporaryFile(prefix='sql-files', suffix='.SQL', dir=self.tempdir) as named_file:
with NamedTemporaryFile(prefix="sql-files", suffix=".SQL", dir=self.tempdir) as named_file:
file_path = os.path.dirname(named_file.name)
relative_path = os.path.basename(file_path)
out = dbt.clients.system.find_matching(
self.base_dir,
[relative_path],
'*.sql'
)
expected_output = [{
'searched_path': relative_path,
'absolute_path': named_file.name,
'relative_path': os.path.basename(named_file.name),
'modification_time': out[0]['modification_time'],
}]
out = dbt.clients.system.find_matching(self.base_dir, [relative_path], "*.sql")
expected_output = [
{
"searched_path": relative_path,
"absolute_path": named_file.name,
"relative_path": os.path.basename(named_file.name),
"modification_time": out[0]["modification_time"],
}
]
self.assertEqual(out, expected_output)
def test_find_matching_file_pattern_not_found(self):
with NamedTemporaryFile(
prefix='sql-files', suffix='.SQLT', dir=self.tempdir
):
out = dbt.clients.system.find_matching(
self.tempdir,
[''],
'*.sql'
)
with NamedTemporaryFile(prefix="sql-files", suffix=".SQLT", dir=self.tempdir):
out = dbt.clients.system.find_matching(self.tempdir, [""], "*.sql")
self.assertEqual(out, [])
def test_ignore_spec(self):
with NamedTemporaryFile(
prefix='sql-files', suffix='.sql', dir=self.tempdir
):
with NamedTemporaryFile(prefix="sql-files", suffix=".sql", dir=self.tempdir):
out = dbt.clients.system.find_matching(
self.tempdir,
[''],
'*.sql',
[""],
"*.sql",
pathspec.PathSpec.from_lines(
pathspec.patterns.GitWildMatchPattern, "sql-files*".splitlines()
)
),
)
self.assertEqual(out, [])
def tearDown(self):
try:
shutil.rmtree(self.base_dir)
except:
except Exception as e: # noqa: [F841]
pass
class TestUntarPackage(unittest.TestCase):
def setUp(self):
self.base_dir = mkdtemp()
self.tempdir = mkdtemp(dir=self.base_dir)
self.tempdest = mkdtemp(dir=self.base_dir)
def tearDown(self):
try:
shutil.rmtree(self.base_dir)
except:
except Exception as e: # noqa: [F841]
pass
def test_untar_package_success(self):
# set up a valid tarball to test against
with NamedTemporaryFile(
prefix='my-package.2', suffix='.tar.gz', dir=self.tempdir, delete=False
prefix="my-package.2", suffix=".tar.gz", dir=self.tempdir, delete=False
) as named_tar_file:
tar_file_full_path = named_tar_file.name
with NamedTemporaryFile(
prefix='a', suffix='.txt', dir=self.tempdir
) as file_a:
file_a.write(b'some text in the text file')
with NamedTemporaryFile(prefix="a", suffix=".txt", dir=self.tempdir) as file_a:
file_a.write(b"some text in the text file")
relative_file_a = os.path.basename(file_a.name)
with tarfile.open(fileobj=named_tar_file, mode='w:gz') as tar:
with tarfile.open(fileobj=named_tar_file, mode="w:gz") as tar:
tar.addfile(tarfile.TarInfo(relative_file_a), open(file_a.name))
# now we test can test that we can untar the file successfully
@@ -261,22 +247,22 @@ class TestUntarPackage(unittest.TestCase):
def test_untar_package_failure(self):
# create a text file then rename it as a tar (so it's invalid)
with NamedTemporaryFile(
prefix='a', suffix='.txt', dir=self.tempdir, delete=False
) as file_a:
file_a.write(b'some text in the text file')
txt_file_name = file_a.name
file_path= os.path.dirname(txt_file_name)
tar_file_path = os.path.join(file_path, 'mypackage.2.tar.gz')
prefix="a", suffix=".txt", dir=self.tempdir, delete=False
) as file_a:
file_a.write(b"some text in the text file")
txt_file_name = file_a.name
file_path = os.path.dirname(txt_file_name)
tar_file_path = os.path.join(file_path, "mypackage.2.tar.gz")
os.rename(txt_file_name, tar_file_path)
# now that we're set up, test that untarring the file fails
with self.assertRaises(tarfile.ReadError) as exc:
with self.assertRaises(tarfile.ReadError) as exc: # noqa: [F841]
dbt.clients.system.untar_package(tar_file_path, self.tempdest)
def test_untar_package_empty(self):
# create a tarball with nothing in it
with NamedTemporaryFile(
prefix='my-empty-package.2', suffix='.tar.gz', dir=self.tempdir
prefix="my-empty-package.2", suffix=".tar.gz", dir=self.tempdir
) as named_file:
# make sure we throw an error for the empty file

View File

@@ -3,7 +3,7 @@ import datetime
import shutil
import tempfile
import unittest
from unittest.mock import MagicMock
class TestTracking(unittest.TestCase):
def setUp(self):
@@ -16,10 +16,7 @@ class TestTracking(unittest.TestCase):
def test_tracking_initial(self):
assert dbt.tracking.active_user is None
dbt.tracking.initialize_from_flags(
True,
self.tempdir
)
dbt.tracking.initialize_from_flags(True, self.tempdir)
assert isinstance(dbt.tracking.active_user, dbt.tracking.User)
invocation_id = dbt.tracking.active_user.invocation_id
@@ -77,13 +74,8 @@ class TestTracking(unittest.TestCase):
def test_initialize_from_flags(self):
for send_anonymous_usage_stats in [True, False]:
with self.subTest(
send_anonymous_usage_stats=send_anonymous_usage_stats
):
with self.subTest(send_anonymous_usage_stats=send_anonymous_usage_stats):
dbt.tracking.initialize_from_flags(
send_anonymous_usage_stats,
self.tempdir
)
dbt.tracking.initialize_from_flags(send_anonymous_usage_stats, self.tempdir)
assert dbt.tracking.active_user.do_not_track != send_anonymous_usage_stats

View File

@@ -5,60 +5,64 @@ import dbt.utils
class TestDeepMerge(unittest.TestCase):
def test__simple_cases(self):
cases = [
{'args': [{}, {'a': 1}],
'expected': {'a': 1},
'description': 'one key into empty'},
{'args': [{}, {'b': 1}, {'a': 1}],
'expected': {'a': 1, 'b': 1},
'description': 'three merges'},
{"args": [{}, {"a": 1}], "expected": {"a": 1}, "description": "one key into empty"},
{
"args": [{}, {"b": 1}, {"a": 1}],
"expected": {"a": 1, "b": 1},
"description": "three merges",
},
]
for case in cases:
actual = dbt.utils.deep_merge(*case['args'])
actual = dbt.utils.deep_merge(*case["args"])
self.assertEqual(
case['expected'], actual,
'failed on {} (actual {}, expected {})'.format(
case['description'], actual, case['expected']))
case["expected"],
actual,
"failed on {} (actual {}, expected {})".format(
case["description"], actual, case["expected"]
),
)
class TestMerge(unittest.TestCase):
def test__simple_cases(self):
cases = [
{'args': [{}, {'a': 1}],
'expected': {'a': 1},
'description': 'one key into empty'},
{'args': [{}, {'b': 1}, {'a': 1}],
'expected': {'a': 1, 'b': 1},
'description': 'three merges'},
{"args": [{}, {"a": 1}], "expected": {"a": 1}, "description": "one key into empty"},
{
"args": [{}, {"b": 1}, {"a": 1}],
"expected": {"a": 1, "b": 1},
"description": "three merges",
},
]
for case in cases:
actual = dbt.utils.deep_merge(*case['args'])
actual = dbt.utils.deep_merge(*case["args"])
self.assertEqual(
case['expected'], actual,
'failed on {} (actual {}, expected {})'.format(
case['description'], actual, case['expected']))
case["expected"],
actual,
"failed on {} (actual {}, expected {})".format(
case["description"], actual, case["expected"]
),
)
class TestDeepMap(unittest.TestCase):
def setUp(self):
self.input_value = {
'foo': {
'bar': 'hello',
'baz': [1, 90.5, '990', '89.9'],
"foo": {
"bar": "hello",
"baz": [1, 90.5, "990", "89.9"],
},
'nested': [
"nested": [
{
'test': '90',
'other_test': None,
"test": "90",
"other_test": None,
},
{
'test': 400,
'other_test': 4.7e9,
"test": 400,
"other_test": 4.7e9,
},
],
}
@@ -72,18 +76,18 @@ class TestDeepMap(unittest.TestCase):
def test__simple_cases(self):
expected = {
'foo': {
'bar': -1,
'baz': [1, 90, 990, -1],
"foo": {
"bar": -1,
"baz": [1, 90, 990, -1],
},
'nested': [
"nested": [
{
'test': 90,
'other_test': -1,
"test": 90,
"other_test": -1,
},
{
'test': 400,
'other_test': 4700000000,
"test": 400,
"other_test": 4700000000,
},
],
}
@@ -96,26 +100,26 @@ class TestDeepMap(unittest.TestCase):
@staticmethod
def special_keypath(value, keypath):
if tuple(keypath) == ('foo', 'baz', 1):
return 'hello'
if tuple(keypath) == ("foo", "baz", 1):
return "hello"
else:
return value
def test__keypath(self):
expected = {
'foo': {
'bar': 'hello',
"foo": {
"bar": "hello",
# the only change from input is the second entry here
'baz': [1, 'hello', '990', '89.9'],
"baz": [1, "hello", "990", "89.9"],
},
'nested': [
"nested": [
{
'test': '90',
'other_test': None,
"test": "90",
"other_test": None,
},
{
'test': 400,
'other_test': 4.7e9,
"test": 400,
"other_test": 4.7e9,
},
],
}
@@ -130,52 +134,53 @@ class TestDeepMap(unittest.TestCase):
self.assertEqual(actual, self.input_value)
def test_trivial(self):
cases = [[], {}, 1, 'abc', None, True]
cases = [[], {}, 1, "abc", None, True]
for case in cases:
result = dbt.utils.deep_map_render(lambda x, _: x, case)
self.assertEqual(result, case)
with self.assertRaises(dbt.exceptions.DbtConfigError):
dbt.utils.deep_map_render(lambda x, _: x, {'foo': object()})
dbt.utils.deep_map_render(lambda x, _: x, {"foo": object()})
class TestMultiDict(unittest.TestCase):
def test_one_member(self):
dct = {'a': 1, 'b': 2, 'c': 3}
dct = {"a": 1, "b": 2, "c": 3}
md = dbt.utils.MultiDict([dct])
assert len(md) == 3
for key in 'abc':
for key in "abc":
assert key in md
assert md['a'] == 1
assert md['b'] == 2
assert md['c'] == 3
assert md["a"] == 1
assert md["b"] == 2
assert md["c"] == 3
def test_two_members_no_overlap(self):
first = {'a': 1, 'b': 2, 'c': 3}
second = {'d': 1, 'e': 2, 'f': 3}
first = {"a": 1, "b": 2, "c": 3}
second = {"d": 1, "e": 2, "f": 3}
md = dbt.utils.MultiDict([first, second])
assert len(md) == 6
for key in 'abcdef':
for key in "abcdef":
assert key in md
assert md['a'] == 1
assert md['b'] == 2
assert md['c'] == 3
assert md['d'] == 1
assert md['e'] == 2
assert md['f'] == 3
assert md["a"] == 1
assert md["b"] == 2
assert md["c"] == 3
assert md["d"] == 1
assert md["e"] == 2
assert md["f"] == 3
def test_two_members_overlap(self):
first = {'a': 1, 'b': 2, 'c': 3}
second = {'c': 1, 'd': 2, 'e': 3}
first = {"a": 1, "b": 2, "c": 3}
second = {"c": 1, "d": 2, "e": 3}
md = dbt.utils.MultiDict([first, second])
assert len(md) == 5
for key in 'abcde':
for key in "abcde":
assert key in md
assert md['a'] == 1
assert md['b'] == 2
assert md['c'] == 1
assert md['d'] == 2
assert md['e'] == 3
assert md["a"] == 1
assert md["b"] == 2
assert md["c"] == 1
assert md["d"] == 2
assert md["e"] == 3
class TestHumanizeExecutionTime(unittest.TestCase):
def test_humanzing_execution_time_with_integer(self):

View File

@@ -1,19 +1,16 @@
import unittest
import dbt.exceptions
import dbt.utils
from dbt.parser.schema_renderer import SchemaYamlRenderer
class TestYamlRendering(unittest.TestCase):
def test__models(self):
context = {
"test_var": "1234",
"alt_var": "replaced",
}
renderer = SchemaYamlRenderer(context, 'models')
renderer = SchemaYamlRenderer(context, "models")
# Verify description is not rendered and misc attribute is rendered
dct = {
@@ -31,18 +28,18 @@ class TestYamlRendering(unittest.TestCase):
# Verify description in columns is not rendered
dct = {
'name': 'my_test',
'attribute': "{{ test_var }}",
'columns': [
{'description': "{{ test_var }}", 'name': 'id'},
]
"name": "my_test",
"attribute": "{{ test_var }}",
"columns": [
{"description": "{{ test_var }}", "name": "id"},
],
}
expected = {
'name': 'my_test',
'attribute': "1234",
'columns': [
{'description': "{{ test_var }}", 'name': 'id'},
]
"name": "my_test",
"attribute": "1234",
"columns": [
{"description": "{{ test_var }}", "name": "id"},
],
}
dct = renderer.render_data(dct)
self.assertEqual(expected, dct)
@@ -53,7 +50,7 @@ class TestYamlRendering(unittest.TestCase):
"test_var": "1234",
"alt_var": "replaced",
}
renderer = SchemaYamlRenderer(context, 'sources')
renderer = SchemaYamlRenderer(context, "sources")
# Only descriptions have jinja, none should be rendered
dct = {
@@ -68,9 +65,9 @@ class TestYamlRendering(unittest.TestCase):
"name": "id",
"description": "{{ alt_var }}",
}
]
],
}
]
],
}
rendered = renderer.render_data(dct)
self.assertEqual(dct, rendered)
@@ -81,31 +78,29 @@ class TestYamlRendering(unittest.TestCase):
"test_var": "1234",
"alt_var": "replaced",
}
renderer = SchemaYamlRenderer(context, 'macros')
renderer = SchemaYamlRenderer(context, "macros")
# Look for description in arguments
dct = {
"name": "my_macro",
"arguments": [
{"name": "my_arg", "attr": "{{ alt_var }}"},
{"name": "an_arg", "description": "{{ alt_var}}"}
]
{"name": "an_arg", "description": "{{ alt_var}}"},
],
}
expected = {
"name": "my_macro",
"arguments": [
{"name": "my_arg", "attr": "replaced"},
{"name": "an_arg", "description": "{{ alt_var}}"}
]
{"name": "an_arg", "description": "{{ alt_var}}"},
],
}
dct = renderer.render_data(dct)
self.assertEqual(dct, expected)
def test__metrics(self):
context = {
"my_time_grains": "[day]"
}
renderer = SchemaYamlRenderer(context, 'metrics')
context = {"my_time_grains": "[day]"}
renderer = SchemaYamlRenderer(context, "metrics")
dct = {
"name": "my_source",
@@ -119,7 +114,7 @@ class TestYamlRendering(unittest.TestCase):
"name": "my_source",
"description": "{{ docs('my_doc') }}",
"expression": "select {{ var('my_var') }} from my_table",
"time_grains": "[day]"
"time_grains": "[day]",
}
dct = renderer.render_data(dct)
self.assertEqual(dct, expected)

View File

@@ -28,22 +28,22 @@ def normalize(path):
class Obj:
which = 'blah'
which = "blah"
single_threaded = False
def mock_connection(name, state='open'):
def mock_connection(name, state="open"):
conn = mock.MagicMock()
conn.name = name
conn.state = state
return conn
def profile_from_dict(profile, profile_name, cli_vars='{}'):
def profile_from_dict(profile, profile_name, cli_vars="{}"):
from dbt.config import Profile
from dbt.config.renderer import ProfileRenderer
from dbt.context.base import generate_base_context
from dbt.config.utils import parse_cli_vars
if not isinstance(cli_vars, dict):
cli_vars = parse_cli_vars(cli_vars)
@@ -55,15 +55,16 @@ def profile_from_dict(profile, profile_name, cli_vars='{}'):
)
def project_from_dict(project, profile, packages=None, selectors=None, cli_vars='{}'):
def project_from_dict(project, profile, packages=None, selectors=None, cli_vars="{}"):
from dbt.config.renderer import DbtProjectYamlRenderer
from dbt.config.utils import parse_cli_vars
if not isinstance(cli_vars, dict):
cli_vars = parse_cli_vars(cli_vars)
renderer = DbtProjectYamlRenderer(profile, cli_vars)
project_root = project.pop('project-root', os.getcwd())
project_root = project.pop("project-root", os.getcwd())
partial = PartialProject.from_dicts(
project_root=project_root,
@@ -74,7 +75,6 @@ def project_from_dict(project, profile, packages=None, selectors=None, cli_vars=
return partial.render(renderer)
def config_from_parts_or_dicts(project, profile, packages=None, selectors=None, cli_vars={}):
from dbt.config import Project, Profile, RuntimeConfig
from copy import deepcopy
@@ -82,7 +82,7 @@ def config_from_parts_or_dicts(project, profile, packages=None, selectors=None,
if isinstance(project, Project):
profile_name = project.profile_name
else:
profile_name = project.get('profile')
profile_name = project.get("profile")
if not isinstance(profile, Profile):
profile = profile_from_dict(
@@ -102,16 +102,13 @@ def config_from_parts_or_dicts(project, profile, packages=None, selectors=None,
args = Obj()
args.vars = cli_vars
args.profile_dir = '/dev/null'
return RuntimeConfig.from_parts(
project=project,
profile=profile,
args=args
)
args.profile_dir = "/dev/null"
return RuntimeConfig.from_parts(project=project, profile=profile, args=args)
def inject_plugin(plugin):
from dbt.adapters.factory import FACTORY
key = plugin.adapter.type()
FACTORY.plugins[key] = plugin
@@ -119,6 +116,7 @@ def inject_plugin(plugin):
def inject_plugin_for(config):
# from dbt.adapters.postgres import Plugin, PostgresAdapter
from dbt.adapters.factory import FACTORY
FACTORY.load_plugin(config.credentials.type)
adapter = FACTORY.get_adapter(config)
return adapter
@@ -130,12 +128,14 @@ def inject_adapter(value, plugin):
"""
inject_plugin(plugin)
from dbt.adapters.factory import FACTORY
key = value.type()
FACTORY.adapters[key] = value
def clear_plugin(plugin):
from dbt.adapters.factory import FACTORY
key = plugin.adapter.type()
FACTORY.plugins.pop(key, None)
FACTORY.adapters.pop(key, None)
@@ -155,7 +155,7 @@ class ContractTestCase(TestCase):
if cls is None:
cls = self.ContractType
cls.validate(dct)
self.assertEqual(cls.from_dict(dct), obj)
self.assertEqual(cls.from_dict(dct), obj)
def assert_symmetric(self, obj, dct, cls=None):
self.assert_to_dict(obj, dct)
@@ -178,7 +178,7 @@ def compare_dicts(dict1, dict2):
common_keys = set(first_set).intersection(set(second_set))
found_differences = False
for key in common_keys:
if dict1[key] != dict2[key] :
if dict1[key] != dict2[key]:
print(f"--- --- first dict: {key}: {str(dict1[key])}")
print(f"--- --- second dict: {key}: {str(dict2[key])}")
found_differences = True
@@ -195,7 +195,7 @@ def assert_from_dict(obj, dct, cls=None):
obj_from_dict = cls.from_dict(dct)
if hasattr(obj, 'created_at'):
if hasattr(obj, "created_at"):
obj_from_dict.created_at = 1
obj.created_at = 1
@@ -204,10 +204,10 @@ def assert_from_dict(obj, dct, cls=None):
def assert_to_dict(obj, dct):
obj_to_dict = obj.to_dict(omit_none=True)
if 'created_at' in obj_to_dict:
obj_to_dict['created_at'] = 1
if 'created_at' in dct:
dct['created_at'] = 1
if "created_at" in obj_to_dict:
obj_to_dict["created_at"] = 1
if "created_at" in dct:
dct["created_at"] = 1
if obj_to_dict != dct:
compare_dicts(obj_to_dict, dct)
assert obj_to_dict == dct
@@ -227,24 +227,25 @@ def assert_fails_validation(dct, cls):
def generate_name_macros(package):
from dbt.contracts.graph.nodes import Macro
from dbt.node_types import NodeType
name_sql = {}
for component in ('database', 'schema', 'alias'):
if component == 'alias':
source = 'node.name'
for component in ("database", "schema", "alias"):
if component == "alias":
source = "node.name"
else:
source = f'target.{component}'
name = f'generate_{component}_name'
sql = f'{{% macro {name}(value, node) %}} {{% if value %}} {{{{ value }}}} {{% else %}} {{{{ {source} }}}} {{% endif %}} {{% endmacro %}}'
source = f"target.{component}"
name = f"generate_{component}_name"
sql = f"{{% macro {name}(value, node) %}} {{% if value %}} {{{{ value }}}} {{% else %}} {{{{ {source} }}}} {{% endif %}} {{% endmacro %}}"
name_sql[name] = sql
for name, sql in name_sql.items():
pm = Macro(
name=name,
resource_type=NodeType.Macro,
unique_id=f'macro.{package}.{name}',
unique_id=f"macro.{package}.{name}",
package_name=package,
original_file_path=normalize('macros/macro.sql'),
path=normalize('macros/macro.sql'),
original_file_path=normalize("macros/macro.sql"),
path=normalize("macros/macro.sql"),
macro_sql=sql,
)
yield pm
@@ -253,6 +254,7 @@ def generate_name_macros(package):
class TestAdapterConversions(TestCase):
def _get_tester_for(self, column_type):
from dbt.clients import agate_helper
if column_type is agate.TimeDelta: # dbt never makes this!
return agate.TimeDelta()
@@ -260,10 +262,10 @@ class TestAdapterConversions(TestCase):
if isinstance(instance, column_type): # include child types
return instance
raise ValueError(f'no tester for {column_type}')
raise ValueError(f"no tester for {column_type}")
def _make_table_of(self, rows, column_types):
column_names = list(string.ascii_letters[:len(rows[0])])
column_names = list(string.ascii_letters[: len(rows[0])])
if isinstance(column_types, type):
column_types = [self._get_tester_for(column_types) for _ in column_names]
else:
@@ -272,50 +274,48 @@ class TestAdapterConversions(TestCase):
return table
def MockMacro(package, name='my_macro', **kwargs):
def MockMacro(package, name="my_macro", **kwargs):
from dbt.contracts.graph.nodes import Macro
from dbt.node_types import NodeType
mock_kwargs = dict(
resource_type=NodeType.Macro,
package_name=package,
unique_id=f'macro.{package}.{name}',
original_file_path='/dev/null',
unique_id=f"macro.{package}.{name}",
original_file_path="/dev/null",
)
mock_kwargs.update(kwargs)
macro = mock.MagicMock(
spec=Macro,
**mock_kwargs
)
macro = mock.MagicMock(spec=Macro, **mock_kwargs)
macro.name = name
return macro
def MockMaterialization(package, name='my_materialization', adapter_type=None, **kwargs):
def MockMaterialization(package, name="my_materialization", adapter_type=None, **kwargs):
if adapter_type is None:
adapter_type = 'default'
kwargs['adapter_type'] = adapter_type
return MockMacro(package, f'materialization_{name}_{adapter_type}', **kwargs)
adapter_type = "default"
kwargs["adapter_type"] = adapter_type
return MockMacro(package, f"materialization_{name}_{adapter_type}", **kwargs)
def MockGenerateMacro(package, component='some_component', **kwargs):
name = f'generate_{component}_name'
def MockGenerateMacro(package, component="some_component", **kwargs):
name = f"generate_{component}_name"
return MockMacro(package, name=name, **kwargs)
def MockSource(package, source_name, name, **kwargs):
from dbt.node_types import NodeType
from dbt.contracts.graph.nodes import SourceDefinition
src = mock.MagicMock(
__class__=SourceDefinition,
resource_type=NodeType.Source,
source_name=source_name,
package_name=package,
unique_id=f'source.{package}.{source_name}.{name}',
search_name=f'{source_name}.{name}',
**kwargs
unique_id=f"source.{package}.{source_name}.{name}",
search_name=f"{source_name}.{name}",
**kwargs,
)
src.name = name
return src
@@ -324,6 +324,7 @@ def MockSource(package, source_name, name, **kwargs):
def MockNode(package, name, resource_type=None, **kwargs):
from dbt.node_types import NodeType
from dbt.contracts.graph.nodes import ModelNode, SeedNode
if resource_type is None:
resource_type = NodeType.Model
if resource_type == NodeType.Model:
@@ -331,14 +332,14 @@ def MockNode(package, name, resource_type=None, **kwargs):
elif resource_type == NodeType.Seed:
cls = SeedNode
else:
raise ValueError(f'I do not know how to handle {resource_type}')
raise ValueError(f"I do not know how to handle {resource_type}")
node = mock.MagicMock(
__class__=cls,
resource_type=resource_type,
package_name=package,
unique_id=f'{str(resource_type)}.{package}.{name}',
unique_id=f"{str(resource_type)}.{package}.{name}",
search_name=name,
**kwargs
**kwargs,
)
node.name = name
return node
@@ -347,22 +348,23 @@ def MockNode(package, name, resource_type=None, **kwargs):
def MockDocumentation(package, name, **kwargs):
from dbt.node_types import NodeType
from dbt.contracts.graph.nodes import Documentation
doc = mock.MagicMock(
__class__=Documentation,
resource_type=NodeType.Documentation,
package_name=package,
search_name=name,
unique_id=f'{package}.{name}',
**kwargs
unique_id=f"{package}.{name}",
**kwargs,
)
doc.name = name
return doc
def load_internal_manifest_macros(config, macro_hook = lambda m: None):
def load_internal_manifest_macros(config, macro_hook=lambda m: None):
from dbt.parser.manifest import ManifestLoader
return ManifestLoader.load_macros(config, macro_hook)
return ManifestLoader.load_macros(config, macro_hook)
def dict_replace(dct, **kwargs):
@@ -376,4 +378,3 @@ def replace_config(n, **kwargs):
config=n.config.replace(**kwargs),
unrendered_config=dict_replace(n.unrendered_config, **kwargs),
)