Compare commits

...

3 Commits

Author SHA1 Message Date
Gerda Shank
b5c79f9e09 Merge branch 'main' into ct-236-adapter_tests 2022-03-17 17:31:37 -04:00
Gerda Shank
424024ea9e Add 'test_config' to support adapter test customization 2022-03-17 14:35:22 -04:00
Gerda Shank
0c71b44e6b Setup adapter tests for easy integration into adapter repos 2022-03-16 10:43:14 -04:00
31 changed files with 1001 additions and 865 deletions

View File

@@ -0,0 +1,7 @@
kind: Features
body: Testing framework for dbt adapter testing
time: 2022-03-09T14:26:45.828295-05:00
custom:
Author: gshank
Issue: "4730"
PR: "4846"

View File

@@ -218,3 +218,29 @@ class SQLAdapter(BaseAdapter):
kwargs = {"information_schema": information_schema, "schema": schema}
results = self.execute_macro(CHECK_SCHEMA_EXISTS_MACRO_NAME, kwargs=kwargs)
return results[0][0] > 0
# This is for use in the test suite
def run_sql_for_tests(self, sql, fetch, conn):
cursor = conn.handle.cursor()
try:
cursor.execute(sql)
if hasattr(conn.handle, "commit"):
conn.handle.commit()
if fetch == "one":
if hasattr(cursor, "fetchone"): # for spark
return cursor.fetchone()
else:
# for spark
return cursor.fetchall()[0]
elif fetch == "all":
return cursor.fetchall()
else:
return
except BaseException as e:
if conn.handle and not getattr(conn.handle, "closed", True):
conn.handle.rollback()
print(sql)
print(e)
raise
finally:
conn.transaction_open = False

View File

@@ -1 +0,0 @@
# dbt.tests directory

View File

@@ -1,212 +0,0 @@
from itertools import chain, repeat
from dbt.context import providers
from unittest.mock import patch
# These functions were extracted from the dbt-adapter-tests spec_file.py.
# They are used in the 'adapter' tests directory. At some point they
# might be moved to dbts.tests.util if they are of general purpose use,
# but leaving here for now to keep the adapter work more contained.
# We may want to consolidate in the future since some of this is kind
# of duplicative of the functionality in dbt.tests.tables.
class TestProcessingException(Exception):
pass
def relation_from_name(adapter, name: str):
"""reverse-engineer a relation (including quoting) from a given name and
the adapter. Assumes that relations are split by the '.' character.
"""
# Different adapters have different Relation classes
cls = adapter.Relation
credentials = adapter.config.credentials
quote_policy = cls.get_default_quote_policy().to_dict()
include_policy = cls.get_default_include_policy().to_dict()
kwargs = {} # This will contain database, schema, identifier
parts = name.split(".")
names = ["database", "schema", "identifier"]
defaults = [credentials.database, credentials.schema, None]
values = chain(repeat(None, 3 - len(parts)), parts)
for name, value, default in zip(names, values, defaults):
# no quote policy -> use the default
if value is None:
if default is None:
include_policy[name] = False
value = default
else:
include_policy[name] = True
# if we have a value, we can figure out the quote policy.
trimmed = value[1:-1]
if adapter.quote(trimmed) == value:
quote_policy[name] = True
value = trimmed
else:
quote_policy[name] = False
kwargs[name] = value
relation = cls.create(
include_policy=include_policy,
quote_policy=quote_policy,
**kwargs,
)
return relation
def check_relation_types(adapter, relation_to_type):
"""
Relation name to table/view
{
"base": "table",
"other": "view",
}
"""
expected_relation_values = {}
found_relations = []
schemas = set()
for key, value in relation_to_type.items():
relation = relation_from_name(adapter, key)
expected_relation_values[relation] = value
schemas.add(relation.without_identifier())
with patch.object(providers, "get_adapter", return_value=adapter):
with adapter.connection_named("__test"):
for schema in schemas:
found_relations.extend(adapter.list_relations_without_caching(schema))
for key, value in relation_to_type.items():
for relation in found_relations:
# this might be too broad
if relation.identifier == key:
assert relation.type == value, (
f"Got an unexpected relation type of {relation.type} "
f"for relation {key}, expected {value}"
)
def check_relations_equal(adapter, relation_names):
if len(relation_names) < 2:
raise TestProcessingException(
"Not enough relations to compare",
)
relations = [relation_from_name(adapter, name) for name in relation_names]
with patch.object(providers, "get_adapter", return_value=adapter):
with adapter.connection_named("_test"):
basis, compares = relations[0], relations[1:]
columns = [c.name for c in adapter.get_columns_in_relation(basis)]
for relation in compares:
sql = adapter.get_rows_different_sql(basis, relation, column_names=columns)
_, tbl = adapter.execute(sql, fetch=True)
num_rows = len(tbl)
assert (
num_rows == 1
), f"Invalid sql query from get_rows_different_sql: incorrect number of rows ({num_rows})"
num_cols = len(tbl[0])
assert (
num_cols == 2
), f"Invalid sql query from get_rows_different_sql: incorrect number of cols ({num_cols})"
row_count_difference = tbl[0][0]
assert (
row_count_difference == 0
), f"Got {row_count_difference} difference in row count betwen {basis} and {relation}"
rows_mismatched = tbl[0][1]
assert (
rows_mismatched == 0
), f"Got {rows_mismatched} different rows between {basis} and {relation}"
def get_unique_ids_in_results(results):
unique_ids = []
for result in results:
unique_ids.append(result.node.unique_id)
return unique_ids
def check_result_nodes_by_name(results, names):
result_names = []
for result in results:
result_names.append(result.node.name)
assert set(names) == set(result_names)
def check_result_nodes_by_unique_id(results, unique_ids):
result_unique_ids = []
for result in results:
result_unique_ids.append(result.node.unique_id)
assert set(unique_ids) == set(result_unique_ids)
def update_rows(adapter, update_rows_config):
"""
{
"name": "base",
"dst_col": "some_date"
"clause": {
"type": "add_timestamp",
"src_col": "some_date",
"where" "id > 10"
}
"""
for key in ["name", "dst_col", "clause"]:
if key not in update_rows_config:
raise TestProcessingException(f"Invalid update_rows: no {key}")
clause = update_rows_config["clause"]
clause = generate_update_clause(adapter, clause)
where = None
if "where" in update_rows_config:
where = update_rows_config["where"]
name = update_rows_config["name"]
dst_col = update_rows_config["dst_col"]
relation = relation_from_name(adapter, name)
with patch.object(providers, "get_adapter", return_value=adapter):
with adapter.connection_named("_test"):
sql = adapter.update_column_sql(
dst_name=str(relation),
dst_column=dst_col,
clause=clause,
where_clause=where,
)
print(f"--- update_rows sql: {sql}")
adapter.execute(sql, auto_begin=True)
adapter.commit_if_has_connection()
def generate_update_clause(adapter, clause) -> str:
"""
Called by update_rows function. Expects the "clause" dictionary
documented in 'update_rows.
"""
if "type" not in clause or clause["type"] not in ["add_timestamp", "add_string"]:
raise TestProcessingException("invalid update_rows clause: type missing or incorrect")
clause_type = clause["type"]
if clause_type == "add_timestamp":
if "src_col" not in clause:
raise TestProcessingException("Invalid update_rows clause: no src_col")
add_to = clause["src_col"]
kwargs = {k: v for k, v in clause.items() if k in ("interval", "number")}
with patch.object(providers, "get_adapter", return_value=adapter):
with adapter.connection_named("_test"):
return adapter.timestamp_add_sql(add_to=add_to, **kwargs)
elif clause_type == "add_string":
for key in ["src_col", "value"]:
if key not in clause:
raise TestProcessingException(f"Invalid update_rows clause: no {key}")
src_col = clause["src_col"]
value = clause["value"]
location = clause.get("location", "append")
with patch.object(providers, "get_adapter", return_value=adapter):
with adapter.connection_named("_test"):
return adapter.string_add_sql(src_col, value, location)
return ""

