|
|
|
|
@@ -11,6 +11,7 @@ import unittest
|
|
|
|
|
from contextlib import contextmanager
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
from functools import wraps
|
|
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
|
import yaml
|
|
|
|
|
@@ -88,7 +89,8 @@ class TestArgs:
|
|
|
|
|
self.__dict__.update(kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _profile_from_test_name(test_name):
|
|
|
|
|
# TODO since we only ever test postgres now we don't really need this
|
|
|
|
|
def _profile_from_test_name(test_name: str) -> str:
|
|
|
|
|
adapter_names = ('postgres', 'presto')
|
|
|
|
|
adapters_in_name = sum(x in test_name for x in adapter_names)
|
|
|
|
|
if adapters_in_name != 1:
|
|
|
|
|
@@ -555,7 +557,7 @@ class DBTIntegrationTest(unittest.TestCase):
|
|
|
|
|
fire_event(IntegrationTestInfo(msg=msg))
|
|
|
|
|
return dbt.handle_and_check(final_args)
|
|
|
|
|
|
|
|
|
|
def run_sql_file(self, path, kwargs=None):
|
|
|
|
|
def run_sql_file(self, path: str, kwargs=None):
|
|
|
|
|
with open(path, 'r') as f:
|
|
|
|
|
statements = f.read().split(";")
|
|
|
|
|
for statement in statements:
|
|
|
|
|
@@ -601,7 +603,7 @@ class DBTIntegrationTest(unittest.TestCase):
|
|
|
|
|
conn.handle.commit()
|
|
|
|
|
conn.transaction_open = False
|
|
|
|
|
|
|
|
|
|
def run_sql_common(self, sql, fetch, conn):
|
|
|
|
|
def run_sql_common(self, sql, fetch, conn) -> None:
|
|
|
|
|
with conn.handle.cursor() as cursor:
|
|
|
|
|
try:
|
|
|
|
|
cursor.execute(sql)
|
|
|
|
|
@@ -621,7 +623,7 @@ class DBTIntegrationTest(unittest.TestCase):
|
|
|
|
|
finally:
|
|
|
|
|
conn.transaction_open = False
|
|
|
|
|
|
|
|
|
|
def run_sql(self, query, fetch='None', kwargs=None, connection_name=None):
|
|
|
|
|
def run_sql(self, query, fetch='None', kwargs=None, connection_name=None) -> None:
|
|
|
|
|
if connection_name is None:
|
|
|
|
|
connection_name = '__test'
|
|
|
|
|
|
|
|
|
|
@@ -735,14 +737,14 @@ class DBTIntegrationTest(unittest.TestCase):
|
|
|
|
|
conn = self.adapter.connections.get_thread_connection()
|
|
|
|
|
yield conn
|
|
|
|
|
|
|
|
|
|
def get_relation_columns(self, relation):
|
|
|
|
|
def get_relation_columns(self, relation) -> List[Tuple[str, str, int]]:
|
|
|
|
|
with self.get_connection():
|
|
|
|
|
columns = self.adapter.get_columns_in_relation(relation)
|
|
|
|
|
|
|
|
|
|
return sorted(((c.name, c.dtype, c.char_size) for c in columns),
|
|
|
|
|
key=lambda x: x[0])
|
|
|
|
|
|
|
|
|
|
def get_table_columns(self, table, schema=None, database=None):
|
|
|
|
|
def get_table_columns(self, table, schema=None, database=None): # type: ignore
|
|
|
|
|
schema = self.unique_schema() if schema is None else schema
|
|
|
|
|
database = self.default_database if database is None else database
|
|
|
|
|
relation = self.adapter.Relation.create(
|
|
|
|
|
@@ -1012,23 +1014,23 @@ class DBTIntegrationTest(unittest.TestCase):
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def assertTableDoesNotExist(self, table, schema=None, database=None):
|
|
|
|
|
columns = self.get_table_columns(table, schema, database)
|
|
|
|
|
def assertTableDoesNotExist(self, table, schema=None, database=None): # type: ignore
|
|
|
|
|
columns = self.get_table_columns(table, schema, database) # type: ignore
|
|
|
|
|
|
|
|
|
|
self.assertEqual(
|
|
|
|
|
len(columns),
|
|
|
|
|
0
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def assertTableDoesExist(self, table, schema=None, database=None):
|
|
|
|
|
columns = self.get_table_columns(table, schema, database)
|
|
|
|
|
def assertTableDoesExist(self, table, schema=None, database=None): # type: ignore
|
|
|
|
|
columns = self.get_table_columns(table, schema, database) # type: ignore
|
|
|
|
|
|
|
|
|
|
self.assertGreater(
|
|
|
|
|
len(columns),
|
|
|
|
|
0
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _assertTableColumnsEqual(self, relation_a, relation_b):
|
|
|
|
|
def _assertTableColumnsEqual(self, relation_a, relation_b): # type: ignore
|
|
|
|
|
table_a_result = self.get_relation_columns(relation_a)
|
|
|
|
|
table_b_result = self.get_relation_columns(relation_b)
|
|
|
|
|
|
|
|
|
|
@@ -1057,11 +1059,11 @@ class DBTIntegrationTest(unittest.TestCase):
|
|
|
|
|
relation_a, relation_b, a_name, a_size, b_size
|
|
|
|
|
))
|
|
|
|
|
|
|
|
|
|
def assertEquals(self, *args, **kwargs):
|
|
|
|
|
def assertEquals(self, *args, **kwargs) -> None: # type: ignore
|
|
|
|
|
# assertEquals is deprecated. This makes the warnings less chatty
|
|
|
|
|
self.assertEqual(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
def assertBetween(self, timestr, start, end=None):
|
|
|
|
|
def assertBetween(self, timestr: str, start: datetime, end: Optional[datetime]=None) -> None:
|
|
|
|
|
datefmt = '%Y-%m-%dT%H:%M:%S.%fZ'
|
|
|
|
|
if end is None:
|
|
|
|
|
end = datetime.utcnow()
|
|
|
|
|
@@ -1079,18 +1081,18 @@ class DBTIntegrationTest(unittest.TestCase):
|
|
|
|
|
end.strftime(datefmt))
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def copy_file(self, src, dest) -> None:
|
|
|
|
|
def copy_file(self, src: str, dest: str) -> None:
|
|
|
|
|
# move files in the temp testing dir created
|
|
|
|
|
shutil.copyfile(
|
|
|
|
|
os.path.join(self.test_root_dir, src),
|
|
|
|
|
os.path.join(self.test_root_dir, dest),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def rm_file(self, src) -> None:
|
|
|
|
|
def rm_file(self, src: str) -> None:
|
|
|
|
|
os.remove(os.path.join(self.test_root_dir, src))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def use_profile(profile_name):
|
|
|
|
|
def use_profile(profile_name): # type: ignore
|
|
|
|
|
"""A decorator to declare a test method as using a particular profile.
|
|
|
|
|
Handles both setting the nose attr and calling self.use_profile.
|
|
|
|
|
|
|
|
|
|
@@ -1105,10 +1107,10 @@ def use_profile(profile_name):
|
|
|
|
|
def test_snowflake_thing(self):
|
|
|
|
|
self.assertEqual(self.adapter_type, 'snowflake')
|
|
|
|
|
"""
|
|
|
|
|
def outer(wrapped):
|
|
|
|
|
def outer(wrapped): # type: ignore
|
|
|
|
|
@getattr(pytest.mark, 'profile_'+profile_name)
|
|
|
|
|
@wraps(wrapped)
|
|
|
|
|
def func(self, *args, **kwargs):
|
|
|
|
|
def func(self, *args, **kwargs): # type: ignore
|
|
|
|
|
return wrapped(self, *args, **kwargs)
|
|
|
|
|
# sanity check at import time
|
|
|
|
|
assert _profile_from_test_name(wrapped.__name__) == profile_name
|
|
|
|
|
@@ -1119,22 +1121,22 @@ def use_profile(profile_name):
|
|
|
|
|
class AnyFloat:
|
|
|
|
|
"""Any float. Use this in assertEqual() calls to assert that it is a float.
|
|
|
|
|
"""
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
|
|
|
return isinstance(other, float)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AnyString:
|
|
|
|
|
"""Any string. Use this in assertEqual() calls to assert that it is a string.
|
|
|
|
|
"""
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
|
|
|
return isinstance(other, str)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AnyStringWith:
|
|
|
|
|
def __init__(self, contains=None):
|
|
|
|
|
def __init__(self, contains: Optional[str]=None) -> None:
|
|
|
|
|
self.contains = contains
|
|
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
|
|
|
if not isinstance(other, str):
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
@@ -1143,16 +1145,15 @@ class AnyStringWith:
|
|
|
|
|
|
|
|
|
|
return self.contains in other
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
|
|
return 'AnyStringWith<{!r}>'.format(self.contains)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_manifest():
|
|
|
|
|
def get_manifest() -> Optional[Manifest]:
|
|
|
|
|
path = './target/partial_parse.msgpack'
|
|
|
|
|
if os.path.exists(path):
|
|
|
|
|
with open(path, 'rb') as fp:
|
|
|
|
|
manifest_mp = fp.read()
|
|
|
|
|
manifest: Manifest = Manifest.from_msgpack(manifest_mp)
|
|
|
|
|
return manifest
|
|
|
|
|
return Manifest.from_msgpack(manifest_mp)
|
|
|
|
|
else:
|
|
|
|
|
return None
|
|
|
|
|
|