mirror of
https://github.com/dbt-labs/dbt-core
synced 2025-12-20 15:41:27 +00:00
Compare commits
3 Commits
adding-sem
...
ct-236-ada
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b5c79f9e09 | ||
|
|
424024ea9e | ||
|
|
0c71b44e6b |
7
.changes/unreleased/Features-20220309-142645.yaml
Normal file
7
.changes/unreleased/Features-20220309-142645.yaml
Normal file
@@ -0,0 +1,7 @@
|
||||
kind: Features
|
||||
body: Testing framework for dbt adapter testing
|
||||
time: 2022-03-09T14:26:45.828295-05:00
|
||||
custom:
|
||||
Author: gshank
|
||||
Issue: "4730"
|
||||
PR: "4846"
|
||||
@@ -218,3 +218,29 @@ class SQLAdapter(BaseAdapter):
|
||||
kwargs = {"information_schema": information_schema, "schema": schema}
|
||||
results = self.execute_macro(CHECK_SCHEMA_EXISTS_MACRO_NAME, kwargs=kwargs)
|
||||
return results[0][0] > 0
|
||||
|
||||
# This is for use in the test suite
|
||||
def run_sql_for_tests(self, sql, fetch, conn):
|
||||
cursor = conn.handle.cursor()
|
||||
try:
|
||||
cursor.execute(sql)
|
||||
if hasattr(conn.handle, "commit"):
|
||||
conn.handle.commit()
|
||||
if fetch == "one":
|
||||
if hasattr(cursor, "fetchone"): # for spark
|
||||
return cursor.fetchone()
|
||||
else:
|
||||
# for spark
|
||||
return cursor.fetchall()[0]
|
||||
elif fetch == "all":
|
||||
return cursor.fetchall()
|
||||
else:
|
||||
return
|
||||
except BaseException as e:
|
||||
if conn.handle and not getattr(conn.handle, "closed", True):
|
||||
conn.handle.rollback()
|
||||
print(sql)
|
||||
print(e)
|
||||
raise
|
||||
finally:
|
||||
conn.transaction_open = False
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
# dbt.tests directory
|
||||
@@ -1,212 +0,0 @@
|
||||
from itertools import chain, repeat
|
||||
from dbt.context import providers
|
||||
from unittest.mock import patch
|
||||
|
||||
# These functions were extracted from the dbt-adapter-tests spec_file.py.
|
||||
# They are used in the 'adapter' tests directory. At some point they
|
||||
# might be moved to dbts.tests.util if they are of general purpose use,
|
||||
# but leaving here for now to keep the adapter work more contained.
|
||||
# We may want to consolidate in the future since some of this is kind
|
||||
# of duplicative of the functionality in dbt.tests.tables.
|
||||
|
||||
|
||||
class TestProcessingException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def relation_from_name(adapter, name: str):
|
||||
"""reverse-engineer a relation (including quoting) from a given name and
|
||||
the adapter. Assumes that relations are split by the '.' character.
|
||||
"""
|
||||
|
||||
# Different adapters have different Relation classes
|
||||
cls = adapter.Relation
|
||||
credentials = adapter.config.credentials
|
||||
quote_policy = cls.get_default_quote_policy().to_dict()
|
||||
include_policy = cls.get_default_include_policy().to_dict()
|
||||
kwargs = {} # This will contain database, schema, identifier
|
||||
|
||||
parts = name.split(".")
|
||||
names = ["database", "schema", "identifier"]
|
||||
defaults = [credentials.database, credentials.schema, None]
|
||||
values = chain(repeat(None, 3 - len(parts)), parts)
|
||||
for name, value, default in zip(names, values, defaults):
|
||||
# no quote policy -> use the default
|
||||
if value is None:
|
||||
if default is None:
|
||||
include_policy[name] = False
|
||||
value = default
|
||||
else:
|
||||
include_policy[name] = True
|
||||
# if we have a value, we can figure out the quote policy.
|
||||
trimmed = value[1:-1]
|
||||
if adapter.quote(trimmed) == value:
|
||||
quote_policy[name] = True
|
||||
value = trimmed
|
||||
else:
|
||||
quote_policy[name] = False
|
||||
kwargs[name] = value
|
||||
|
||||
relation = cls.create(
|
||||
include_policy=include_policy,
|
||||
quote_policy=quote_policy,
|
||||
**kwargs,
|
||||
)
|
||||
return relation
|
||||
|
||||
|
||||
def check_relation_types(adapter, relation_to_type):
|
||||
"""
|
||||
Relation name to table/view
|
||||
{
|
||||
"base": "table",
|
||||
"other": "view",
|
||||
}
|
||||
"""
|
||||
|
||||
expected_relation_values = {}
|
||||
found_relations = []
|
||||
schemas = set()
|
||||
|
||||
for key, value in relation_to_type.items():
|
||||
relation = relation_from_name(adapter, key)
|
||||
expected_relation_values[relation] = value
|
||||
schemas.add(relation.without_identifier())
|
||||
|
||||
with patch.object(providers, "get_adapter", return_value=adapter):
|
||||
with adapter.connection_named("__test"):
|
||||
for schema in schemas:
|
||||
found_relations.extend(adapter.list_relations_without_caching(schema))
|
||||
|
||||
for key, value in relation_to_type.items():
|
||||
for relation in found_relations:
|
||||
# this might be too broad
|
||||
if relation.identifier == key:
|
||||
assert relation.type == value, (
|
||||
f"Got an unexpected relation type of {relation.type} "
|
||||
f"for relation {key}, expected {value}"
|
||||
)
|
||||
|
||||
|
||||
def check_relations_equal(adapter, relation_names):
|
||||
if len(relation_names) < 2:
|
||||
raise TestProcessingException(
|
||||
"Not enough relations to compare",
|
||||
)
|
||||
relations = [relation_from_name(adapter, name) for name in relation_names]
|
||||
|
||||
with patch.object(providers, "get_adapter", return_value=adapter):
|
||||
with adapter.connection_named("_test"):
|
||||
basis, compares = relations[0], relations[1:]
|
||||
columns = [c.name for c in adapter.get_columns_in_relation(basis)]
|
||||
|
||||
for relation in compares:
|
||||
sql = adapter.get_rows_different_sql(basis, relation, column_names=columns)
|
||||
_, tbl = adapter.execute(sql, fetch=True)
|
||||
num_rows = len(tbl)
|
||||
assert (
|
||||
num_rows == 1
|
||||
), f"Invalid sql query from get_rows_different_sql: incorrect number of rows ({num_rows})"
|
||||
num_cols = len(tbl[0])
|
||||
assert (
|
||||
num_cols == 2
|
||||
), f"Invalid sql query from get_rows_different_sql: incorrect number of cols ({num_cols})"
|
||||
row_count_difference = tbl[0][0]
|
||||
assert (
|
||||
row_count_difference == 0
|
||||
), f"Got {row_count_difference} difference in row count betwen {basis} and {relation}"
|
||||
rows_mismatched = tbl[0][1]
|
||||
assert (
|
||||
rows_mismatched == 0
|
||||
), f"Got {rows_mismatched} different rows between {basis} and {relation}"
|
||||
|
||||
|
||||
def get_unique_ids_in_results(results):
|
||||
unique_ids = []
|
||||
for result in results:
|
||||
unique_ids.append(result.node.unique_id)
|
||||
return unique_ids
|
||||
|
||||
|
||||
def check_result_nodes_by_name(results, names):
|
||||
result_names = []
|
||||
for result in results:
|
||||
result_names.append(result.node.name)
|
||||
assert set(names) == set(result_names)
|
||||
|
||||
|
||||
def check_result_nodes_by_unique_id(results, unique_ids):
|
||||
result_unique_ids = []
|
||||
for result in results:
|
||||
result_unique_ids.append(result.node.unique_id)
|
||||
assert set(unique_ids) == set(result_unique_ids)
|
||||
|
||||
|
||||
def update_rows(adapter, update_rows_config):
|
||||
"""
|
||||
{
|
||||
"name": "base",
|
||||
"dst_col": "some_date"
|
||||
"clause": {
|
||||
"type": "add_timestamp",
|
||||
"src_col": "some_date",
|
||||
"where" "id > 10"
|
||||
}
|
||||
"""
|
||||
for key in ["name", "dst_col", "clause"]:
|
||||
if key not in update_rows_config:
|
||||
raise TestProcessingException(f"Invalid update_rows: no {key}")
|
||||
|
||||
clause = update_rows_config["clause"]
|
||||
clause = generate_update_clause(adapter, clause)
|
||||
|
||||
where = None
|
||||
if "where" in update_rows_config:
|
||||
where = update_rows_config["where"]
|
||||
|
||||
name = update_rows_config["name"]
|
||||
dst_col = update_rows_config["dst_col"]
|
||||
relation = relation_from_name(adapter, name)
|
||||
|
||||
with patch.object(providers, "get_adapter", return_value=adapter):
|
||||
with adapter.connection_named("_test"):
|
||||
sql = adapter.update_column_sql(
|
||||
dst_name=str(relation),
|
||||
dst_column=dst_col,
|
||||
clause=clause,
|
||||
where_clause=where,
|
||||
)
|
||||
print(f"--- update_rows sql: {sql}")
|
||||
adapter.execute(sql, auto_begin=True)
|
||||
adapter.commit_if_has_connection()
|
||||
|
||||
|
||||
def generate_update_clause(adapter, clause) -> str:
|
||||
"""
|
||||
Called by update_rows function. Expects the "clause" dictionary
|
||||
documented in 'update_rows.
|
||||
"""
|
||||
|
||||
if "type" not in clause or clause["type"] not in ["add_timestamp", "add_string"]:
|
||||
raise TestProcessingException("invalid update_rows clause: type missing or incorrect")
|
||||
clause_type = clause["type"]
|
||||
|
||||
if clause_type == "add_timestamp":
|
||||
if "src_col" not in clause:
|
||||
raise TestProcessingException("Invalid update_rows clause: no src_col")
|
||||
add_to = clause["src_col"]
|
||||
kwargs = {k: v for k, v in clause.items() if k in ("interval", "number")}
|
||||
with patch.object(providers, "get_adapter", return_value=adapter):
|
||||
with adapter.connection_named("_test"):
|
||||
return adapter.timestamp_add_sql(add_to=add_to, **kwargs)
|
||||
elif clause_type == "add_string":
|
||||
for key in ["src_col", "value"]:
|
||||
if key not in clause:
|
||||
raise TestProcessingException(f"Invalid update_rows clause: no {key}")
|
||||
src_col = clause["src_col"]
|
||||
value = clause["value"]
|
||||
location = clause.get("location", "append")
|
||||
with patch.object(providers, "get_adapter", return_value=adapter):
|
||||
with adapter.connection_named("_test"):
|
||||
return adapter.string_add_sql(src_col, value, location)
|
||||
return ""
|
||||
57
core/dbt/tests/fixtures/project.py
vendored
57
core/dbt/tests/fixtures/project.py
vendored
@@ -10,7 +10,7 @@ import dbt.flags as flags
|
||||
from dbt.config.runtime import RuntimeConfig
|
||||
from dbt.adapters.factory import get_adapter, register_adapter, reset_adapters
|
||||
from dbt.events.functions import setup_event_logger
|
||||
from dbt.tests.util import write_file, run_sql_with_adapter
|
||||
from dbt.tests.util import write_file, run_sql_with_adapter, TestProcessingException
|
||||
|
||||
|
||||
# These are the fixtures that are used in dbt core functional tests
|
||||
@@ -63,6 +63,19 @@ def test_data_dir(request):
|
||||
return os.path.join(request.fspath.dirname, "data")
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def dbt_profile_target():
|
||||
return {
|
||||
"type": "postgres",
|
||||
"threads": 4,
|
||||
"host": "localhost",
|
||||
"port": int(os.getenv("POSTGRES_TEST_PORT", 5432)),
|
||||
"user": os.getenv("POSTGRES_TEST_USER", "root"),
|
||||
"pass": os.getenv("POSTGRES_TEST_PASS", "password"),
|
||||
"dbname": os.getenv("POSTGRES_TEST_DATABASE", "dbt"),
|
||||
}
|
||||
|
||||
|
||||
# This fixture can be overridden in a project
|
||||
@pytest.fixture(scope="class")
|
||||
def profiles_config_update():
|
||||
@@ -71,35 +84,19 @@ def profiles_config_update():
|
||||
|
||||
# The profile dictionary, used to write out profiles.yml
|
||||
@pytest.fixture(scope="class")
|
||||
def dbt_profile_data(unique_schema, profiles_config_update):
|
||||
def dbt_profile_data(unique_schema, dbt_profile_target, profiles_config_update):
|
||||
profile = {
|
||||
"config": {"send_anonymous_usage_stats": False},
|
||||
"test": {
|
||||
"outputs": {
|
||||
"default": {
|
||||
"type": "postgres",
|
||||
"threads": 4,
|
||||
"host": "localhost",
|
||||
"port": int(os.getenv("POSTGRES_TEST_PORT", 5432)),
|
||||
"user": os.getenv("POSTGRES_TEST_USER", "root"),
|
||||
"pass": os.getenv("POSTGRES_TEST_PASS", "password"),
|
||||
"dbname": os.getenv("POSTGRES_TEST_DATABASE", "dbt"),
|
||||
"schema": unique_schema,
|
||||
},
|
||||
"other_schema": {
|
||||
"type": "postgres",
|
||||
"threads": 4,
|
||||
"host": "localhost",
|
||||
"port": int(os.getenv("POSTGRES_TEST_PORT", 5432)),
|
||||
"user": "noaccess",
|
||||
"pass": "password",
|
||||
"dbname": os.getenv("POSTGRES_TEST_DATABASE", "dbt"),
|
||||
"schema": unique_schema + "_alt", # Should this be the same unique_schema?
|
||||
},
|
||||
"default": {},
|
||||
},
|
||||
"target": "default",
|
||||
},
|
||||
}
|
||||
target = dbt_profile_target
|
||||
target["schema"] = unique_schema
|
||||
profile["test"]["outputs"]["default"] = target
|
||||
|
||||
if profiles_config_update:
|
||||
profile.update(profiles_config_update)
|
||||
@@ -199,6 +196,8 @@ def write_project_files(project_root, dir_name, file_dict):
|
||||
|
||||
# Write files out from file_dict. Can be nested directories...
|
||||
def write_project_files_recursively(path, file_dict):
|
||||
if type(file_dict) is not dict:
|
||||
raise TestProcessingException(f"Error creating {path}. Did you forget the file extension?")
|
||||
for name, value in file_dict.items():
|
||||
if name.endswith(".sql") or name.endswith(".csv") or name.endswith(".md"):
|
||||
write_file(value, path, name)
|
||||
@@ -276,6 +275,7 @@ class TestProjInfo:
|
||||
test_data_dir,
|
||||
test_schema,
|
||||
database,
|
||||
test_config,
|
||||
):
|
||||
self.project_root = project_root
|
||||
self.profiles_dir = profiles_dir
|
||||
@@ -285,6 +285,7 @@ class TestProjInfo:
|
||||
self.test_data_dir = test_data_dir
|
||||
self.test_schema = test_schema
|
||||
self.database = database
|
||||
self.test_config = test_config
|
||||
|
||||
# Run sql from a path
|
||||
def run_sql_file(self, sql_path, fetch=None):
|
||||
@@ -313,6 +314,13 @@ class TestProjInfo:
|
||||
return {model_name: materialization for (model_name, materialization) in result}
|
||||
|
||||
|
||||
# This fixture is for customizing tests that need overrides in adapter
|
||||
# repos. Example in dbt.tests.adapter.basic.test_base.
|
||||
@pytest.fixture(scope="class")
|
||||
def test_config():
|
||||
return {}
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def project(
|
||||
project_root,
|
||||
@@ -328,6 +336,7 @@ def project(
|
||||
shared_data_dir,
|
||||
test_data_dir,
|
||||
logs_dir,
|
||||
test_config,
|
||||
):
|
||||
# Logbook warnings are ignored so we don't have to fork logbook to support python 3.10.
|
||||
# This _only_ works for tests in `tests/` that use the project fixture.
|
||||
@@ -345,8 +354,8 @@ def project(
|
||||
shared_data_dir=shared_data_dir,
|
||||
test_data_dir=test_data_dir,
|
||||
test_schema=unique_schema,
|
||||
# the following feels kind of fragile. TODO: better way of getting database
|
||||
database=profiles_yml["test"]["outputs"]["default"]["dbname"],
|
||||
database=adapter.config.credentials.database,
|
||||
test_config=test_config,
|
||||
)
|
||||
project.run_sql("drop schema if exists {schema} cascade")
|
||||
project.run_sql("create schema {schema}")
|
||||
|
||||
@@ -46,7 +46,7 @@ def run_dbt_and_capture(args: List[str] = None, expect_pass=True):
|
||||
|
||||
# Used in test cases to get the manifest from the partial parsing file
|
||||
def get_manifest(project_root):
|
||||
path = project_root.join("target", "partial_parse.msgpack")
|
||||
path = os.path.join(project_root, "target", "partial_parse.msgpack")
|
||||
if os.path.exists(path):
|
||||
with open(path, "rb") as fp:
|
||||
manifest_mp = fp.read()
|
||||
@@ -122,6 +122,8 @@ def run_sql_with_adapter(adapter, sql, fetch=None):
|
||||
}
|
||||
sql = sql.format(**kwargs)
|
||||
|
||||
msg = f'test connection "__test" executing: {sql}'
|
||||
fire_event(IntegrationTestDebug(msg=msg))
|
||||
# Since the 'adapter' in dbt.adapters.factory may have been replaced by execution
|
||||
# of dbt commands since the test 'adapter' was created, we patch the 'get_adapter' call in
|
||||
# dbt.context.providers, so that macros that are called refer to this test adapter.
|
||||
@@ -130,24 +132,196 @@ def run_sql_with_adapter(adapter, sql, fetch=None):
|
||||
with patch.object(providers, "get_adapter", return_value=adapter):
|
||||
with adapter.connection_named("__test"):
|
||||
conn = adapter.connections.get_thread_connection()
|
||||
msg = f'test connection "{conn.name}" executing: {sql}'
|
||||
fire_event(IntegrationTestDebug(msg=msg))
|
||||
with conn.handle.cursor() as cursor:
|
||||
try:
|
||||
cursor.execute(sql)
|
||||
conn.handle.commit()
|
||||
conn.handle.commit()
|
||||
if fetch == "one":
|
||||
return cursor.fetchone()
|
||||
elif fetch == "all":
|
||||
return cursor.fetchall()
|
||||
else:
|
||||
return
|
||||
except BaseException as e:
|
||||
if conn.handle and not getattr(conn.handle, "closed", True):
|
||||
conn.handle.rollback()
|
||||
print(sql)
|
||||
print(e)
|
||||
raise
|
||||
finally:
|
||||
conn.transaction_open = False
|
||||
return adapter.run_sql_for_tests(sql, fetch, conn)
|
||||
|
||||
|
||||
class TestProcessingException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def relation_from_name(adapter, name: str):
|
||||
"""reverse-engineer a relation from a given name and
|
||||
the adapter. The relation name is split by the '.' character.
|
||||
"""
|
||||
|
||||
# Different adapters have different Relation classes
|
||||
cls = adapter.Relation
|
||||
credentials = adapter.config.credentials
|
||||
quote_policy = cls.get_default_quote_policy().to_dict()
|
||||
include_policy = cls.get_default_include_policy().to_dict()
|
||||
|
||||
# Make sure we have database/schema/identifier parts, even if
|
||||
# only identifier was supplied.
|
||||
relation_parts = name.split(".")
|
||||
if len(relation_parts) == 1:
|
||||
relation_parts.insert(0, credentials.schema)
|
||||
if len(relation_parts) == 2:
|
||||
relation_parts.insert(0, credentials.database)
|
||||
kwargs = {
|
||||
"database": relation_parts[0],
|
||||
"schema": relation_parts[1],
|
||||
"identifier": relation_parts[2],
|
||||
}
|
||||
|
||||
relation = cls.create(
|
||||
include_policy=include_policy,
|
||||
quote_policy=quote_policy,
|
||||
**kwargs,
|
||||
)
|
||||
return relation
|
||||
|
||||
|
||||
def check_relation_types(adapter, relation_to_type):
|
||||
"""
|
||||
Relation name to table/view
|
||||
{
|
||||
"base": "table",
|
||||
"other": "view",
|
||||
}
|
||||
"""
|
||||
|
||||
expected_relation_values = {}
|
||||
found_relations = []
|
||||
schemas = set()
|
||||
|
||||
for key, value in relation_to_type.items():
|
||||
relation = relation_from_name(adapter, key)
|
||||
expected_relation_values[relation] = value
|
||||
schemas.add(relation.without_identifier())
|
||||
|
||||
with patch.object(providers, "get_adapter", return_value=adapter):
|
||||
with adapter.connection_named("__test"):
|
||||
for schema in schemas:
|
||||
found_relations.extend(adapter.list_relations_without_caching(schema))
|
||||
|
||||
for key, value in relation_to_type.items():
|
||||
for relation in found_relations:
|
||||
# this might be too broad
|
||||
if relation.identifier == key:
|
||||
assert relation.type == value, (
|
||||
f"Got an unexpected relation type of {relation.type} "
|
||||
f"for relation {key}, expected {value}"
|
||||
)
|
||||
|
||||
|
||||
def check_relations_equal(adapter, relation_names):
|
||||
if len(relation_names) < 2:
|
||||
raise TestProcessingException(
|
||||
"Not enough relations to compare",
|
||||
)
|
||||
relations = [relation_from_name(adapter, name) for name in relation_names]
|
||||
|
||||
with patch.object(providers, "get_adapter", return_value=adapter):
|
||||
with adapter.connection_named("_test"):
|
||||
basis, compares = relations[0], relations[1:]
|
||||
columns = [c.name for c in adapter.get_columns_in_relation(basis)]
|
||||
|
||||
for relation in compares:
|
||||
sql = adapter.get_rows_different_sql(basis, relation, column_names=columns)
|
||||
_, tbl = adapter.execute(sql, fetch=True)
|
||||
num_rows = len(tbl)
|
||||
assert (
|
||||
num_rows == 1
|
||||
), f"Invalid sql query from get_rows_different_sql: incorrect number of rows ({num_rows})"
|
||||
num_cols = len(tbl[0])
|
||||
assert (
|
||||
num_cols == 2
|
||||
), f"Invalid sql query from get_rows_different_sql: incorrect number of cols ({num_cols})"
|
||||
row_count_difference = tbl[0][0]
|
||||
assert (
|
||||
row_count_difference == 0
|
||||
), f"Got {row_count_difference} difference in row count betwen {basis} and {relation}"
|
||||
rows_mismatched = tbl[0][1]
|
||||
assert (
|
||||
rows_mismatched == 0
|
||||
), f"Got {rows_mismatched} different rows between {basis} and {relation}"
|
||||
|
||||
|
||||
def get_unique_ids_in_results(results):
|
||||
unique_ids = []
|
||||
for result in results:
|
||||
unique_ids.append(result.node.unique_id)
|
||||
return unique_ids
|
||||
|
||||
|
||||
def check_result_nodes_by_name(results, names):
|
||||
result_names = []
|
||||
for result in results:
|
||||
result_names.append(result.node.name)
|
||||
assert set(names) == set(result_names)
|
||||
|
||||
|
||||
def check_result_nodes_by_unique_id(results, unique_ids):
|
||||
result_unique_ids = []
|
||||
for result in results:
|
||||
result_unique_ids.append(result.node.unique_id)
|
||||
assert set(unique_ids) == set(result_unique_ids)
|
||||
|
||||
|
||||
def update_rows(adapter, update_rows_config):
|
||||
"""
|
||||
{
|
||||
"name": "base",
|
||||
"dst_col": "some_date"
|
||||
"clause": {
|
||||
"type": "add_timestamp",
|
||||
"src_col": "some_date",
|
||||
"where" "id > 10"
|
||||
}
|
||||
"""
|
||||
for key in ["name", "dst_col", "clause"]:
|
||||
if key not in update_rows_config:
|
||||
raise TestProcessingException(f"Invalid update_rows: no {key}")
|
||||
|
||||
clause = update_rows_config["clause"]
|
||||
clause = generate_update_clause(adapter, clause)
|
||||
|
||||
where = None
|
||||
if "where" in update_rows_config:
|
||||
where = update_rows_config["where"]
|
||||
|
||||
name = update_rows_config["name"]
|
||||
dst_col = update_rows_config["dst_col"]
|
||||
relation = relation_from_name(adapter, name)
|
||||
|
||||
with patch.object(providers, "get_adapter", return_value=adapter):
|
||||
with adapter.connection_named("_test"):
|
||||
sql = adapter.update_column_sql(
|
||||
dst_name=str(relation),
|
||||
dst_column=dst_col,
|
||||
clause=clause,
|
||||
where_clause=where,
|
||||
)
|
||||
adapter.execute(sql, auto_begin=True)
|
||||
adapter.commit_if_has_connection()
|
||||
|
||||
|
||||
def generate_update_clause(adapter, clause) -> str:
|
||||
"""
|
||||
Called by update_rows function. Expects the "clause" dictionary
|
||||
documented in 'update_rows.
|
||||
"""
|
||||
|
||||
if "type" not in clause or clause["type"] not in ["add_timestamp", "add_string"]:
|
||||
raise TestProcessingException("invalid update_rows clause: type missing or incorrect")
|
||||
clause_type = clause["type"]
|
||||
|
||||
if clause_type == "add_timestamp":
|
||||
if "src_col" not in clause:
|
||||
raise TestProcessingException("Invalid update_rows clause: no src_col")
|
||||
add_to = clause["src_col"]
|
||||
kwargs = {k: v for k, v in clause.items() if k in ("interval", "number")}
|
||||
with patch.object(providers, "get_adapter", return_value=adapter):
|
||||
with adapter.connection_named("_test"):
|
||||
return adapter.timestamp_add_sql(add_to=add_to, **kwargs)
|
||||
elif clause_type == "add_string":
|
||||
for key in ["src_col", "value"]:
|
||||
if key not in clause:
|
||||
raise TestProcessingException(f"Invalid update_rows clause: no {key}")
|
||||
src_col = clause["src_col"]
|
||||
value = clause["value"]
|
||||
location = clause.get("location", "append")
|
||||
with patch.object(providers, "get_adapter", return_value=adapter):
|
||||
with adapter.connection_named("_test"):
|
||||
return adapter.string_add_sql(src_col, value, location)
|
||||
return ""
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
-e ./core
|
||||
-e ./plugins/postgres
|
||||
-e ./tests/adapter
|
||||
|
||||
22
tests/adapter/README.md
Normal file
22
tests/adapter/README.md
Normal file
@@ -0,0 +1,22 @@
|
||||
# dbt-tests-adapter
|
||||
|
||||
This is where we store the adapter tests that will be used by
|
||||
plugin repos. It should be included in the dbt-core and plugin
|
||||
repos by installing using pip.
|
||||
|
||||
Tests in this repo will be packaged as classes, with a base class
|
||||
that can be imported by adapter test repositories. Tests that might
|
||||
need to be overridden by a plugin will be separated into separate
|
||||
test methods or test classes as they are discovered.
|
||||
|
||||
This plugin is installed in the dbt-core repo by pip install -e tests/adapter,
|
||||
which is included in the editable-requirements.txt file.
|
||||
|
||||
The dbt.tests.adapter.basic tests originally came from the earlier
|
||||
dbt-adapter-tests repository. Additional test directories will be
|
||||
added as they are converted from dbt-core integration tests, so that
|
||||
they can be used in adapter test suites without copying and pasting.
|
||||
|
||||
This is packaged as a plugin using a python namespace package so it
|
||||
cannot have an __init__.py file in the part of the hierarchy to which it
|
||||
needs to be attached.
|
||||
@@ -53,7 +53,7 @@ sources:
|
||||
identifier: "{{ var('seed_name', 'base') }}"
|
||||
"""
|
||||
|
||||
schema_test_seed_yml = """
|
||||
generic_test_seed_yml = """
|
||||
version: 2
|
||||
models:
|
||||
- name: base
|
||||
@@ -63,7 +63,7 @@ models:
|
||||
- not_null
|
||||
"""
|
||||
|
||||
schema_test_view_yml = """
|
||||
generic_test_view_yml = """
|
||||
version: 2
|
||||
models:
|
||||
- name: view_model
|
||||
@@ -73,7 +73,7 @@ models:
|
||||
- not_null
|
||||
"""
|
||||
|
||||
schema_test_table_yml = """
|
||||
generic_test_table_yml = """
|
||||
version: 2
|
||||
models:
|
||||
- name: table_model
|
||||
110
tests/adapter/dbt/tests/adapter/basic/test_base.py
Normal file
110
tests/adapter/dbt/tests/adapter/basic/test_base.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import pytest
|
||||
from dbt.tests.util import (
|
||||
run_dbt,
|
||||
check_result_nodes_by_name,
|
||||
relation_from_name,
|
||||
check_relation_types,
|
||||
check_relations_equal,
|
||||
)
|
||||
from dbt.tests.adapter.basic.files import (
|
||||
seeds_base_csv,
|
||||
base_view_sql,
|
||||
base_table_sql,
|
||||
base_materialized_var_sql,
|
||||
schema_base_yml,
|
||||
)
|
||||
|
||||
|
||||
class BaseSimpleMaterializations:
|
||||
@pytest.fixture(scope="class")
|
||||
def models(self):
|
||||
return {
|
||||
"view_model.sql": base_view_sql,
|
||||
"table_model.sql": base_table_sql,
|
||||
"swappable.sql": base_materialized_var_sql,
|
||||
"schema.yml": schema_base_yml,
|
||||
}
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def seeds(self):
|
||||
return {
|
||||
"base.csv": seeds_base_csv,
|
||||
}
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def project_config_update(self):
|
||||
return {
|
||||
"name": "base",
|
||||
}
|
||||
|
||||
def test_base(self, project):
|
||||
|
||||
# seed command
|
||||
results = run_dbt(["seed"])
|
||||
# seed result length
|
||||
assert len(results) == 1
|
||||
|
||||
# run command
|
||||
results = run_dbt()
|
||||
# run result length
|
||||
assert len(results) == 3
|
||||
|
||||
# names exist in result nodes
|
||||
check_result_nodes_by_name(results, ["view_model", "table_model", "swappable"])
|
||||
|
||||
# check relation types
|
||||
expected = {
|
||||
"base": "table",
|
||||
"view_model": "view",
|
||||
"table_model": "table",
|
||||
"swappable": "table",
|
||||
}
|
||||
check_relation_types(project.adapter, expected)
|
||||
|
||||
# base table rowcount
|
||||
relation = relation_from_name(project.adapter, "base")
|
||||
result = project.run_sql(f"select count(*) as num_rows from {relation}", fetch="one")
|
||||
assert result[0] == 10
|
||||
|
||||
# relations_equal
|
||||
check_relations_equal(project.adapter, ["base", "view_model", "table_model", "swappable"])
|
||||
|
||||
# check relations in catalog
|
||||
catalog = run_dbt(["docs", "generate"])
|
||||
assert len(catalog.nodes) == 4
|
||||
assert len(catalog.sources) == 1
|
||||
|
||||
# run_dbt changing materialized_var to view
|
||||
if project.test_config.get("require_full_refresh", False): # required for BigQuery
|
||||
results = run_dbt(
|
||||
["run", "--full-refresh", "-m", "swappable", "--vars", "materialized_var: view"]
|
||||
)
|
||||
else:
|
||||
results = run_dbt(["run", "-m", "swappable", "--vars", "materialized_var: view"])
|
||||
assert len(results) == 1
|
||||
|
||||
# check relation types, swappable is view
|
||||
expected = {
|
||||
"base": "table",
|
||||
"view_model": "view",
|
||||
"table_model": "table",
|
||||
"swappable": "view",
|
||||
}
|
||||
check_relation_types(project.adapter, expected)
|
||||
|
||||
# run_dbt changing materialized_var to incremental
|
||||
results = run_dbt(["run", "-m", "swappable", "--vars", "materialized_var: incremental"])
|
||||
assert len(results) == 1
|
||||
|
||||
# check relation types, swappable is table
|
||||
expected = {
|
||||
"base": "table",
|
||||
"view_model": "view",
|
||||
"table_model": "table",
|
||||
"swappable": "table",
|
||||
}
|
||||
check_relation_types(project.adapter, expected)
|
||||
|
||||
|
||||
class TestSimpleMaterializations(BaseSimpleMaterializations):
|
||||
pass
|
||||
29
tests/adapter/dbt/tests/adapter/basic/test_empty.py
Normal file
29
tests/adapter/dbt/tests/adapter/basic/test_empty.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from dbt.tests.util import run_dbt
|
||||
import os
|
||||
|
||||
|
||||
class BaseEmpty:
|
||||
def test_empty(self, project):
|
||||
# check seed
|
||||
results = run_dbt(["seed"])
|
||||
assert len(results) == 0
|
||||
run_results_path = os.path.join(project.project_root, "target", "run_results.json")
|
||||
assert os.path.exists(run_results_path)
|
||||
|
||||
# check run
|
||||
results = run_dbt(["run"])
|
||||
assert len(results) == 0
|
||||
|
||||
catalog_path = os.path.join(project.project_root, "target", "catalog.json")
|
||||
assert not os.path.exists(catalog_path)
|
||||
|
||||
# check catalog
|
||||
catalog = run_dbt(["docs", "generate"])
|
||||
assert os.path.exists(run_results_path)
|
||||
assert os.path.exists(catalog_path)
|
||||
assert len(catalog.nodes) == 0
|
||||
assert len(catalog.sources) == 0
|
||||
|
||||
|
||||
class TestEmpty(BaseEmpty):
|
||||
pass
|
||||
70
tests/adapter/dbt/tests/adapter/basic/test_ephemeral.py
Normal file
70
tests/adapter/dbt/tests/adapter/basic/test_ephemeral.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import pytest
|
||||
import os
|
||||
from dbt.tests.util import (
|
||||
run_dbt,
|
||||
get_manifest,
|
||||
check_relations_equal,
|
||||
check_result_nodes_by_name,
|
||||
relation_from_name,
|
||||
)
|
||||
from dbt.tests.adapter.basic.files import (
|
||||
seeds_base_csv,
|
||||
base_ephemeral_sql,
|
||||
ephemeral_view_sql,
|
||||
ephemeral_table_sql,
|
||||
schema_base_yml,
|
||||
)
|
||||
|
||||
|
||||
class BaseEphemeral:
|
||||
@pytest.fixture(scope="class")
|
||||
def project_config_update(self):
|
||||
return {"name": "ephemeral"}
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def seeds(self):
|
||||
return {"base.csv": seeds_base_csv}
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def models(self):
|
||||
return {
|
||||
"ephemeral.sql": base_ephemeral_sql,
|
||||
"view_model.sql": ephemeral_view_sql,
|
||||
"table_model.sql": ephemeral_table_sql,
|
||||
"schema.yml": schema_base_yml,
|
||||
}
|
||||
|
||||
def test_ephemeral(self, project):
|
||||
# seed command
|
||||
results = run_dbt(["seed"])
|
||||
assert len(results) == 1
|
||||
check_result_nodes_by_name(results, ["base"])
|
||||
|
||||
# run command
|
||||
results = run_dbt(["run"])
|
||||
assert len(results) == 2
|
||||
check_result_nodes_by_name(results, ["view_model", "table_model"])
|
||||
|
||||
# base table rowcount
|
||||
relation = relation_from_name(project.adapter, "base")
|
||||
result = project.run_sql(f"select count(*) as num_rows from {relation}", fetch="one")
|
||||
assert result[0] == 10
|
||||
|
||||
# relations equal
|
||||
check_relations_equal(project.adapter, ["base", "view_model", "table_model"])
|
||||
|
||||
# catalog node count
|
||||
catalog = run_dbt(["docs", "generate"])
|
||||
catalog_path = os.path.join(project.project_root, "target", "catalog.json")
|
||||
assert os.path.exists(catalog_path)
|
||||
assert len(catalog.nodes) == 3
|
||||
assert len(catalog.sources) == 1
|
||||
|
||||
# manifest (not in original)
|
||||
manifest = get_manifest(project.project_root)
|
||||
assert len(manifest.nodes) == 4
|
||||
assert len(manifest.sources) == 1
|
||||
|
||||
|
||||
class TestEphemeral(BaseEphemeral):
|
||||
pass
|
||||
54
tests/adapter/dbt/tests/adapter/basic/test_generic_tests.py
Normal file
54
tests/adapter/dbt/tests/adapter/basic/test_generic_tests.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import pytest
|
||||
from dbt.tests.util import run_dbt
|
||||
from dbt.tests.adapter.basic.files import (
|
||||
seeds_base_csv,
|
||||
generic_test_seed_yml,
|
||||
base_view_sql,
|
||||
base_table_sql,
|
||||
schema_base_yml,
|
||||
generic_test_view_yml,
|
||||
generic_test_table_yml,
|
||||
)
|
||||
|
||||
|
||||
class BaseGenericTests:
|
||||
@pytest.fixture(scope="class")
|
||||
def project_config_update(self):
|
||||
return {"name": "generic_tests"}
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def seeds(self):
|
||||
return {
|
||||
"base.csv": seeds_base_csv,
|
||||
"schema.yml": generic_test_seed_yml,
|
||||
}
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def models(self):
|
||||
return {
|
||||
"view_model.sql": base_view_sql,
|
||||
"table_model.sql": base_table_sql,
|
||||
"schema.yml": schema_base_yml,
|
||||
"schema_view.yml": generic_test_view_yml,
|
||||
"schema_table.yml": generic_test_table_yml,
|
||||
}
|
||||
|
||||
def test_generic_tests(self, project):
|
||||
# seed command
|
||||
results = run_dbt(["seed"])
|
||||
|
||||
# test command selecting base model
|
||||
results = run_dbt(["test", "-m", "base"])
|
||||
assert len(results) == 1
|
||||
|
||||
# run command
|
||||
results = run_dbt(["run"])
|
||||
assert len(results) == 2
|
||||
|
||||
# test command, all tests
|
||||
results = run_dbt(["test"])
|
||||
assert len(results) == 3
|
||||
|
||||
|
||||
class TestGenericTests(BaseGenericTests):
|
||||
pass
|
||||
62
tests/adapter/dbt/tests/adapter/basic/test_incremental.py
Normal file
62
tests/adapter/dbt/tests/adapter/basic/test_incremental.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import pytest
|
||||
from dbt.tests.util import run_dbt, check_relations_equal, relation_from_name
|
||||
from dbt.tests.adapter.basic.files import (
|
||||
seeds_base_csv,
|
||||
seeds_added_csv,
|
||||
schema_base_yml,
|
||||
incremental_sql,
|
||||
)
|
||||
|
||||
|
||||
class BaseIncremental:
|
||||
@pytest.fixture(scope="class")
|
||||
def project_config_update(self):
|
||||
return {"name": "incremental"}
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def models(self):
|
||||
return {"incremental.sql": incremental_sql, "schema.yml": schema_base_yml}
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def seeds(self):
|
||||
return {"base.csv": seeds_base_csv, "added.csv": seeds_added_csv}
|
||||
|
||||
def test_incremental(self, project):
|
||||
# seed command
|
||||
results = run_dbt(["seed"])
|
||||
assert len(results) == 2
|
||||
|
||||
# base table rowcount
|
||||
relation = relation_from_name(project.adapter, "base")
|
||||
result = project.run_sql(f"select count(*) as num_rows from {relation}", fetch="one")
|
||||
assert result[0] == 10
|
||||
|
||||
# added table rowcount
|
||||
relation = relation_from_name(project.adapter, "added")
|
||||
result = project.run_sql(f"select count(*) as num_rows from {relation}", fetch="one")
|
||||
assert result[0] == 20
|
||||
|
||||
# run command
|
||||
# the "seed_name" var changes the seed identifier in the schema file
|
||||
results = run_dbt(["run", "--vars", "seed_name: base"])
|
||||
assert len(results) == 1
|
||||
|
||||
# check relations equal
|
||||
check_relations_equal(project.adapter, ["base", "incremental"])
|
||||
|
||||
# change seed_name var
|
||||
# the "seed_name" var changes the seed identifier in the schema file
|
||||
results = run_dbt(["run", "--vars", "seed_name: added"])
|
||||
assert len(results) == 1
|
||||
|
||||
# check relations equal
|
||||
check_relations_equal(project.adapter, ["added", "incremental"])
|
||||
|
||||
# get catalog from docs generate
|
||||
catalog = run_dbt(["docs", "generate"])
|
||||
assert len(catalog.nodes) == 3
|
||||
assert len(catalog.sources) == 1
|
||||
|
||||
|
||||
class Testincremental(BaseIncremental):
|
||||
pass
|
||||
38
tests/adapter/dbt/tests/adapter/basic/test_singular_tests.py
Normal file
38
tests/adapter/dbt/tests/adapter/basic/test_singular_tests.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import pytest
|
||||
from dbt.tests.adapter.basic.files import (
|
||||
test_passing_sql,
|
||||
test_failing_sql,
|
||||
)
|
||||
from dbt.tests.util import check_result_nodes_by_name, run_dbt
|
||||
|
||||
|
||||
class BaseSingularTests:
|
||||
@pytest.fixture(scope="class")
|
||||
def tests(self):
|
||||
return {
|
||||
"passing.sql": test_passing_sql,
|
||||
"failing.sql": test_failing_sql,
|
||||
}
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def project_config_update(self):
|
||||
return {"name": "singular_tests"}
|
||||
|
||||
def test_singular_tests(self, project):
|
||||
# test command
|
||||
results = run_dbt(["test"])
|
||||
assert len(results) == 2
|
||||
|
||||
# We have the right result nodes
|
||||
check_result_nodes_by_name(results, ["passing", "failing"])
|
||||
|
||||
# Check result status
|
||||
for result in results:
|
||||
if result.node.name == "passing":
|
||||
assert result.status == "pass"
|
||||
elif result.node.name == "failing":
|
||||
assert result.status == "fail"
|
||||
|
||||
|
||||
class TestSingularTests(BaseSingularTests):
|
||||
pass
|
||||
@@ -0,0 +1,67 @@
|
||||
import pytest
|
||||
|
||||
from dbt.tests.util import run_dbt, check_result_nodes_by_name
|
||||
from dbt.tests.adapter.basic.files import (
|
||||
seeds_base_csv,
|
||||
ephemeral_with_cte_sql,
|
||||
test_ephemeral_passing_sql,
|
||||
test_ephemeral_failing_sql,
|
||||
schema_base_yml,
|
||||
)
|
||||
|
||||
|
||||
class BaseSingularTestsEphemeral:
|
||||
@pytest.fixture(scope="class")
|
||||
def seeds(self):
|
||||
return {
|
||||
"base.csv": seeds_base_csv,
|
||||
}
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def models(self):
|
||||
return {
|
||||
"ephemeral.sql": ephemeral_with_cte_sql,
|
||||
"passing_model.sql": test_ephemeral_passing_sql,
|
||||
"failing_model.sql": test_ephemeral_failing_sql,
|
||||
"schema.yml": schema_base_yml,
|
||||
}
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def tests(self):
|
||||
return {
|
||||
"passing.sql": test_ephemeral_passing_sql,
|
||||
"failing.sql": test_ephemeral_failing_sql,
|
||||
}
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def project_config_update(self):
|
||||
return {
|
||||
"name": "singular_tests_ephemeral",
|
||||
}
|
||||
|
||||
def test_singular_tests_ephemeral(self, project):
|
||||
# check results from seed command
|
||||
results = run_dbt(["seed"])
|
||||
assert len(results) == 1
|
||||
check_result_nodes_by_name(results, ["base"])
|
||||
|
||||
# Check results from test command
|
||||
results = run_dbt(["test"])
|
||||
assert len(results) == 2
|
||||
check_result_nodes_by_name(results, ["passing", "failing"])
|
||||
|
||||
# Check result status
|
||||
for result in results:
|
||||
if result.node.name == "passing":
|
||||
assert result.status == "pass"
|
||||
elif result.node.name == "failing":
|
||||
assert result.status == "fail"
|
||||
|
||||
# check results from run command
|
||||
results = run_dbt()
|
||||
assert len(results) == 2
|
||||
check_result_nodes_by_name(results, ["failing_model", "passing_model"])
|
||||
|
||||
|
||||
class TestSingularTestsEphemeral(BaseSingularTestsEphemeral):
|
||||
pass
|
||||
@@ -0,0 +1,112 @@
|
||||
import pytest
|
||||
from dbt.tests.util import run_dbt, update_rows, relation_from_name
|
||||
from dbt.tests.adapter.basic.files import (
|
||||
seeds_base_csv,
|
||||
seeds_added_csv,
|
||||
cc_all_snapshot_sql,
|
||||
cc_date_snapshot_sql,
|
||||
cc_name_snapshot_sql,
|
||||
)
|
||||
|
||||
|
||||
def check_relation_rows(project, snapshot_name, count):
|
||||
relation = relation_from_name(project.adapter, snapshot_name)
|
||||
result = project.run_sql(f"select count(*) as num_rows from {relation}", fetch="one")
|
||||
assert result[0] == count
|
||||
|
||||
|
||||
class BaseSnapshotCheckCols:
|
||||
@pytest.fixture(scope="class")
|
||||
def project_config_update(self):
|
||||
return {"name": "snapshot_strategy_check_cols"}
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def seeds(self):
|
||||
return {
|
||||
"base.csv": seeds_base_csv,
|
||||
"added.csv": seeds_added_csv,
|
||||
}
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def snapshots(self):
|
||||
return {
|
||||
"cc_all_snapshot.sql": cc_all_snapshot_sql,
|
||||
"cc_date_snapshot.sql": cc_date_snapshot_sql,
|
||||
"cc_name_snapshot.sql": cc_name_snapshot_sql,
|
||||
}
|
||||
|
||||
def test_snapshot_check_cols(self, project):
|
||||
# seed command
|
||||
results = run_dbt(["seed"])
|
||||
assert len(results) == 2
|
||||
|
||||
# snapshot command
|
||||
results = run_dbt(["snapshot"])
|
||||
for result in results:
|
||||
assert result.status == "success"
|
||||
|
||||
# check rowcounts for all snapshots
|
||||
check_relation_rows(project, "cc_all_snapshot", 10)
|
||||
check_relation_rows(project, "cc_name_snapshot", 10)
|
||||
check_relation_rows(project, "cc_date_snapshot", 10)
|
||||
|
||||
relation = relation_from_name(project.adapter, "cc_all_snapshot")
|
||||
result = project.run_sql(f"select * from {relation}", fetch="all")
|
||||
|
||||
# point at the "added" seed so the snapshot sees 10 new rows
|
||||
results = run_dbt(["--no-partial-parse", "snapshot", "--vars", "seed_name: added"])
|
||||
for result in results:
|
||||
assert result.status == "success"
|
||||
|
||||
# check rowcounts for all snapshots
|
||||
check_relation_rows(project, "cc_all_snapshot", 20)
|
||||
check_relation_rows(project, "cc_name_snapshot", 20)
|
||||
check_relation_rows(project, "cc_date_snapshot", 20)
|
||||
|
||||
# update some timestamps in the "added" seed so the snapshot sees 10 more new rows
|
||||
update_rows_config = {
|
||||
"name": "added",
|
||||
"dst_col": "some_date",
|
||||
"clause": {"src_col": "some_date", "type": "add_timestamp"},
|
||||
"where": "id > 10 and id < 21",
|
||||
}
|
||||
update_rows(project.adapter, update_rows_config)
|
||||
|
||||
# re-run snapshots, using "added'
|
||||
results = run_dbt(["snapshot", "--vars", "seed_name: added"])
|
||||
for result in results:
|
||||
assert result.status == "success"
|
||||
|
||||
# check rowcounts for all snapshots
|
||||
check_relation_rows(project, "cc_all_snapshot", 30)
|
||||
check_relation_rows(project, "cc_date_snapshot", 30)
|
||||
# unchanged: only the timestamp changed
|
||||
check_relation_rows(project, "cc_name_snapshot", 20)
|
||||
|
||||
# Update the name column
|
||||
update_rows_config = {
|
||||
"name": "added",
|
||||
"dst_col": "name",
|
||||
"clause": {
|
||||
"src_col": "name",
|
||||
"type": "add_string",
|
||||
"value": "_updated",
|
||||
},
|
||||
"where": "id < 11",
|
||||
}
|
||||
update_rows(project.adapter, update_rows_config)
|
||||
|
||||
# re-run snapshots, using "added'
|
||||
results = run_dbt(["snapshot", "--vars", "seed_name: added"])
|
||||
for result in results:
|
||||
assert result.status == "success"
|
||||
|
||||
# check rowcounts for all snapshots
|
||||
check_relation_rows(project, "cc_all_snapshot", 40)
|
||||
check_relation_rows(project, "cc_name_snapshot", 30)
|
||||
# does not see name updates
|
||||
check_relation_rows(project, "cc_date_snapshot", 30)
|
||||
|
||||
|
||||
class TestSnapshotCheckCols(BaseSnapshotCheckCols):
|
||||
pass
|
||||
@@ -0,0 +1,90 @@
|
||||
import pytest
|
||||
from dbt.tests.util import run_dbt, relation_from_name, update_rows
|
||||
from dbt.tests.adapter.basic.files import (
|
||||
seeds_base_csv,
|
||||
seeds_newcolumns_csv,
|
||||
seeds_added_csv,
|
||||
ts_snapshot_sql,
|
||||
)
|
||||
|
||||
|
||||
def check_relation_rows(project, snapshot_name, count):
|
||||
relation = relation_from_name(project.adapter, snapshot_name)
|
||||
result = project.run_sql(f"select count(*) as num_rows from {relation}", fetch="one")
|
||||
assert result[0] == count
|
||||
|
||||
|
||||
class BaseSnapshotTimestamp:
|
||||
@pytest.fixture(scope="class")
|
||||
def seeds(self):
|
||||
return {
|
||||
"base.csv": seeds_base_csv,
|
||||
"newcolumns.csv": seeds_newcolumns_csv,
|
||||
"added.csv": seeds_added_csv,
|
||||
}
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def snapshots(self):
|
||||
return {
|
||||
"ts_snapshot.sql": ts_snapshot_sql,
|
||||
}
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def project_config_update(self):
|
||||
return {"name": "snapshot_strategy_timestamp"}
|
||||
|
||||
def test_snapshot_timestamp(self, project):
|
||||
# seed command
|
||||
results = run_dbt(["seed"])
|
||||
assert len(results) == 3
|
||||
|
||||
# snapshot command
|
||||
results = run_dbt(["snapshot"])
|
||||
assert len(results) == 1
|
||||
|
||||
# snapshot has 10 rows
|
||||
check_relation_rows(project, "ts_snapshot", 10)
|
||||
|
||||
# point at the "added" seed so the snapshot sees 10 new rows
|
||||
results = run_dbt(["snapshot", "--vars", "seed_name: added"])
|
||||
|
||||
# snapshot now has 20 rows
|
||||
check_relation_rows(project, "ts_snapshot", 20)
|
||||
|
||||
# update some timestamps in the "added" seed so the snapshot sees 10 more new rows
|
||||
update_rows_config = {
|
||||
"name": "added",
|
||||
"dst_col": "some_date",
|
||||
"clause": {
|
||||
"src_col": "some_date",
|
||||
"type": "add_timestamp",
|
||||
},
|
||||
"where": "id > 10 and id < 21",
|
||||
}
|
||||
update_rows(project.adapter, update_rows_config)
|
||||
|
||||
results = run_dbt(["snapshot", "--vars", "seed_name: added"])
|
||||
|
||||
# snapshot now has 30 rows
|
||||
check_relation_rows(project, "ts_snapshot", 30)
|
||||
|
||||
update_rows_config = {
|
||||
"name": "added",
|
||||
"dst_col": "name",
|
||||
"clause": {
|
||||
"src_col": "name",
|
||||
"type": "add_string",
|
||||
"value": "_updated",
|
||||
},
|
||||
"where": "id < 11",
|
||||
}
|
||||
update_rows(project.adapter, update_rows_config)
|
||||
|
||||
results = run_dbt(["snapshot", "--vars", "seed_name: added"])
|
||||
|
||||
# snapshot still has 30 rows because timestamp not updated
|
||||
check_relation_rows(project, "ts_snapshot", 30)
|
||||
|
||||
|
||||
class TestSnapshotTimestamp(BaseSnapshotTimestamp):
|
||||
pass
|
||||
79
tests/adapter/setup.py
Normal file
79
tests/adapter/setup.py
Normal file
@@ -0,0 +1,79 @@
|
||||
#!/usr/bin/env python
|
||||
import os
|
||||
import sys
|
||||
|
||||
if sys.version_info < (3, 7):
|
||||
print("Error: dbt does not support this version of Python.")
|
||||
print("Please upgrade to Python 3.7 or higher.")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
from setuptools import setup
|
||||
|
||||
try:
|
||||
from setuptools import find_namespace_packages
|
||||
except ImportError:
|
||||
# the user has a downlevel version of setuptools.
|
||||
print("Error: dbt requires setuptools v40.1.0 or higher.")
|
||||
print('Please upgrade setuptools with "pip install --upgrade setuptools" ' "and try again")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
PSYCOPG2_MESSAGE = """
|
||||
No package name override was set.
|
||||
Using 'psycopg2-binary' package to satisfy 'psycopg2'
|
||||
|
||||
If you experience segmentation faults, silent crashes, or installation errors,
|
||||
consider retrying with the 'DBT_PSYCOPG2_NAME' environment variable set to
|
||||
'psycopg2'. It may require a compiler toolchain and development libraries!
|
||||
""".strip()
|
||||
|
||||
|
||||
def _dbt_psycopg2_name():
|
||||
# if the user chose something, use that
|
||||
package_name = os.getenv("DBT_PSYCOPG2_NAME", "")
|
||||
if package_name:
|
||||
return package_name
|
||||
|
||||
# default to psycopg2-binary for all OSes/versions
|
||||
print(PSYCOPG2_MESSAGE)
|
||||
return "psycopg2-binary"
|
||||
|
||||
|
||||
package_name = "dbt-tests-adapter"
|
||||
package_version = "1.0.1"
|
||||
description = """The dbt adapter tests for adapter plugins"""
|
||||
|
||||
this_directory = os.path.abspath(os.path.dirname(__file__))
|
||||
with open(os.path.join(this_directory, "README.md")) as f:
|
||||
long_description = f.read()
|
||||
|
||||
DBT_PSYCOPG2_NAME = _dbt_psycopg2_name()
|
||||
|
||||
setup(
|
||||
name=package_name,
|
||||
version=package_version,
|
||||
description=description,
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
author="dbt Labs",
|
||||
author_email="info@dbtlabs.com",
|
||||
url="https://github.com/dbt-labs/dbt-tests-adapter",
|
||||
packages=find_namespace_packages(include=["dbt", "dbt.*"]),
|
||||
install_requires=[
|
||||
"dbt-core=={}".format(package_version),
|
||||
"{}~=2.8".format(DBT_PSYCOPG2_NAME),
|
||||
],
|
||||
zip_safe=False,
|
||||
classifiers=[
|
||||
"Development Status :: 5 - Production/Stable",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Operating System :: Microsoft :: Windows",
|
||||
"Operating System :: MacOS :: MacOS X",
|
||||
"Operating System :: POSIX :: Linux",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
],
|
||||
python_requires=">=3.7",
|
||||
)
|
||||
@@ -1,99 +0,0 @@
|
||||
import pytest
|
||||
from dbt.tests.util import run_dbt
|
||||
from dbt.tests.adapter import check_result_nodes_by_name
|
||||
from dbt.tests.adapter import relation_from_name, check_relation_types, check_relations_equal
|
||||
from tests.functional.adapter.files import (
|
||||
seeds_base_csv,
|
||||
base_view_sql,
|
||||
base_table_sql,
|
||||
base_materialized_var_sql,
|
||||
schema_base_yml,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def models():
|
||||
return {
|
||||
"view_model.sql": base_view_sql,
|
||||
"table_model.sql": base_table_sql,
|
||||
"swappable.sql": base_materialized_var_sql,
|
||||
"schema.yml": schema_base_yml,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def seeds():
|
||||
return {
|
||||
"base.csv": seeds_base_csv,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def project_config_update():
|
||||
return {
|
||||
"name": "base",
|
||||
}
|
||||
|
||||
|
||||
def test_base(project):
|
||||
|
||||
# seed command
|
||||
results = run_dbt(["seed"])
|
||||
# seed result length
|
||||
assert len(results) == 1
|
||||
|
||||
# run command
|
||||
results = run_dbt()
|
||||
# run result length
|
||||
assert len(results) == 3
|
||||
|
||||
# names exist in result nodes
|
||||
check_result_nodes_by_name(results, ["view_model", "table_model", "swappable"])
|
||||
|
||||
# check relation types
|
||||
expected = {
|
||||
"base": "table",
|
||||
"view_model": "view",
|
||||
"table_model": "table",
|
||||
"swappable": "table",
|
||||
}
|
||||
check_relation_types(project.adapter, expected)
|
||||
|
||||
# base table rowcount
|
||||
relation = relation_from_name(project.adapter, "base")
|
||||
result = project.run_sql(f"select count(*) as num_rows from {relation}", fetch="one")
|
||||
assert result[0] == 10
|
||||
|
||||
# relations_equal
|
||||
check_relations_equal(project.adapter, ["base", "view_model", "table_model", "swappable"])
|
||||
|
||||
# check relations in catalog
|
||||
catalog = run_dbt(["docs", "generate"])
|
||||
assert len(catalog.nodes) == 4
|
||||
assert len(catalog.sources) == 1
|
||||
|
||||
# run_dbt changing materialized_var to view
|
||||
results = run_dbt(["run", "-m", "swappable", "--vars", "materialized_var: view"])
|
||||
assert len(results) == 1
|
||||
|
||||
# check relation types, swappable is view
|
||||
expected = {
|
||||
"base": "table",
|
||||
"view_model": "view",
|
||||
"table_model": "table",
|
||||
"swappable": "view",
|
||||
}
|
||||
check_relation_types(project.adapter, expected)
|
||||
|
||||
# run_dbt changing materialized_var to incremental
|
||||
results = run_dbt(["run", "-m", "swappable", "--vars", "materialized_var: incremental"])
|
||||
assert len(results) == 1
|
||||
|
||||
# check relation types, swappable is table
|
||||
expected = {
|
||||
"base": "table",
|
||||
"view_model": "view",
|
||||
"table_model": "table",
|
||||
"swappable": "table",
|
||||
}
|
||||
check_relation_types(project.adapter, expected)
|
||||
@@ -1,67 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from dbt.tests.util import run_dbt
|
||||
from dbt.tests.adapter import check_result_nodes_by_name
|
||||
from tests.functional.adapter.files import (
|
||||
seeds_base_csv,
|
||||
ephemeral_with_cte_sql,
|
||||
test_ephemeral_passing_sql,
|
||||
test_ephemeral_failing_sql,
|
||||
schema_base_yml,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def seeds():
|
||||
return {
|
||||
"base.csv": seeds_base_csv,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def models():
|
||||
return {
|
||||
"ephemeral.sql": ephemeral_with_cte_sql,
|
||||
"passing_model.sql": test_ephemeral_passing_sql,
|
||||
"failing_model.sql": test_ephemeral_failing_sql,
|
||||
"schema.yml": schema_base_yml,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def tests():
|
||||
return {
|
||||
"passing.sql": test_ephemeral_passing_sql,
|
||||
"failing.sql": test_ephemeral_failing_sql,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def project_config_update():
|
||||
return {
|
||||
"name": "data_test_ephemeral_models",
|
||||
}
|
||||
|
||||
|
||||
def test_data_test_ephemerals(project):
|
||||
# check results from seed command
|
||||
results = run_dbt(["seed"])
|
||||
assert len(results) == 1
|
||||
check_result_nodes_by_name(results, ["base"])
|
||||
|
||||
# Check results from test command
|
||||
results = run_dbt(["test"])
|
||||
assert len(results) == 2
|
||||
check_result_nodes_by_name(results, ["passing", "failing"])
|
||||
|
||||
# Check result status
|
||||
for result in results:
|
||||
if result.node.name == "passing":
|
||||
assert result.status == "pass"
|
||||
elif result.node.name == "failing":
|
||||
assert result.status == "fail"
|
||||
|
||||
# check results from run command
|
||||
results = run_dbt()
|
||||
assert len(results) == 2
|
||||
check_result_nodes_by_name(results, ["failing_model", "passing_model"])
|
||||
@@ -1,36 +0,0 @@
|
||||
import pytest
|
||||
from tests.functional.adapter.files import (
|
||||
test_passing_sql,
|
||||
test_failing_sql,
|
||||
)
|
||||
from dbt.tests.adapter import check_result_nodes_by_name
|
||||
from dbt.tests.util import run_dbt
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def tests():
|
||||
return {
|
||||
"passing.sql": test_passing_sql,
|
||||
"failing.sql": test_failing_sql,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def project_config_update():
|
||||
return {"name": "data_tests"}
|
||||
|
||||
|
||||
def test_data_tests(project):
|
||||
# test command
|
||||
results = run_dbt(["test"])
|
||||
assert len(results) == 2
|
||||
|
||||
# We have the right result nodes
|
||||
check_result_nodes_by_name(results, ["passing", "failing"])
|
||||
|
||||
# Check result status
|
||||
for result in results:
|
||||
if result.node.name == "passing":
|
||||
assert result.status == "pass"
|
||||
elif result.node.name == "failing":
|
||||
assert result.status == "fail"
|
||||
@@ -1,24 +0,0 @@
|
||||
from dbt.tests.util import run_dbt
|
||||
import os
|
||||
|
||||
|
||||
def test_empty(project):
|
||||
# check seed
|
||||
results = run_dbt(["seed"])
|
||||
assert len(results) == 0
|
||||
run_results_path = os.path.join(project.project_root, "target", "run_results.json")
|
||||
assert os.path.exists(run_results_path)
|
||||
|
||||
# check run
|
||||
results = run_dbt(["run"])
|
||||
assert len(results) == 0
|
||||
|
||||
catalog_path = os.path.join(project.project_root, "target", "catalog.json")
|
||||
assert not os.path.exists(catalog_path)
|
||||
|
||||
# check catalog
|
||||
catalog = run_dbt(["docs", "generate"])
|
||||
assert os.path.exists(run_results_path)
|
||||
assert os.path.exists(catalog_path)
|
||||
assert len(catalog.nodes) == 0
|
||||
assert len(catalog.sources) == 0
|
||||
@@ -1,63 +0,0 @@
|
||||
import pytest
|
||||
import os
|
||||
from dbt.tests.util import run_dbt, get_manifest
|
||||
from dbt.tests.adapter import check_relations_equal, check_result_nodes_by_name, relation_from_name
|
||||
from tests.functional.adapter.files import (
|
||||
seeds_base_csv,
|
||||
base_ephemeral_sql,
|
||||
ephemeral_view_sql,
|
||||
ephemeral_table_sql,
|
||||
schema_base_yml,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def project_config_update():
|
||||
return {"name": "ephemeral"}
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def seeds():
|
||||
return {"base.csv": seeds_base_csv}
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def models():
|
||||
return {
|
||||
"ephemeral.sql": base_ephemeral_sql,
|
||||
"view_model.sql": ephemeral_view_sql,
|
||||
"table_model.sql": ephemeral_table_sql,
|
||||
"schema.yml": schema_base_yml,
|
||||
}
|
||||
|
||||
|
||||
def test_ephemeral(project):
|
||||
# seed command
|
||||
results = run_dbt(["seed"])
|
||||
assert len(results) == 1
|
||||
check_result_nodes_by_name(results, ["base"])
|
||||
|
||||
# run command
|
||||
results = run_dbt(["run"])
|
||||
assert len(results) == 2
|
||||
check_result_nodes_by_name(results, ["view_model", "table_model"])
|
||||
|
||||
# base table rowcount
|
||||
relation = relation_from_name(project.adapter, "base")
|
||||
result = project.run_sql(f"select count(*) as num_rows from {relation}", fetch="one")
|
||||
assert result[0] == 10
|
||||
|
||||
# relations equal
|
||||
check_relations_equal(project.adapter, ["base", "view_model", "table_model"])
|
||||
|
||||
# catalog node count
|
||||
catalog = run_dbt(["docs", "generate"])
|
||||
catalog_path = os.path.join(project.project_root, "target", "catalog.json")
|
||||
assert os.path.exists(catalog_path)
|
||||
assert len(catalog.nodes) == 3
|
||||
assert len(catalog.sources) == 1
|
||||
|
||||
# manifest (not in original)
|
||||
manifest = get_manifest(project.project_root)
|
||||
assert len(manifest.nodes) == 4
|
||||
assert len(manifest.sources) == 1
|
||||
@@ -1,61 +0,0 @@
|
||||
import pytest
|
||||
from dbt.tests.util import run_dbt
|
||||
from tests.functional.adapter.files import (
|
||||
seeds_base_csv,
|
||||
seeds_added_csv,
|
||||
schema_base_yml,
|
||||
incremental_sql,
|
||||
)
|
||||
from dbt.tests.adapter import check_relations_equal, relation_from_name
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def project_config_update():
|
||||
return {"name": "incremental"}
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def models():
|
||||
return {"incremental.sql": incremental_sql, "schema.yml": schema_base_yml}
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def seeds():
|
||||
return {"base.csv": seeds_base_csv, "added.csv": seeds_added_csv}
|
||||
|
||||
|
||||
def test_incremental(project):
|
||||
# seed command
|
||||
results = run_dbt(["seed"])
|
||||
assert len(results) == 2
|
||||
|
||||
# base table rowcount
|
||||
relation = relation_from_name(project.adapter, "base")
|
||||
result = project.run_sql(f"select count(*) as num_rows from {relation}", fetch="one")
|
||||
assert result[0] == 10
|
||||
|
||||
# added table rowcount
|
||||
relation = relation_from_name(project.adapter, "added")
|
||||
result = project.run_sql(f"select count(*) as num_rows from {relation}", fetch="one")
|
||||
assert result[0] == 20
|
||||
|
||||
# run command
|
||||
# the "seed_name" var changes the seed identifier in the schema file
|
||||
results = run_dbt(["run", "--vars", "seed_name: base"])
|
||||
assert len(results) == 1
|
||||
|
||||
# check relations equal
|
||||
check_relations_equal(project.adapter, ["base", "incremental"])
|
||||
|
||||
# change seed_name var
|
||||
# the "seed_name" var changes the seed identifier in the schema file
|
||||
results = run_dbt(["run", "--vars", "seed_name: added"])
|
||||
assert len(results) == 1
|
||||
|
||||
# check relations equal
|
||||
check_relations_equal(project.adapter, ["added", "incremental"])
|
||||
|
||||
# get catalog from docs generate
|
||||
catalog = run_dbt(["docs", "generate"])
|
||||
assert len(catalog.nodes) == 3
|
||||
assert len(catalog.sources) == 1
|
||||
@@ -1,52 +0,0 @@
|
||||
import pytest
|
||||
from dbt.tests.util import run_dbt
|
||||
from tests.functional.adapter.files import (
|
||||
seeds_base_csv,
|
||||
schema_test_seed_yml,
|
||||
base_view_sql,
|
||||
base_table_sql,
|
||||
schema_base_yml,
|
||||
schema_test_view_yml,
|
||||
schema_test_table_yml,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def project_config_update():
|
||||
return {"name": "schema_test"}
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def seeds():
|
||||
return {
|
||||
"base.csv": seeds_base_csv,
|
||||
"schema.yml": schema_test_seed_yml,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def models():
|
||||
return {
|
||||
"view_model.sql": base_view_sql,
|
||||
"table_model.sql": base_table_sql,
|
||||
"schema.yml": schema_base_yml,
|
||||
"schema_view.yml": schema_test_view_yml,
|
||||
"schema_table.yml": schema_test_table_yml,
|
||||
}
|
||||
|
||||
|
||||
def test_schema_tests(project):
|
||||
# seed command
|
||||
results = run_dbt(["seed"])
|
||||
|
||||
# test command selecting base model
|
||||
results = run_dbt(["test", "-m", "base"])
|
||||
assert len(results) == 1
|
||||
|
||||
# run command
|
||||
results = run_dbt(["run"])
|
||||
assert len(results) == 2
|
||||
|
||||
# test command, all tests
|
||||
results = run_dbt(["test"])
|
||||
assert len(results) == 3
|
||||
@@ -1,111 +0,0 @@
|
||||
import pytest
|
||||
from dbt.tests.util import run_dbt
|
||||
from tests.functional.adapter.files import (
|
||||
seeds_base_csv,
|
||||
seeds_added_csv,
|
||||
cc_all_snapshot_sql,
|
||||
cc_date_snapshot_sql,
|
||||
cc_name_snapshot_sql,
|
||||
)
|
||||
from dbt.tests.adapter import update_rows, relation_from_name
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def project_config_update():
|
||||
return {"name": "snapshot_strategy_check_cols"}
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def seeds():
|
||||
return {
|
||||
"base.csv": seeds_base_csv,
|
||||
"added.csv": seeds_added_csv,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def snapshots():
|
||||
return {
|
||||
"cc_all_snapshot.sql": cc_all_snapshot_sql,
|
||||
"cc_date_snapshot.sql": cc_date_snapshot_sql,
|
||||
"cc_name_snapshot.sql": cc_name_snapshot_sql,
|
||||
}
|
||||
|
||||
|
||||
def check_relation_rows(project, snapshot_name, count):
|
||||
relation = relation_from_name(project.adapter, snapshot_name)
|
||||
result = project.run_sql(f"select count(*) as num_rows from {relation}", fetch="one")
|
||||
assert result[0] == count
|
||||
|
||||
|
||||
def test_snapshot_check_cols(project):
|
||||
# seed command
|
||||
results = run_dbt(["seed"])
|
||||
assert len(results) == 2
|
||||
|
||||
# snapshot command
|
||||
results = run_dbt(["snapshot"])
|
||||
for result in results:
|
||||
assert result.status == "success"
|
||||
|
||||
# check rowcounts for all snapshots
|
||||
check_relation_rows(project, "cc_all_snapshot", 10)
|
||||
check_relation_rows(project, "cc_name_snapshot", 10)
|
||||
check_relation_rows(project, "cc_date_snapshot", 10)
|
||||
|
||||
relation = relation_from_name(project.adapter, "cc_all_snapshot")
|
||||
result = project.run_sql(f"select * from {relation}", fetch="all")
|
||||
|
||||
# point at the "added" seed so the snapshot sees 10 new rows
|
||||
results = run_dbt(["--no-partial-parse", "snapshot", "--vars", "seed_name: added"])
|
||||
for result in results:
|
||||
assert result.status == "success"
|
||||
|
||||
# check rowcounts for all snapshots
|
||||
check_relation_rows(project, "cc_all_snapshot", 20)
|
||||
check_relation_rows(project, "cc_name_snapshot", 20)
|
||||
check_relation_rows(project, "cc_date_snapshot", 20)
|
||||
|
||||
# update some timestamps in the "added" seed so the snapshot sees 10 more new rows
|
||||
update_rows_config = {
|
||||
"name": "added",
|
||||
"dst_col": "some_date",
|
||||
"clause": {"src_col": "some_date", "type": "add_timestamp"},
|
||||
"where": "id > 10 and id < 21",
|
||||
}
|
||||
update_rows(project.adapter, update_rows_config)
|
||||
|
||||
# re-run snapshots, using "added'
|
||||
results = run_dbt(["snapshot", "--vars", "seed_name: added"])
|
||||
for result in results:
|
||||
assert result.status == "success"
|
||||
|
||||
# check rowcounts for all snapshots
|
||||
check_relation_rows(project, "cc_all_snapshot", 30)
|
||||
check_relation_rows(project, "cc_date_snapshot", 30)
|
||||
# unchanged: only the timestamp changed
|
||||
check_relation_rows(project, "cc_name_snapshot", 20)
|
||||
|
||||
# Update the name column
|
||||
update_rows_config = {
|
||||
"name": "added",
|
||||
"dst_col": "name",
|
||||
"clause": {
|
||||
"src_col": "name",
|
||||
"type": "add_string",
|
||||
"value": "_updated",
|
||||
},
|
||||
"where": "id < 11",
|
||||
}
|
||||
update_rows(project.adapter, update_rows_config)
|
||||
|
||||
# re-run snapshots, using "added'
|
||||
results = run_dbt(["snapshot", "--vars", "seed_name: added"])
|
||||
for result in results:
|
||||
assert result.status == "success"
|
||||
|
||||
# check rowcounts for all snapshots
|
||||
check_relation_rows(project, "cc_all_snapshot", 40)
|
||||
check_relation_rows(project, "cc_name_snapshot", 30)
|
||||
# does not see name updates
|
||||
check_relation_rows(project, "cc_date_snapshot", 30)
|
||||
@@ -1,89 +0,0 @@
|
||||
import pytest
|
||||
from dbt.tests.util import run_dbt
|
||||
from tests.functional.adapter.files import (
|
||||
seeds_base_csv,
|
||||
seeds_newcolumns_csv,
|
||||
seeds_added_csv,
|
||||
ts_snapshot_sql,
|
||||
)
|
||||
from dbt.tests.adapter import relation_from_name, update_rows
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def seeds():
|
||||
return {
|
||||
"base.csv": seeds_base_csv,
|
||||
"newcolumns.csv": seeds_newcolumns_csv,
|
||||
"added.csv": seeds_added_csv,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def snapshots():
|
||||
return {
|
||||
"ts_snapshot.sql": ts_snapshot_sql,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def project_config_update():
|
||||
return {"name": "snapshot_strategy_timestamp"}
|
||||
|
||||
|
||||
def check_relation_rows(project, snapshot_name, count):
|
||||
relation = relation_from_name(project.adapter, snapshot_name)
|
||||
result = project.run_sql(f"select count(*) as num_rows from {relation}", fetch="one")
|
||||
assert result[0] == count
|
||||
|
||||
|
||||
def test_snapshot_timestamp(project):
|
||||
# seed command
|
||||
results = run_dbt(["seed"])
|
||||
assert len(results) == 3
|
||||
|
||||
# snapshot command
|
||||
results = run_dbt(["snapshot"])
|
||||
assert len(results) == 1
|
||||
|
||||
# snapshot has 10 rows
|
||||
check_relation_rows(project, "ts_snapshot", 10)
|
||||
|
||||
# point at the "added" seed so the snapshot sees 10 new rows
|
||||
results = run_dbt(["snapshot", "--vars", "seed_name: added"])
|
||||
|
||||
# snapshot now has 20 rows
|
||||
check_relation_rows(project, "ts_snapshot", 20)
|
||||
|
||||
# update some timestamps in the "added" seed so the snapshot sees 10 more new rows
|
||||
update_rows_config = {
|
||||
"name": "added",
|
||||
"dst_col": "some_date",
|
||||
"clause": {
|
||||
"src_col": "some_date",
|
||||
"type": "add_timestamp",
|
||||
},
|
||||
"where": "id > 10 and id < 21",
|
||||
}
|
||||
update_rows(project.adapter, update_rows_config)
|
||||
|
||||
results = run_dbt(["snapshot", "--vars", "seed_name: added"])
|
||||
|
||||
# snapshot now has 30 rows
|
||||
check_relation_rows(project, "ts_snapshot", 30)
|
||||
|
||||
update_rows_config = {
|
||||
"name": "added",
|
||||
"dst_col": "name",
|
||||
"clause": {
|
||||
"src_col": "name",
|
||||
"type": "add_string",
|
||||
"value": "_updated",
|
||||
},
|
||||
"where": "id < 11",
|
||||
}
|
||||
update_rows(project.adapter, update_rows_config)
|
||||
|
||||
results = run_dbt(["snapshot", "--vars", "seed_name: added"])
|
||||
|
||||
# snapshot still has 30 rows because timestamp not updated
|
||||
check_relation_rows(project, "ts_snapshot", 30)
|
||||
@@ -158,7 +158,7 @@ class TestTagSelection(SelectionFixtures):
|
||||
"unique_users_rollup_gender",
|
||||
]
|
||||
|
||||
def test_select_tag_in_model_with_project_config_parents_children_selectors(project):
|
||||
def test_select_tag_in_model_with_project_config_parents_children_selectors(self, project):
|
||||
results = run_dbt(["run", "--selector", "user_tagged_childrens_parents"])
|
||||
assert len(results) == 4
|
||||
|
||||
|
||||
1
tox.ini
1
tox.ini
@@ -18,6 +18,7 @@ passenv = DBT_* POSTGRES_TEST_* PYTEST_ADDOPTS
|
||||
commands =
|
||||
{envpython} -m pytest --cov=core -m profile_postgres {posargs} test/integration
|
||||
{envpython} -m pytest --cov=core {posargs} tests/functional
|
||||
{envpython} -m pytest --cov=core {posargs} tests/adapter
|
||||
|
||||
deps =
|
||||
-rdev-requirements.txt
|
||||
|
||||
Reference in New Issue
Block a user