View File

@@ -10,7 +10,7 @@ import dbt.flags as flags
from dbt.config.runtime import RuntimeConfig
from dbt.adapters.factory import get_adapter, register_adapter, reset_adapters
from dbt.events.functions import setup_event_logger
from dbt.tests.util import write_file, run_sql_with_adapter
from dbt.tests.util import write_file, run_sql_with_adapter, TestProcessingException
# These are the fixtures that are used in dbt core functional tests
@@ -63,6 +63,19 @@ def test_data_dir(request):
return os.path.join(request.fspath.dirname, "data")
@pytest.fixture(scope="class")
def dbt_profile_target():
return {
"type": "postgres",
"threads": 4,
"host": "localhost",
"port": int(os.getenv("POSTGRES_TEST_PORT", 5432)),
"user": os.getenv("POSTGRES_TEST_USER", "root"),
"pass": os.getenv("POSTGRES_TEST_PASS", "password"),
"dbname": os.getenv("POSTGRES_TEST_DATABASE", "dbt"),
}
# This fixture can be overridden in a project
@pytest.fixture(scope="class")
def profiles_config_update():
@@ -71,35 +84,19 @@ def profiles_config_update():
# The profile dictionary, used to write out profiles.yml
@pytest.fixture(scope="class")
def dbt_profile_data(unique_schema, profiles_config_update):
def dbt_profile_data(unique_schema, dbt_profile_target, profiles_config_update):
profile = {
"config": {"send_anonymous_usage_stats": False},
"test": {
"outputs": {
"default": {
"type": "postgres",
"threads": 4,
"host": "localhost",
"port": int(os.getenv("POSTGRES_TEST_PORT", 5432)),
"user": os.getenv("POSTGRES_TEST_USER", "root"),
"pass": os.getenv("POSTGRES_TEST_PASS", "password"),
"dbname": os.getenv("POSTGRES_TEST_DATABASE", "dbt"),
"schema": unique_schema,
},
"other_schema": {
"type": "postgres",
"threads": 4,
"host": "localhost",
"port": int(os.getenv("POSTGRES_TEST_PORT", 5432)),
"user": "noaccess",
"pass": "password",
"dbname": os.getenv("POSTGRES_TEST_DATABASE", "dbt"),
"schema": unique_schema + "_alt", # Should this be the same unique_schema?
},
"default": {},
},
"target": "default",
},
}
target = dbt_profile_target
target["schema"] = unique_schema
profile["test"]["outputs"]["default"] = target
if profiles_config_update:
profile.update(profiles_config_update)
@@ -199,6 +196,8 @@ def write_project_files(project_root, dir_name, file_dict):
# Write files out from file_dict. Can be nested directories...
def write_project_files_recursively(path, file_dict):
if type(file_dict) is not dict:
raise TestProcessingException(f"Error creating {path}. Did you forget the file extension?")
for name, value in file_dict.items():
if name.endswith(".sql") or name.endswith(".csv") or name.endswith(".md"):
write_file(value, path, name)
@@ -276,6 +275,7 @@ class TestProjInfo:
test_data_dir,
test_schema,
database,
test_config,
):
self.project_root = project_root
self.profiles_dir = profiles_dir
@@ -285,6 +285,7 @@ class TestProjInfo:
self.test_data_dir = test_data_dir
self.test_schema = test_schema
self.database = database
self.test_config = test_config
# Run sql from a path
def run_sql_file(self, sql_path, fetch=None):
@@ -313,6 +314,13 @@ class TestProjInfo:
return {model_name: materialization for (model_name, materialization) in result}
# This fixture is for customizing tests that need overrides in adapter
# repos. Example in dbt.tests.adapter.basic.test_base.
@pytest.fixture(scope="class")
def test_config():
return {}
@pytest.fixture(scope="class")
def project(
project_root,
@@ -328,6 +336,7 @@ def project(
shared_data_dir,
test_data_dir,
logs_dir,
test_config,
):
# Logbook warnings are ignored so we don't have to fork logbook to support python 3.10.
# This _only_ works for tests in `tests/` that use the project fixture.
@@ -345,8 +354,8 @@ def project(
shared_data_dir=shared_data_dir,
test_data_dir=test_data_dir,
test_schema=unique_schema,
# the following feels kind of fragile. TODO: better way of getting database
database=profiles_yml["test"]["outputs"]["default"]["dbname"],
database=adapter.config.credentials.database,
test_config=test_config,
)
project.run_sql("drop schema if exists {schema} cascade")
project.run_sql("create schema {schema}")

View File

@@ -46,7 +46,7 @@ def run_dbt_and_capture(args: List[str] = None, expect_pass=True):
# Used in test cases to get the manifest from the partial parsing file
def get_manifest(project_root):
path = project_root.join("target", "partial_parse.msgpack")
path = os.path.join(project_root, "target", "partial_parse.msgpack")
if os.path.exists(path):
with open(path, "rb") as fp:
manifest_mp = fp.read()
@@ -122,6 +122,8 @@ def run_sql_with_adapter(adapter, sql, fetch=None):
}
sql = sql.format(**kwargs)
msg = f'test connection "__test" executing: {sql}'
fire_event(IntegrationTestDebug(msg=msg))
# Since the 'adapter' in dbt.adapters.factory may have been replaced by execution
# of dbt commands since the test 'adapter' was created, we patch the 'get_adapter' call in
# dbt.context.providers, so that macros that are called refer to this test adapter.
@@ -130,24 +132,196 @@ def run_sql_with_adapter(adapter, sql, fetch=None):
with patch.object(providers, "get_adapter", return_value=adapter):
with adapter.connection_named("__test"):
conn = adapter.connections.get_thread_connection()
msg = f'test connection "{conn.name}" executing: {sql}'
fire_event(IntegrationTestDebug(msg=msg))
with conn.handle.cursor() as cursor:
try:
cursor.execute(sql)
conn.handle.commit()
conn.handle.commit()
if fetch == "one":
return cursor.fetchone()
elif fetch == "all":
return cursor.fetchall()
else:
return
except BaseException as e:
if conn.handle and not getattr(conn.handle, "closed", True):
conn.handle.rollback()
print(sql)
print(e)
raise
finally:
conn.transaction_open = False
return adapter.run_sql_for_tests(sql, fetch, conn)
class TestProcessingException(Exception):
pass
def relation_from_name(adapter, name: str):
"""reverse-engineer a relation from a given name and
the adapter. The relation name is split by the '.' character.
"""
# Different adapters have different Relation classes
cls = adapter.Relation
credentials = adapter.config.credentials
quote_policy = cls.get_default_quote_policy().to_dict()
include_policy = cls.get_default_include_policy().to_dict()
# Make sure we have database/schema/identifier parts, even if
# only identifier was supplied.
relation_parts = name.split(".")
if len(relation_parts) == 1:
relation_parts.insert(0, credentials.schema)
if len(relation_parts) == 2:
relation_parts.insert(0, credentials.database)
kwargs = {
"database": relation_parts[0],
"schema": relation_parts[1],
"identifier": relation_parts[2],
}
relation = cls.create(
include_policy=include_policy,
quote_policy=quote_policy,
**kwargs,
)
return relation
def check_relation_types(adapter, relation_to_type):
"""
Relation name to table/view
{
"base": "table",
"other": "view",
}
"""
expected_relation_values = {}
found_relations = []
schemas = set()
for key, value in relation_to_type.items():
relation = relation_from_name(adapter, key)
expected_relation_values[relation] = value
schemas.add(relation.without_identifier())
with patch.object(providers, "get_adapter", return_value=adapter):
with adapter.connection_named("__test"):
for schema in schemas:
found_relations.extend(adapter.list_relations_without_caching(schema))
for key, value in relation_to_type.items():
for relation in found_relations:
# this might be too broad
if relation.identifier == key:
assert relation.type == value, (
f"Got an unexpected relation type of {relation.type} "
f"for relation {key}, expected {value}"
)
def check_relations_equal(adapter, relation_names):
if len(relation_names) < 2:
raise TestProcessingException(
"Not enough relations to compare",
)
relations = [relation_from_name(adapter, name) for name in relation_names]
with patch.object(providers, "get_adapter", return_value=adapter):
with adapter.connection_named("_test"):
basis, compares = relations[0], relations[1:]
columns = [c.name for c in adapter.get_columns_in_relation(basis)]
for relation in compares:
sql = adapter.get_rows_different_sql(basis, relation, column_names=columns)
_, tbl = adapter.execute(sql, fetch=True)
num_rows = len(tbl)
assert (
num_rows == 1
), f"Invalid sql query from get_rows_different_sql: incorrect number of rows ({num_rows})"
num_cols = len(tbl[0])
assert (
num_cols == 2
), f"Invalid sql query from get_rows_different_sql: incorrect number of cols ({num_cols})"
row_count_difference = tbl[0][0]
assert (
row_count_difference == 0
), f"Got {row_count_difference} difference in row count betwen {basis} and {relation}"
rows_mismatched = tbl[0][1]
assert (
rows_mismatched == 0
), f"Got {rows_mismatched} different rows between {basis} and {relation}"
def get_unique_ids_in_results(results):
unique_ids = []
for result in results:
unique_ids.append(result.node.unique_id)
return unique_ids
def check_result_nodes_by_name(results, names):
result_names = []
for result in results:
result_names.append(result.node.name)
assert set(names) == set(result_names)
def check_result_nodes_by_unique_id(results, unique_ids):
result_unique_ids = []
for result in results:
result_unique_ids.append(result.node.unique_id)
assert set(unique_ids) == set(result_unique_ids)
def update_rows(adapter, update_rows_config):
"""
{
"name": "base",
"dst_col": "some_date"
"clause": {
"type": "add_timestamp",
"src_col": "some_date",
"where" "id > 10"
}
"""
for key in ["name", "dst_col", "clause"]:
if key not in update_rows_config:
raise TestProcessingException(f"Invalid update_rows: no {key}")
clause = update_rows_config["clause"]
clause = generate_update_clause(adapter, clause)
where = None
if "where" in update_rows_config:
where = update_rows_config["where"]
name = update_rows_config["name"]
dst_col = update_rows_config["dst_col"]
relation = relation_from_name(adapter, name)
with patch.object(providers, "get_adapter", return_value=adapter):
with adapter.connection_named("_test"):
sql = adapter.update_column_sql(
dst_name=str(relation),
dst_column=dst_col,
clause=clause,
where_clause=where,
)
adapter.execute(sql, auto_begin=True)
adapter.commit_if_has_connection()
def generate_update_clause(adapter, clause) -> str:
"""
Called by update_rows function. Expects the "clause" dictionary
documented in 'update_rows.
"""
if "type" not in clause or clause["type"] not in ["add_timestamp", "add_string"]:
raise TestProcessingException("invalid update_rows clause: type missing or incorrect")
clause_type = clause["type"]
if clause_type == "add_timestamp":
if "src_col" not in clause:
raise TestProcessingException("Invalid update_rows clause: no src_col")
add_to = clause["src_col"]
kwargs = {k: v for k, v in clause.items() if k in ("interval", "number")}
with patch.object(providers, "get_adapter", return_value=adapter):
with adapter.connection_named("_test"):
return adapter.timestamp_add_sql(add_to=add_to, **kwargs)
elif clause_type == "add_string":
for key in ["src_col", "value"]:
if key not in clause:
raise TestProcessingException(f"Invalid update_rows clause: no {key}")
src_col = clause["src_col"]
value = clause["value"]
location = clause.get("location", "append")
with patch.object(providers, "get_adapter", return_value=adapter):
with adapter.connection_named("_test"):
return adapter.string_add_sql(src_col, value, location)
return ""

