Compare commits

...

5 Commits

Author SHA1 Message Date
Nathaniel May
edfca9b425 some changes I forgot to commit 2021-12-13 13:27:49 -05:00
Nathaniel May
f6a4bc8dc6 Any test types 2021-12-06 11:36:41 -05:00
Nathaniel May
97971f804d get_manifest 2021-12-06 11:22:46 -05:00
Nathaniel May
edf4142583 mypy: run strict. run everywhere. 2021-12-06 10:49:34 -05:00
Nathaniel May
b87be582ad update mypy to latest 2021-12-06 10:45:23 -05:00
4 changed files with 31 additions and 30 deletions

View File

@@ -3,7 +3,7 @@ flake8
flaky flaky
freezegun==0.3.12 freezegun==0.3.12
ipdb ipdb
mypy==0.782 mypy==0.910
pip-tools pip-tools
pytest pytest
pytest-dotenv pytest-dotenv

View File

@@ -11,6 +11,7 @@ import unittest
from contextlib import contextmanager from contextlib import contextmanager
from datetime import datetime from datetime import datetime
from functools import wraps from functools import wraps
from typing import Optional
import pytest import pytest
import yaml import yaml
@@ -88,7 +89,8 @@ class TestArgs:
self.__dict__.update(kwargs) self.__dict__.update(kwargs)
def _profile_from_test_name(test_name): # TODO since we only ever test postgres now we don't really need this
def _profile_from_test_name(test_name: str) -> str:
adapter_names = ('postgres', 'presto') adapter_names = ('postgres', 'presto')
adapters_in_name = sum(x in test_name for x in adapter_names) adapters_in_name = sum(x in test_name for x in adapter_names)
if adapters_in_name != 1: if adapters_in_name != 1:
@@ -555,7 +557,7 @@ class DBTIntegrationTest(unittest.TestCase):
fire_event(IntegrationTestInfo(msg=msg)) fire_event(IntegrationTestInfo(msg=msg))
return dbt.handle_and_check(final_args) return dbt.handle_and_check(final_args)
def run_sql_file(self, path, kwargs=None): def run_sql_file(self, path: str, kwargs=None):
with open(path, 'r') as f: with open(path, 'r') as f:
statements = f.read().split(";") statements = f.read().split(";")
for statement in statements: for statement in statements:
@@ -601,7 +603,7 @@ class DBTIntegrationTest(unittest.TestCase):
conn.handle.commit() conn.handle.commit()
conn.transaction_open = False conn.transaction_open = False
def run_sql_common(self, sql, fetch, conn): def run_sql_common(self, sql, fetch, conn) -> None:
with conn.handle.cursor() as cursor: with conn.handle.cursor() as cursor:
try: try:
cursor.execute(sql) cursor.execute(sql)
@@ -621,7 +623,7 @@ class DBTIntegrationTest(unittest.TestCase):
finally: finally:
conn.transaction_open = False conn.transaction_open = False
def run_sql(self, query, fetch='None', kwargs=None, connection_name=None): def run_sql(self, query, fetch='None', kwargs=None, connection_name=None) -> None:
if connection_name is None: if connection_name is None:
connection_name = '__test' connection_name = '__test'
@@ -735,14 +737,14 @@ class DBTIntegrationTest(unittest.TestCase):
conn = self.adapter.connections.get_thread_connection() conn = self.adapter.connections.get_thread_connection()
yield conn yield conn
def get_relation_columns(self, relation): def get_relation_columns(self, relation) -> List[Tuple[str, str, int]]:
with self.get_connection(): with self.get_connection():
columns = self.adapter.get_columns_in_relation(relation) columns = self.adapter.get_columns_in_relation(relation)
return sorted(((c.name, c.dtype, c.char_size) for c in columns), return sorted(((c.name, c.dtype, c.char_size) for c in columns),
key=lambda x: x[0]) key=lambda x: x[0])
def get_table_columns(self, table, schema=None, database=None): def get_table_columns(self, table, schema=None, database=None): # type: ignore
schema = self.unique_schema() if schema is None else schema schema = self.unique_schema() if schema is None else schema
database = self.default_database if database is None else database database = self.default_database if database is None else database
relation = self.adapter.Relation.create( relation = self.adapter.Relation.create(
@@ -1012,23 +1014,23 @@ class DBTIntegrationTest(unittest.TestCase):
) )
) )
def assertTableDoesNotExist(self, table, schema=None, database=None): def assertTableDoesNotExist(self, table, schema=None, database=None): # type: ignore
columns = self.get_table_columns(table, schema, database) columns = self.get_table_columns(table, schema, database) # type: ignore
self.assertEqual( self.assertEqual(
len(columns), len(columns),
0 0
) )
def assertTableDoesExist(self, table, schema=None, database=None): def assertTableDoesExist(self, table, schema=None, database=None): # type: ignore
columns = self.get_table_columns(table, schema, database) columns = self.get_table_columns(table, schema, database) # type: ignore
self.assertGreater( self.assertGreater(
len(columns), len(columns),
0 0
) )
def _assertTableColumnsEqual(self, relation_a, relation_b): def _assertTableColumnsEqual(self, relation_a, relation_b): # type: ignore
table_a_result = self.get_relation_columns(relation_a) table_a_result = self.get_relation_columns(relation_a)
table_b_result = self.get_relation_columns(relation_b) table_b_result = self.get_relation_columns(relation_b)
@@ -1057,11 +1059,11 @@ class DBTIntegrationTest(unittest.TestCase):
relation_a, relation_b, a_name, a_size, b_size relation_a, relation_b, a_name, a_size, b_size
)) ))
def assertEquals(self, *args, **kwargs): def assertEquals(self, *args, **kwargs) -> None: # type: ignore
# assertEquals is deprecated. This makes the warnings less chatty # assertEquals is deprecated. This makes the warnings less chatty
self.assertEqual(*args, **kwargs) self.assertEqual(*args, **kwargs)
def assertBetween(self, timestr, start, end=None): def assertBetween(self, timestr: str, start: datetime, end: Optional[datetime]=None) -> None:
datefmt = '%Y-%m-%dT%H:%M:%S.%fZ' datefmt = '%Y-%m-%dT%H:%M:%S.%fZ'
if end is None: if end is None:
end = datetime.utcnow() end = datetime.utcnow()
@@ -1079,18 +1081,18 @@ class DBTIntegrationTest(unittest.TestCase):
end.strftime(datefmt)) end.strftime(datefmt))
) )
def copy_file(self, src, dest) -> None: def copy_file(self, src: str, dest: str) -> None:
# move files in the temp testing dir created # move files in the temp testing dir created
shutil.copyfile( shutil.copyfile(
os.path.join(self.test_root_dir, src), os.path.join(self.test_root_dir, src),
os.path.join(self.test_root_dir, dest), os.path.join(self.test_root_dir, dest),
) )
def rm_file(self, src) -> None: def rm_file(self, src: str) -> None:
os.remove(os.path.join(self.test_root_dir, src)) os.remove(os.path.join(self.test_root_dir, src))
def use_profile(profile_name): def use_profile(profile_name): # type: ignore
"""A decorator to declare a test method as using a particular profile. """A decorator to declare a test method as using a particular profile.
Handles both setting the nose attr and calling self.use_profile. Handles both setting the nose attr and calling self.use_profile.
@@ -1105,10 +1107,10 @@ def use_profile(profile_name):
def test_snowflake_thing(self): def test_snowflake_thing(self):
self.assertEqual(self.adapter_type, 'snowflake') self.assertEqual(self.adapter_type, 'snowflake')
""" """
def outer(wrapped): def outer(wrapped): # type: ignore
@getattr(pytest.mark, 'profile_'+profile_name) @getattr(pytest.mark, 'profile_'+profile_name)
@wraps(wrapped) @wraps(wrapped)
def func(self, *args, **kwargs): def func(self, *args, **kwargs): # type: ignore
return wrapped(self, *args, **kwargs) return wrapped(self, *args, **kwargs)
# sanity check at import time # sanity check at import time
assert _profile_from_test_name(wrapped.__name__) == profile_name assert _profile_from_test_name(wrapped.__name__) == profile_name
@@ -1119,22 +1121,22 @@ def use_profile(profile_name):
class AnyFloat: class AnyFloat:
"""Any float. Use this in assertEqual() calls to assert that it is a float. """Any float. Use this in assertEqual() calls to assert that it is a float.
""" """
def __eq__(self, other): def __eq__(self, other: object) -> bool:
return isinstance(other, float) return isinstance(other, float)
class AnyString: class AnyString:
"""Any string. Use this in assertEqual() calls to assert that it is a string. """Any string. Use this in assertEqual() calls to assert that it is a string.
""" """
def __eq__(self, other): def __eq__(self, other: object) -> bool:
return isinstance(other, str) return isinstance(other, str)
class AnyStringWith: class AnyStringWith:
def __init__(self, contains=None): def __init__(self, contains: Optional[str]=None) -> None:
self.contains = contains self.contains = contains
def __eq__(self, other): def __eq__(self, other: object) -> bool:
if not isinstance(other, str): if not isinstance(other, str):
return False return False
@@ -1143,16 +1145,15 @@ class AnyStringWith:
return self.contains in other return self.contains in other
def __repr__(self): def __repr__(self) -> str:
return 'AnyStringWith<{!r}>'.format(self.contains) return 'AnyStringWith<{!r}>'.format(self.contains)
def get_manifest(): def get_manifest() -> Optional[Manifest]:
path = './target/partial_parse.msgpack' path = './target/partial_parse.msgpack'
if os.path.exists(path): if os.path.exists(path):
with open(path, 'rb') as fp: with open(path, 'rb') as fp:
manifest_mp = fp.read() manifest_mp = fp.read()
manifest: Manifest = Manifest.from_msgpack(manifest_mp) return Manifest.from_msgpack(manifest_mp)
return manifest
else: else:
return None return None

