Compare commits

...

5 Commits

Author SHA1 Message Date
Nathaniel May
edfca9b425 some changes I forgot to commit 2021-12-13 13:27:49 -05:00
Nathaniel May
f6a4bc8dc6 Any test types 2021-12-06 11:36:41 -05:00
Nathaniel May
97971f804d get_manifest 2021-12-06 11:22:46 -05:00
Nathaniel May
edf4142583 mypy: run strict. run everywhere. 2021-12-06 10:49:34 -05:00
Nathaniel May
b87be582ad update mypy to latest 2021-12-06 10:45:23 -05:00
4 changed files with 31 additions and 30 deletions

View File

@@ -3,7 +3,7 @@ flake8
flaky
freezegun==0.3.12
ipdb
mypy==0.782
mypy==0.910
pip-tools
pytest
pytest-dotenv

View File

@@ -11,6 +11,7 @@ import unittest
from contextlib import contextmanager
from datetime import datetime
from functools import wraps
from typing import Optional
import pytest
import yaml
@@ -88,7 +89,8 @@ class TestArgs:
self.__dict__.update(kwargs)
def _profile_from_test_name(test_name):
# TODO since we only ever test postgres now we don't really need this
def _profile_from_test_name(test_name: str) -> str:
adapter_names = ('postgres', 'presto')
adapters_in_name = sum(x in test_name for x in adapter_names)
if adapters_in_name != 1:
@@ -555,7 +557,7 @@ class DBTIntegrationTest(unittest.TestCase):
fire_event(IntegrationTestInfo(msg=msg))
return dbt.handle_and_check(final_args)
def run_sql_file(self, path, kwargs=None):
def run_sql_file(self, path: str, kwargs=None):
with open(path, 'r') as f:
statements = f.read().split(";")
for statement in statements:
@@ -601,7 +603,7 @@ class DBTIntegrationTest(unittest.TestCase):
conn.handle.commit()
conn.transaction_open = False
def run_sql_common(self, sql, fetch, conn):
def run_sql_common(self, sql, fetch, conn) -> None:
with conn.handle.cursor() as cursor:
try:
cursor.execute(sql)
@@ -621,7 +623,7 @@ class DBTIntegrationTest(unittest.TestCase):
finally:
conn.transaction_open = False
def run_sql(self, query, fetch='None', kwargs=None, connection_name=None):
def run_sql(self, query, fetch='None', kwargs=None, connection_name=None) -> None:
if connection_name is None:
connection_name = '__test'
@@ -735,14 +737,14 @@ class DBTIntegrationTest(unittest.TestCase):
conn = self.adapter.connections.get_thread_connection()
yield conn
def get_relation_columns(self, relation):
def get_relation_columns(self, relation) -> List[Tuple[str, str, int]]:
with self.get_connection():
columns = self.adapter.get_columns_in_relation(relation)
return sorted(((c.name, c.dtype, c.char_size) for c in columns),
key=lambda x: x[0])
def get_table_columns(self, table, schema=None, database=None):
def get_table_columns(self, table, schema=None, database=None): # type: ignore
schema = self.unique_schema() if schema is None else schema
database = self.default_database if database is None else database
relation = self.adapter.Relation.create(
@@ -1012,23 +1014,23 @@ class DBTIntegrationTest(unittest.TestCase):
)
)
def assertTableDoesNotExist(self, table, schema=None, database=None):
columns = self.get_table_columns(table, schema, database)
def assertTableDoesNotExist(self, table, schema=None, database=None): # type: ignore
columns = self.get_table_columns(table, schema, database) # type: ignore
self.assertEqual(
len(columns),
0
)
def assertTableDoesExist(self, table, schema=None, database=None):
columns = self.get_table_columns(table, schema, database)
def assertTableDoesExist(self, table, schema=None, database=None): # type: ignore
columns = self.get_table_columns(table, schema, database) # type: ignore
self.assertGreater(
len(columns),
0
)
def _assertTableColumnsEqual(self, relation_a, relation_b):
def _assertTableColumnsEqual(self, relation_a, relation_b): # type: ignore
table_a_result = self.get_relation_columns(relation_a)
table_b_result = self.get_relation_columns(relation_b)
@@ -1057,11 +1059,11 @@ class DBTIntegrationTest(unittest.TestCase):
relation_a, relation_b, a_name, a_size, b_size
))
def assertEquals(self, *args, **kwargs):
def assertEquals(self, *args, **kwargs) -> None: # type: ignore
# assertEquals is deprecated. This makes the warnings less chatty
self.assertEqual(*args, **kwargs)
def assertBetween(self, timestr, start, end=None):
def assertBetween(self, timestr: str, start: datetime, end: Optional[datetime]=None) -> None:
datefmt = '%Y-%m-%dT%H:%M:%S.%fZ'
if end is None:
end = datetime.utcnow()
@@ -1079,18 +1081,18 @@ class DBTIntegrationTest(unittest.TestCase):
end.strftime(datefmt))
)
def copy_file(self, src, dest) -> None:
def copy_file(self, src: str, dest: str) -> None:
# move files in the temp testing dir created
shutil.copyfile(
os.path.join(self.test_root_dir, src),
os.path.join(self.test_root_dir, dest),
)
def rm_file(self, src) -> None:
def rm_file(self, src: str) -> None:
os.remove(os.path.join(self.test_root_dir, src))
def use_profile(profile_name):
def use_profile(profile_name): # type: ignore
"""A decorator to declare a test method as using a particular profile.
Handles both setting the nose attr and calling self.use_profile.
@@ -1105,10 +1107,10 @@ def use_profile(profile_name):
def test_snowflake_thing(self):
self.assertEqual(self.adapter_type, 'snowflake')
"""
def outer(wrapped):
def outer(wrapped): # type: ignore
@getattr(pytest.mark, 'profile_'+profile_name)
@wraps(wrapped)
def func(self, *args, **kwargs):
def func(self, *args, **kwargs): # type: ignore
return wrapped(self, *args, **kwargs)
# sanity check at import time
assert _profile_from_test_name(wrapped.__name__) == profile_name
@@ -1119,22 +1121,22 @@ def use_profile(profile_name):
class AnyFloat:
"""Any float. Use this in assertEqual() calls to assert that it is a float.
"""
def __eq__(self, other):
def __eq__(self, other: object) -> bool:
return isinstance(other, float)
class AnyString:
"""Any string. Use this in assertEqual() calls to assert that it is a string.
"""
def __eq__(self, other):
def __eq__(self, other: object) -> bool:
return isinstance(other, str)
class AnyStringWith:
def __init__(self, contains=None):
def __init__(self, contains: Optional[str]=None) -> None:
self.contains = contains
def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if not isinstance(other, str):
return False
@@ -1143,16 +1145,15 @@ class AnyStringWith:
return self.contains in other
def __repr__(self):
def __repr__(self) -> str:
return 'AnyStringWith<{!r}>'.format(self.contains)
def get_manifest():
def get_manifest() -> Optional[Manifest]:
path = './target/partial_parse.msgpack'
if os.path.exists(path):
with open(path, 'rb') as fp:
manifest_mp = fp.read()
manifest: Manifest = Manifest.from_msgpack(manifest_mp)
return manifest
return Manifest.from_msgpack(manifest_mp)
else:
return None