View File

@@ -1,2 +1,3 @@
-e ./core
-e ./plugins/postgres
-e ./tests/adapter

22
tests/adapter/README.md Normal file
View File

@@ -0,0 +1,22 @@
# dbt-tests-adapter
This is where we store the adapter tests that will be used by
plugin repos. It should be included in the dbt-core and plugin
repos by installing using pip.
Tests in this repo will be packaged as classes, with a base class
that can be imported by adapter test repositories. Tests that might
need to be overridden by a plugin will be separated into separate
test methods or test classes as they are discovered.
This plugin is installed in the dbt-core repo by pip install -e tests/adapter,
which is included in the editable-requirements.txt file.
The dbt.tests.adapter.basic tests originally came from the earlier
dbt-adapter-tests repository. Additional test directories will be
added as they are converted from dbt-core integration tests, so that
they can be used in adapter test suites without copying and pasting.
This is packaged as a plugin using a python namespace package so it
cannot have an __init__.py file in the part of the hierarchy to which it
needs to be attached.

View File

@@ -53,7 +53,7 @@ sources:
identifier: "{{ var('seed_name', 'base') }}"
"""
schema_test_seed_yml = """
generic_test_seed_yml = """
version: 2
models:
- name: base
@@ -63,7 +63,7 @@ models:
- not_null
"""
schema_test_view_yml = """
generic_test_view_yml = """
version: 2
models:
- name: view_model
@@ -73,7 +73,7 @@ models:
- not_null
"""
schema_test_table_yml = """
generic_test_table_yml = """
version: 2
models:
- name: table_model

View File

@@ -0,0 +1,110 @@
import pytest
from dbt.tests.util import (
run_dbt,
check_result_nodes_by_name,
relation_from_name,
check_relation_types,
check_relations_equal,
)
from dbt.tests.adapter.basic.files import (
seeds_base_csv,
base_view_sql,
base_table_sql,
base_materialized_var_sql,
schema_base_yml,
)
class BaseSimpleMaterializations:
@pytest.fixture(scope="class")
def models(self):
return {
"view_model.sql": base_view_sql,
"table_model.sql": base_table_sql,
"swappable.sql": base_materialized_var_sql,
"schema.yml": schema_base_yml,
}
@pytest.fixture(scope="class")
def seeds(self):
return {
"base.csv": seeds_base_csv,
}
@pytest.fixture(scope="class")
def project_config_update(self):
return {
"name": "base",
}
def test_base(self, project):
# seed command
results = run_dbt(["seed"])
# seed result length
assert len(results) == 1
# run command
results = run_dbt()
# run result length
assert len(results) == 3
# names exist in result nodes
check_result_nodes_by_name(results, ["view_model", "table_model", "swappable"])
# check relation types
expected = {
"base": "table",
"view_model": "view",
"table_model": "table",
"swappable": "table",
}
check_relation_types(project.adapter, expected)
# base table rowcount
relation = relation_from_name(project.adapter, "base")
result = project.run_sql(f"select count(*) as num_rows from {relation}", fetch="one")
assert result[0] == 10
# relations_equal
check_relations_equal(project.adapter, ["base", "view_model", "table_model", "swappable"])
# check relations in catalog
catalog = run_dbt(["docs", "generate"])
assert len(catalog.nodes) == 4
assert len(catalog.sources) == 1
# run_dbt changing materialized_var to view
if project.test_config.get("require_full_refresh", False): # required for BigQuery
results = run_dbt(
["run", "--full-refresh", "-m", "swappable", "--vars", "materialized_var: view"]
)
else:
results = run_dbt(["run", "-m", "swappable", "--vars", "materialized_var: view"])
assert len(results) == 1
# check relation types, swappable is view
expected = {
"base": "table",
"view_model": "view",
"table_model": "table",
"swappable": "view",
}
check_relation_types(project.adapter, expected)
# run_dbt changing materialized_var to incremental
results = run_dbt(["run", "-m", "swappable", "--vars", "materialized_var: incremental"])
assert len(results) == 1
# check relation types, swappable is table
expected = {
"base": "table",
"view_model": "view",
"table_model": "table",
"swappable": "table",
}
check_relation_types(project.adapter, expected)
class TestSimpleMaterializations(BaseSimpleMaterializations):
pass

View File

@@ -0,0 +1,29 @@
from dbt.tests.util import run_dbt
import os
class BaseEmpty:
def test_empty(self, project):
# check seed
results = run_dbt(["seed"])
assert len(results) == 0
run_results_path = os.path.join(project.project_root, "target", "run_results.json")
assert os.path.exists(run_results_path)
# check run
results = run_dbt(["run"])
assert len(results) == 0
catalog_path = os.path.join(project.project_root, "target", "catalog.json")
assert not os.path.exists(catalog_path)
# check catalog
catalog = run_dbt(["docs", "generate"])
assert os.path.exists(run_results_path)
assert os.path.exists(catalog_path)
assert len(catalog.nodes) == 0
assert len(catalog.sources) == 0
class TestEmpty(BaseEmpty):
pass

View File

@@ -0,0 +1,70 @@
import pytest
import os
from dbt.tests.util import (
run_dbt,
get_manifest,
check_relations_equal,
check_result_nodes_by_name,
relation_from_name,
)
from dbt.tests.adapter.basic.files import (
seeds_base_csv,
base_ephemeral_sql,
ephemeral_view_sql,
ephemeral_table_sql,
schema_base_yml,
)
class BaseEphemeral:
@pytest.fixture(scope="class")
def project_config_update(self):
return {"name": "ephemeral"}
@pytest.fixture(scope="class")
def seeds(self):
return {"base.csv": seeds_base_csv}
@pytest.fixture(scope="class")
def models(self):
return {
"ephemeral.sql": base_ephemeral_sql,
"view_model.sql": ephemeral_view_sql,
"table_model.sql": ephemeral_table_sql,
"schema.yml": schema_base_yml,
}
def test_ephemeral(self, project):
# seed command
results = run_dbt(["seed"])
assert len(results) == 1
check_result_nodes_by_name(results, ["base"])
# run command
results = run_dbt(["run"])
assert len(results) == 2
check_result_nodes_by_name(results, ["view_model", "table_model"])
# base table rowcount
relation = relation_from_name(project.adapter, "base")
result = project.run_sql(f"select count(*) as num_rows from {relation}", fetch="one")
assert result[0] == 10
# relations equal
check_relations_equal(project.adapter, ["base", "view_model", "table_model"])
# catalog node count
catalog = run_dbt(["docs", "generate"])
catalog_path = os.path.join(project.project_root, "target", "catalog.json")
assert os.path.exists(catalog_path)
assert len(catalog.nodes) == 3
assert len(catalog.sources) == 1
# manifest (not in original)
manifest = get_manifest(project.project_root)
assert len(manifest.nodes) == 4
assert len(manifest.sources) == 1
class TestEphemeral(BaseEphemeral):
pass

View File

@@ -0,0 +1,54 @@
import pytest
from dbt.tests.util import run_dbt
from dbt.tests.adapter.basic.files import (
seeds_base_csv,
generic_test_seed_yml,
base_view_sql,
base_table_sql,
schema_base_yml,
generic_test_view_yml,
generic_test_table_yml,
)
class BaseGenericTests:
@pytest.fixture(scope="class")
def project_config_update(self):
return {"name": "generic_tests"}
@pytest.fixture(scope="class")
def seeds(self):
return {
"base.csv": seeds_base_csv,
"schema.yml": generic_test_seed_yml,
}
@pytest.fixture(scope="class")
def models(self):
return {
"view_model.sql": base_view_sql,
"table_model.sql": base_table_sql,
"schema.yml": schema_base_yml,
"schema_view.yml": generic_test_view_yml,
"schema_table.yml": generic_test_table_yml,
}
def test_generic_tests(self, project):
# seed command
results = run_dbt(["seed"])
# test command selecting base model
results = run_dbt(["test", "-m", "base"])
assert len(results) == 1
# run command
results = run_dbt(["run"])
assert len(results) == 2
# test command, all tests
results = run_dbt(["test"])
assert len(results) == 3
class TestGenericTests(BaseGenericTests):
pass

View File

@@ -0,0 +1,62 @@
import pytest
from dbt.tests.util import run_dbt, check_relations_equal, relation_from_name
from dbt.tests.adapter.basic.files import (
seeds_base_csv,
seeds_added_csv,
schema_base_yml,
incremental_sql,
)
class BaseIncremental:
@pytest.fixture(scope="class")
def project_config_update(self):
return {"name": "incremental"}
@pytest.fixture(scope="class")
def models(self):
return {"incremental.sql": incremental_sql, "schema.yml": schema_base_yml}
@pytest.fixture(scope="class")
def seeds(self):
return {"base.csv": seeds_base_csv, "added.csv": seeds_added_csv}
def test_incremental(self, project):
# seed command
results = run_dbt(["seed"])
assert len(results) == 2
# base table rowcount
relation = relation_from_name(project.adapter, "base")
result = project.run_sql(f"select count(*) as num_rows from {relation}", fetch="one")
assert result[0] == 10
# added table rowcount
relation = relation_from_name(project.adapter, "added")
result = project.run_sql(f"select count(*) as num_rows from {relation}", fetch="one")
assert result[0] == 20
# run command
# the "seed_name" var changes the seed identifier in the schema file
results = run_dbt(["run", "--vars", "seed_name: base"])
assert len(results) == 1
# check relations equal
check_relations_equal(project.adapter, ["base", "incremental"])
# change seed_name var
# the "seed_name" var changes the seed identifier in the schema file
results = run_dbt(["run", "--vars", "seed_name: added"])
assert len(results) == 1
# check relations equal
check_relations_equal(project.adapter, ["added", "incremental"])
# get catalog from docs generate
catalog = run_dbt(["docs", "generate"])
assert len(catalog.nodes) == 3
assert len(catalog.sources) == 1
class Testincremental(BaseIncremental):
pass

View File

@@ -0,0 +1,38 @@
import pytest
from dbt.tests.adapter.basic.files import (
test_passing_sql,
test_failing_sql,
)
from dbt.tests.util import check_result_nodes_by_name, run_dbt
class BaseSingularTests:
@pytest.fixture(scope="class")
def tests(self):
return {
"passing.sql": test_passing_sql,
"failing.sql": test_failing_sql,
}
@pytest.fixture(scope="class")
def project_config_update(self):
return {"name": "singular_tests"}
def test_singular_tests(self, project):
# test command
results = run_dbt(["test"])
assert len(results) == 2
# We have the right result nodes
check_result_nodes_by_name(results, ["passing", "failing"])
# Check result status
for result in results:
if result.node.name == "passing":
assert result.status == "pass"
elif result.node.name == "failing":
assert result.status == "fail"
class TestSingularTests(BaseSingularTests):
pass

View File

@@ -0,0 +1,67 @@
import pytest
from dbt.tests.util import run_dbt, check_result_nodes_by_name
from dbt.tests.adapter.basic.files import (
seeds_base_csv,
ephemeral_with_cte_sql,
test_ephemeral_passing_sql,
test_ephemeral_failing_sql,
schema_base_yml,
)
class BaseSingularTestsEphemeral:
@pytest.fixture(scope="class")
def seeds(self):
return {
"base.csv": seeds_base_csv,
}
@pytest.fixture(scope="class")
def models(self):
return {
"ephemeral.sql": ephemeral_with_cte_sql,
"passing_model.sql": test_ephemeral_passing_sql,
"failing_model.sql": test_ephemeral_failing_sql,
"schema.yml": schema_base_yml,
}
@pytest.fixture(scope="class")
def tests(self):
return {
"passing.sql": test_ephemeral_passing_sql,
"failing.sql": test_ephemeral_failing_sql,
}
@pytest.fixture(scope="class")
def project_config_update(self):
return {
"name": "singular_tests_ephemeral",
}
def test_singular_tests_ephemeral(self, project):
# check results from seed command
results = run_dbt(["seed"])
assert len(results) == 1
check_result_nodes_by_name(results, ["base"])
# Check results from test command
results = run_dbt(["test"])
assert len(results) == 2
check_result_nodes_by_name(results, ["passing", "failing"])
# Check result status
for result in results:
if result.node.name == "passing":
assert result.status == "pass"
elif result.node.name == "failing":
assert result.status == "fail"
# check results from run command
results = run_dbt()
assert len(results) == 2
check_result_nodes_by_name(results, ["failing_model", "passing_model"])
class TestSingularTestsEphemeral(BaseSingularTestsEphemeral):
pass

View File

@@ -0,0 +1,112 @@
import pytest
from dbt.tests.util import run_dbt, update_rows, relation_from_name
from dbt.tests.adapter.basic.files import (
seeds_base_csv,
seeds_added_csv,
cc_all_snapshot_sql,
cc_date_snapshot_sql,
cc_name_snapshot_sql,
)
def check_relation_rows(project, snapshot_name, count):
relation = relation_from_name(project.adapter, snapshot_name)
result = project.run_sql(f"select count(*) as num_rows from {relation}", fetch="one")
assert result[0] == count
class BaseSnapshotCheckCols:
@pytest.fixture(scope="class")
def project_config_update(self):
return {"name": "snapshot_strategy_check_cols"}
@pytest.fixture(scope="class")
def seeds(self):
return {
"base.csv": seeds_base_csv,
"added.csv": seeds_added_csv,
}
@pytest.fixture(scope="class")
def snapshots(self):
return {
"cc_all_snapshot.sql": cc_all_snapshot_sql,
"cc_date_snapshot.sql": cc_date_snapshot_sql,
"cc_name_snapshot.sql": cc_name_snapshot_sql,
}
def test_snapshot_check_cols(self, project):
# seed command
results = run_dbt(["seed"])
assert len(results) == 2
# snapshot command
results = run_dbt(["snapshot"])
for result in results:
assert result.status == "success"
# check rowcounts for all snapshots
check_relation_rows(project, "cc_all_snapshot", 10)
check_relation_rows(project, "cc_name_snapshot", 10)
check_relation_rows(project, "cc_date_snapshot", 10)
relation = relation_from_name(project.adapter, "cc_all_snapshot")
result = project.run_sql(f"select * from {relation}", fetch="all")
# point at the "added" seed so the snapshot sees 10 new rows
results = run_dbt(["--no-partial-parse", "snapshot", "--vars", "seed_name: added"])
for result in results:
assert result.status == "success"
# check rowcounts for all snapshots
check_relation_rows(project, "cc_all_snapshot", 20)
check_relation_rows(project, "cc_name_snapshot", 20)
check_relation_rows(project, "cc_date_snapshot", 20)
# update some timestamps in the "added" seed so the snapshot sees 10 more new rows
update_rows_config = {
"name": "added",
"dst_col": "some_date",
"clause": {"src_col": "some_date", "type": "add_timestamp"},
"where": "id > 10 and id < 21",
}
update_rows(project.adapter, update_rows_config)
# re-run snapshots, using "added'
results = run_dbt(["snapshot", "--vars", "seed_name: added"])
for result in results:
assert result.status == "success"
# check rowcounts for all snapshots
check_relation_rows(project, "cc_all_snapshot", 30)
check_relation_rows(project, "cc_date_snapshot", 30)
# unchanged: only the timestamp changed
check_relation_rows(project, "cc_name_snapshot", 20)
# Update the name column
update_rows_config = {
"name": "added",
"dst_col": "name",
"clause": {
"src_col": "name",
"type": "add_string",
"value": "_updated",
},
"where": "id < 11",
}
update_rows(project.adapter, update_rows_config)
# re-run snapshots, using "added'
results = run_dbt(["snapshot", "--vars", "seed_name: added"])
for result in results:
assert result.status == "success"
# check rowcounts for all snapshots
check_relation_rows(project, "cc_all_snapshot", 40)
check_relation_rows(project, "cc_name_snapshot", 30)
# does not see name updates
check_relation_rows(project, "cc_date_snapshot", 30)
class TestSnapshotCheckCols(BaseSnapshotCheckCols):
pass

View File

@@ -0,0 +1,90 @@
import pytest
from dbt.tests.util import run_dbt, relation_from_name, update_rows
from dbt.tests.adapter.basic.files import (
seeds_base_csv,
seeds_newcolumns_csv,
seeds_added_csv,
ts_snapshot_sql,
)
def check_relation_rows(project, snapshot_name, count):
relation = relation_from_name(project.adapter, snapshot_name)
result = project.run_sql(f"select count(*) as num_rows from {relation}", fetch="one")
assert result[0] == count
class BaseSnapshotTimestamp:
@pytest.fixture(scope="class")
def seeds(self):
return {
"base.csv": seeds_base_csv,
"newcolumns.csv": seeds_newcolumns_csv,
"added.csv": seeds_added_csv,
}
@pytest.fixture(scope="class")
def snapshots(self):
return {
"ts_snapshot.sql": ts_snapshot_sql,
}
@pytest.fixture(scope="class")
def project_config_update(self):
return {"name": "snapshot_strategy_timestamp"}
def test_snapshot_timestamp(self, project):
# seed command
results = run_dbt(["seed"])
assert len(results) == 3
# snapshot command
results = run_dbt(["snapshot"])
assert len(results) == 1
# snapshot has 10 rows
check_relation_rows(project, "ts_snapshot", 10)
# point at the "added" seed so the snapshot sees 10 new rows
results = run_dbt(["snapshot", "--vars", "seed_name: added"])
# snapshot now has 20 rows
check_relation_rows(project, "ts_snapshot", 20)
# update some timestamps in the "added" seed so the snapshot sees 10 more new rows
update_rows_config = {
"name": "added",
"dst_col": "some_date",
"clause": {
"src_col": "some_date",
"type": "add_timestamp",
},
"where": "id > 10 and id < 21",
}
update_rows(project.adapter, update_rows_config)
results = run_dbt(["snapshot", "--vars", "seed_name: added"])
# snapshot now has 30 rows
check_relation_rows(project, "ts_snapshot", 30)
update_rows_config = {
"name": "added",
"dst_col": "name",
"clause": {
"src_col": "name",
"type": "add_string",
"value": "_updated",
},
"where": "id < 11",
}
update_rows(project.adapter, update_rows_config)
results = run_dbt(["snapshot", "--vars", "seed_name: added"])
# snapshot still has 30 rows because timestamp not updated
check_relation_rows(project, "ts_snapshot", 30)
class TestSnapshotTimestamp(BaseSnapshotTimestamp):
pass

79
tests/adapter/setup.py Normal file
View File

@@ -0,0 +1,79 @@
#!/usr/bin/env python
import os
import sys
if sys.version_info < (3, 7):
print("Error: dbt does not support this version of Python.")
print("Please upgrade to Python 3.7 or higher.")
sys.exit(1)
from setuptools import setup
try:
from setuptools import find_namespace_packages
except ImportError:
# the user has a downlevel version of setuptools.
print("Error: dbt requires setuptools v40.1.0 or higher.")
print('Please upgrade setuptools with "pip install --upgrade setuptools" ' "and try again")
sys.exit(1)
PSYCOPG2_MESSAGE = """
No package name override was set.
Using 'psycopg2-binary' package to satisfy 'psycopg2'
If you experience segmentation faults, silent crashes, or installation errors,
consider retrying with the 'DBT_PSYCOPG2_NAME' environment variable set to
'psycopg2'. It may require a compiler toolchain and development libraries!
""".strip()
def _dbt_psycopg2_name():
# if the user chose something, use that
package_name = os.getenv("DBT_PSYCOPG2_NAME", "")
if package_name:
return package_name
# default to psycopg2-binary for all OSes/versions
print(PSYCOPG2_MESSAGE)
return "psycopg2-binary"
package_name = "dbt-tests-adapter"
package_version = "1.0.1"
description = """The dbt adapter tests for adapter plugins"""
this_directory = os.path.abspath(os.path.dirname(__file__))
with open(os.path.join(this_directory, "README.md")) as f:
long_description = f.read()
DBT_PSYCOPG2_NAME = _dbt_psycopg2_name()
setup(
name=package_name,
version=package_version,
description=description,
long_description=long_description,
long_description_content_type="text/markdown",
author="dbt Labs",
author_email="info@dbtlabs.com",
url="https://github.com/dbt-labs/dbt-tests-adapter",
packages=find_namespace_packages(include=["dbt", "dbt.*"]),
install_requires=[
"dbt-core=={}".format(package_version),
"{}~=2.8".format(DBT_PSYCOPG2_NAME),
],
zip_safe=False,
classifiers=[
"Development Status :: 5 - Production/Stable",
"License :: OSI Approved :: Apache Software License",
"Operating System :: Microsoft :: Windows",
"Operating System :: MacOS :: MacOS X",
"Operating System :: POSIX :: Linux",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
],
python_requires=">=3.7",
)

View File

@@ -1,99 +0,0 @@
import pytest
from dbt.tests.util import run_dbt
from dbt.tests.adapter import check_result_nodes_by_name
from dbt.tests.adapter import relation_from_name, check_relation_types, check_relations_equal
from tests.functional.adapter.files import (
seeds_base_csv,
base_view_sql,
base_table_sql,
base_materialized_var_sql,
schema_base_yml,
)
@pytest.fixture(scope="class")
def models():
return {
"view_model.sql": base_view_sql,
"table_model.sql": base_table_sql,
"swappable.sql": base_materialized_var_sql,
"schema.yml": schema_base_yml,
}
@pytest.fixture(scope="class")
def seeds():
return {
"base.csv": seeds_base_csv,
}
@pytest.fixture(scope="class")
def project_config_update():
return {
"name": "base",
}
def test_base(project):
# seed command
results = run_dbt(["seed"])
# seed result length
assert len(results) == 1
# run command
results = run_dbt()
# run result length
assert len(results) == 3
# names exist in result nodes
check_result_nodes_by_name(results, ["view_model", "table_model", "swappable"])
# check relation types
expected = {
"base": "table",
"view_model": "view",
"table_model": "table",
"swappable": "table",
}
check_relation_types(project.adapter, expected)
# base table rowcount
relation = relation_from_name(project.adapter, "base")
result = project.run_sql(f"select count(*) as num_rows from {relation}", fetch="one")
assert result[0] == 10
# relations_equal
check_relations_equal(project.adapter, ["base", "view_model", "table_model", "swappable"])
# check relations in catalog
catalog = run_dbt(["docs", "generate"])
assert len(catalog.nodes) == 4
assert len(catalog.sources) == 1
# run_dbt changing materialized_var to view
results = run_dbt(["run", "-m", "swappable", "--vars", "materialized_var: view"])
assert len(results) == 1
# check relation types, swappable is view
expected = {
"base": "table",
"view_model": "view",
"table_model": "table",
"swappable": "view",
}
check_relation_types(project.adapter, expected)
# run_dbt changing materialized_var to incremental
results = run_dbt(["run", "-m", "swappable", "--vars", "materialized_var: incremental"])
assert len(results) == 1
# check relation types, swappable is table
expected = {
"base": "table",
"view_model": "view",
"table_model": "table",
"swappable": "table",
}
check_relation_types(project.adapter, expected)

View File

@@ -1,67 +0,0 @@
import pytest
from dbt.tests.util import run_dbt
from dbt.tests.adapter import check_result_nodes_by_name
from tests.functional.adapter.files import (
seeds_base_csv,
ephemeral_with_cte_sql,
test_ephemeral_passing_sql,
test_ephemeral_failing_sql,
schema_base_yml,
)
@pytest.fixture(scope="class")
def seeds():
return {
"base.csv": seeds_base_csv,
}
@pytest.fixture(scope="class")
def models():
return {
"ephemeral.sql": ephemeral_with_cte_sql,
"passing_model.sql": test_ephemeral_passing_sql,
"failing_model.sql": test_ephemeral_failing_sql,
"schema.yml": schema_base_yml,
}
@pytest.fixture(scope="class")
def tests():
return {
"passing.sql": test_ephemeral_passing_sql,
"failing.sql": test_ephemeral_failing_sql,
}
@pytest.fixture(scope="class")
def project_config_update():
return {
"name": "data_test_ephemeral_models",
}
def test_data_test_ephemerals(project):
# check results from seed command
results = run_dbt(["seed"])
assert len(results) == 1
check_result_nodes_by_name(results, ["base"])
# Check results from test command
results = run_dbt(["test"])
assert len(results) == 2
check_result_nodes_by_name(results, ["passing", "failing"])
# Check result status
for result in results:
if result.node.name == "passing":
assert result.status == "pass"
elif result.node.name == "failing":
assert result.status == "fail"
# check results from run command
results = run_dbt()
assert len(results) == 2
check_result_nodes_by_name(results, ["failing_model", "passing_model"])

View File

@@ -1,36 +0,0 @@
import pytest
from tests.functional.adapter.files import (
test_passing_sql,
test_failing_sql,
)
from dbt.tests.adapter import check_result_nodes_by_name
from dbt.tests.util import run_dbt
@pytest.fixture(scope="class")
def tests():
return {
"passing.sql": test_passing_sql,
"failing.sql": test_failing_sql,
}
@pytest.fixture(scope="class")
def project_config_update():
return {"name": "data_tests"}
def test_data_tests(project):
# test command
results = run_dbt(["test"])
assert len(results) == 2
# We have the right result nodes
check_result_nodes_by_name(results, ["passing", "failing"])
# Check result status
for result in results:
if result.node.name == "passing":
assert result.status == "pass"
elif result.node.name == "failing":
assert result.status == "fail"

View File

@@ -1,24 +0,0 @@
from dbt.tests.util import run_dbt
import os
def test_empty(project):
# check seed
results = run_dbt(["seed"])
assert len(results) == 0
run_results_path = os.path.join(project.project_root, "target", "run_results.json")
assert os.path.exists(run_results_path)
# check run
results = run_dbt(["run"])
assert len(results) == 0
catalog_path = os.path.join(project.project_root, "target", "catalog.json")
assert not os.path.exists(catalog_path)
# check catalog
catalog = run_dbt(["docs", "generate"])
assert os.path.exists(run_results_path)
assert os.path.exists(catalog_path)
assert len(catalog.nodes) == 0
assert len(catalog.sources) == 0

View File

@@ -1,63 +0,0 @@
import pytest
import os
from dbt.tests.util import run_dbt, get_manifest
from dbt.tests.adapter import check_relations_equal, check_result_nodes_by_name, relation_from_name
from tests.functional.adapter.files import (
seeds_base_csv,
base_ephemeral_sql,
ephemeral_view_sql,
ephemeral_table_sql,
schema_base_yml,
)
@pytest.fixture(scope="class")
def project_config_update():
return {"name": "ephemeral"}
@pytest.fixture(scope="class")
def seeds():
return {"base.csv": seeds_base_csv}
@pytest.fixture(scope="class")
def models():
return {
"ephemeral.sql": base_ephemeral_sql,
"view_model.sql": ephemeral_view_sql,
"table_model.sql": ephemeral_table_sql,
"schema.yml": schema_base_yml,
}
def test_ephemeral(project):
# seed command
results = run_dbt(["seed"])
assert len(results) == 1
check_result_nodes_by_name(results, ["base"])
# run command
results = run_dbt(["run"])
assert len(results) == 2
check_result_nodes_by_name(results, ["view_model", "table_model"])
# base table rowcount
relation = relation_from_name(project.adapter, "base")
result = project.run_sql(f"select count(*) as num_rows from {relation}", fetch="one")
assert result[0] == 10
# relations equal
check_relations_equal(project.adapter, ["base", "view_model", "table_model"])
# catalog node count
catalog = run_dbt(["docs", "generate"])
catalog_path = os.path.join(project.project_root, "target", "catalog.json")
assert os.path.exists(catalog_path)
assert len(catalog.nodes) == 3
assert len(catalog.sources) == 1
# manifest (not in original)
manifest = get_manifest(project.project_root)
assert len(manifest.nodes) == 4
assert len(manifest.sources) == 1

View File

@@ -1,61 +0,0 @@
import pytest
from dbt.tests.util import run_dbt
from tests.functional.adapter.files import (
seeds_base_csv,
seeds_added_csv,
schema_base_yml,
incremental_sql,
)
from dbt.tests.adapter import check_relations_equal, relation_from_name
@pytest.fixture(scope="class")
def project_config_update():
return {"name": "incremental"}
@pytest.fixture(scope="class")
def models():
return {"incremental.sql": incremental_sql, "schema.yml": schema_base_yml}
@pytest.fixture(scope="class")
def seeds():
return {"base.csv": seeds_base_csv, "added.csv": seeds_added_csv}
def test_incremental(project):
# seed command
results = run_dbt(["seed"])
assert len(results) == 2
# base table rowcount
relation = relation_from_name(project.adapter, "base")
result = project.run_sql(f"select count(*) as num_rows from {relation}", fetch="one")
assert result[0] == 10
# added table rowcount
relation = relation_from_name(project.adapter, "added")
result = project.run_sql(f"select count(*) as num_rows from {relation}", fetch="one")
assert result[0] == 20
# run command
# the "seed_name" var changes the seed identifier in the schema file
results = run_dbt(["run", "--vars", "seed_name: base"])
assert len(results) == 1
# check relations equal
check_relations_equal(project.adapter, ["base", "incremental"])
# change seed_name var
# the "seed_name" var changes the seed identifier in the schema file
results = run_dbt(["run", "--vars", "seed_name: added"])
assert len(results) == 1
# check relations equal
check_relations_equal(project.adapter, ["added", "incremental"])
# get catalog from docs generate
catalog = run_dbt(["docs", "generate"])
assert len(catalog.nodes) == 3
assert len(catalog.sources) == 1

View File

@@ -1,52 +0,0 @@
import pytest
from dbt.tests.util import run_dbt
from tests.functional.adapter.files import (
seeds_base_csv,
schema_test_seed_yml,
base_view_sql,
base_table_sql,
schema_base_yml,
schema_test_view_yml,
schema_test_table_yml,
)
@pytest.fixture(scope="class")
def project_config_update():
return {"name": "schema_test"}
@pytest.fixture(scope="class")
def seeds():
return {
"base.csv": seeds_base_csv,
"schema.yml": schema_test_seed_yml,
}
@pytest.fixture(scope="class")
def models():
return {
"view_model.sql": base_view_sql,
"table_model.sql": base_table_sql,
"schema.yml": schema_base_yml,
"schema_view.yml": schema_test_view_yml,
"schema_table.yml": schema_test_table_yml,
}
def test_schema_tests(project):
# seed command
results = run_dbt(["seed"])
# test command selecting base model
results = run_dbt(["test", "-m", "base"])
assert len(results) == 1
# run command
results = run_dbt(["run"])
assert len(results) == 2
# test command, all tests
results = run_dbt(["test"])
assert len(results) == 3

View File

@@ -1,111 +0,0 @@
import pytest
from dbt.tests.util import run_dbt
from tests.functional.adapter.files import (
seeds_base_csv,
seeds_added_csv,
cc_all_snapshot_sql,
cc_date_snapshot_sql,
cc_name_snapshot_sql,
)
from dbt.tests.adapter import update_rows, relation_from_name
@pytest.fixture(scope="class")
def project_config_update():
return {"name": "snapshot_strategy_check_cols"}
@pytest.fixture(scope="class")
def seeds():
return {
"base.csv": seeds_base_csv,
"added.csv": seeds_added_csv,
}
@pytest.fixture(scope="class")
def snapshots():
return {
"cc_all_snapshot.sql": cc_all_snapshot_sql,
"cc_date_snapshot.sql": cc_date_snapshot_sql,
"cc_name_snapshot.sql": cc_name_snapshot_sql,
}
def check_relation_rows(project, snapshot_name, count):
relation = relation_from_name(project.adapter, snapshot_name)
result = project.run_sql(f"select count(*) as num_rows from {relation}", fetch="one")
assert result[0] == count
def test_snapshot_check_cols(project):
# seed command
results = run_dbt(["seed"])
assert len(results) == 2
# snapshot command
results = run_dbt(["snapshot"])
for result in results:
assert result.status == "success"
# check rowcounts for all snapshots
check_relation_rows(project, "cc_all_snapshot", 10)
check_relation_rows(project, "cc_name_snapshot", 10)
check_relation_rows(project, "cc_date_snapshot", 10)
relation = relation_from_name(project.adapter, "cc_all_snapshot")
result = project.run_sql(f"select * from {relation}", fetch="all")
# point at the "added" seed so the snapshot sees 10 new rows
results = run_dbt(["--no-partial-parse", "snapshot", "--vars", "seed_name: added"])
for result in results:
assert result.status == "success"
# check rowcounts for all snapshots
check_relation_rows(project, "cc_all_snapshot", 20)
check_relation_rows(project, "cc_name_snapshot", 20)
check_relation_rows(project, "cc_date_snapshot", 20)
# update some timestamps in the "added" seed so the snapshot sees 10 more new rows
update_rows_config = {
"name": "added",
"dst_col": "some_date",
"clause": {"src_col": "some_date", "type": "add_timestamp"},
"where": "id > 10 and id < 21",
}
update_rows(project.adapter, update_rows_config)
# re-run snapshots, using "added'
results = run_dbt(["snapshot", "--vars", "seed_name: added"])
for result in results:
assert result.status == "success"
# check rowcounts for all snapshots
check_relation_rows(project, "cc_all_snapshot", 30)
check_relation_rows(project, "cc_date_snapshot", 30)
# unchanged: only the timestamp changed
check_relation_rows(project, "cc_name_snapshot", 20)
# Update the name column
update_rows_config = {
"name": "added",
"dst_col": "name",
"clause": {
"src_col": "name",
"type": "add_string",
"value": "_updated",
},
"where": "id < 11",
}
update_rows(project.adapter, update_rows_config)
# re-run snapshots, using "added'
results = run_dbt(["snapshot", "--vars", "seed_name: added"])
for result in results:
assert result.status == "success"
# check rowcounts for all snapshots
check_relation_rows(project, "cc_all_snapshot", 40)
check_relation_rows(project, "cc_name_snapshot", 30)
# does not see name updates
check_relation_rows(project, "cc_date_snapshot", 30)

View File

@@ -1,89 +0,0 @@
import pytest
from dbt.tests.util import run_dbt
from tests.functional.adapter.files import (
seeds_base_csv,
seeds_newcolumns_csv,
seeds_added_csv,
ts_snapshot_sql,
)
from dbt.tests.adapter import relation_from_name, update_rows
@pytest.fixture(scope="class")
def seeds():
return {
"base.csv": seeds_base_csv,
"newcolumns.csv": seeds_newcolumns_csv,
"added.csv": seeds_added_csv,
}
@pytest.fixture(scope="class")
def snapshots():
return {
"ts_snapshot.sql": ts_snapshot_sql,
}
@pytest.fixture(scope="class")
def project_config_update():
return {"name": "snapshot_strategy_timestamp"}
def check_relation_rows(project, snapshot_name, count):
relation = relation_from_name(project.adapter, snapshot_name)
result = project.run_sql(f"select count(*) as num_rows from {relation}", fetch="one")
assert result[0] == count
def test_snapshot_timestamp(project):
# seed command
results = run_dbt(["seed"])
assert len(results) == 3
# snapshot command
results = run_dbt(["snapshot"])
assert len(results) == 1
# snapshot has 10 rows
check_relation_rows(project, "ts_snapshot", 10)
# point at the "added" seed so the snapshot sees 10 new rows
results = run_dbt(["snapshot", "--vars", "seed_name: added"])
# snapshot now has 20 rows
check_relation_rows(project, "ts_snapshot", 20)
# update some timestamps in the "added" seed so the snapshot sees 10 more new rows
update_rows_config = {
"name": "added",
"dst_col": "some_date",
"clause": {
"src_col": "some_date",
"type": "add_timestamp",
},
"where": "id > 10 and id < 21",
}
update_rows(project.adapter, update_rows_config)
results = run_dbt(["snapshot", "--vars", "seed_name: added"])
# snapshot now has 30 rows
check_relation_rows(project, "ts_snapshot", 30)
update_rows_config = {
"name": "added",
"dst_col": "name",
"clause": {
"src_col": "name",
"type": "add_string",
"value": "_updated",
},
"where": "id < 11",
}
update_rows(project.adapter, update_rows_config)
results = run_dbt(["snapshot", "--vars", "seed_name: added"])
# snapshot still has 30 rows because timestamp not updated
check_relation_rows(project, "ts_snapshot", 30)

View File

@@ -158,7 +158,7 @@ class TestTagSelection(SelectionFixtures):
"unique_users_rollup_gender",
]
def test_select_tag_in_model_with_project_config_parents_children_selectors(project):
def test_select_tag_in_model_with_project_config_parents_children_selectors(self, project):
results = run_dbt(["run", "--selector", "user_tagged_childrens_parents"])
assert len(results) == 4

View File

@@ -18,6 +18,7 @@ passenv = DBT_* POSTGRES_TEST_* PYTEST_ADDOPTS
commands =
{envpython} -m pytest --cov=core -m profile_postgres {posargs} test/integration
{envpython} -m pytest --cov=core {posargs} tests/functional
{envpython} -m pytest --cov=core {posargs} tests/adapter
deps =
-rdev-requirements.txt