View File

@@ -5,9 +5,9 @@ DEFAULT_DICT_PARAMS: Any
EncodedData = Union[str, bytes, bytearray] EncodedData = Union[str, bytes, bytearray]
Encoder = Callable[[Dict], EncodedData] Encoder = Callable[[Dict], EncodedData]
Decoder = Callable[[EncodedData], Dict] Decoder = Callable[[EncodedData], Dict]
T = TypeVar('T', bound='DataClassMessagePackMixin') T_DataClassMessagePackMixin = TypeVar('T', bound='DataClassMessagePackMixin')
class DataClassMessagePackMixin(DataClassDictMixin): class DataClassMessagePackMixin(DataClassDictMixin):
def to_msgpack(self, encoder: Encoder=..., dict_params: Mapping=..., **encoder_kwargs: Any) -> EncodedData: ... def to_msgpack(self, encoder: Encoder=..., dict_params: Mapping=..., **encoder_kwargs: Any) -> EncodedData: ...
@classmethod @classmethod
def from_msgpack(cls: Type[T], data: EncodedData, decoder: Decoder=..., dict_params: Mapping=..., **decoder_kwargs: Any) -> DataClassDictMixin: ... def from_msgpack(cls: Type[T], data: EncodedData, decoder: Decoder=..., dict_params: Mapping=..., **decoder_kwargs: Any) -> T_DataClassMessagePackMixin: ...

View File

@@ -16,7 +16,7 @@ deps =
description = mypy static type checking description = mypy static type checking
basepython = python3.8 basepython = python3.8
skip_install = true skip_install = true
commands = mypy core/dbt commands = mypy --strict --warn-unreachable ./
deps = deps =
-rdev-requirements.txt -rdev-requirements.txt
-reditable-requirements.txt -reditable-requirements.txt