View File

@@ -5,9 +5,9 @@ DEFAULT_DICT_PARAMS: Any
EncodedData = Union[str, bytes, bytearray]
Encoder = Callable[[Dict], EncodedData]
Decoder = Callable[[EncodedData], Dict]
T = TypeVar('T', bound='DataClassMessagePackMixin')
T_DataClassMessagePackMixin = TypeVar('T', bound='DataClassMessagePackMixin')
class DataClassMessagePackMixin(DataClassDictMixin):
def to_msgpack(self, encoder: Encoder=..., dict_params: Mapping=..., **encoder_kwargs: Any) -> EncodedData: ...
@classmethod
def from_msgpack(cls: Type[T], data: EncodedData, decoder: Decoder=..., dict_params: Mapping=..., **decoder_kwargs: Any) -> DataClassDictMixin: ...
def from_msgpack(cls: Type[T], data: EncodedData, decoder: Decoder=..., dict_params: Mapping=..., **decoder_kwargs: Any) -> T_DataClassMessagePackMixin: ...

View File

@@ -16,7 +16,7 @@ deps =
description = mypy static type checking
basepython = python3.8
skip_install = true
commands = mypy core/dbt
commands = mypy --strict --warn-unreachable ./
deps =
-rdev-requirements.txt
-reditable-requirements.txt