Compare commits

...

8 Commits

Author SHA1 Message Date
Nathaniel May
8c9701972a remove two init files. passes tests, fails mypy. 2021-12-16 17:22:43 -05:00
Nathaniel May
97202000ca split up long lines 2021-12-15 12:06:31 -05:00
Nathaniel May
15e373a8dd show error codes to encourage slimmer ignores 2021-12-15 12:03:33 -05:00
Nathaniel May
e6301bf4c7 address the rest of the mypy issues 2021-12-15 12:03:06 -05:00
Nathaniel May
982f7f797f fix process results 2021-12-15 11:13:28 -05:00
Nathaniel May
f649a5ffd1 ignore imports without type hints 2021-12-15 11:05:12 -05:00
Nathaniel May
d0057afd18 add missing init files 2021-12-15 10:56:51 -05:00
Nathaniel May
a106d5649c upgrade mypy 2021-12-14 16:06:06 -05:00
27 changed files with 65 additions and 54 deletions

View File

@@ -9,7 +9,7 @@ from typing import (
) )
import agate import agate
import pytz import pytz # type: ignore[import]
from dbt.exceptions import ( from dbt.exceptions import (
raise_database_error, raise_compiler_error, invalid_type_error, raise_database_error, raise_compiler_error, invalid_type_error,
@@ -275,8 +275,8 @@ class BaseAdapter(metaclass=AdapterMeta):
manifest = ManifestLoader.load_macros( manifest = ManifestLoader.load_macros(
self.config, self.connections.set_query_header self.config, self.connections.set_query_header
) )
self._macro_manifest_lazy = manifest self._macro_manifest_lazy = manifest # type: ignore[assignment]
return self._macro_manifest_lazy return self._macro_manifest_lazy # type: ignore[return-value]
def clear_macro_manifest(self): def clear_macro_manifest(self):
if self._macro_manifest_lazy is not None: if self._macro_manifest_lazy is not None:
@@ -959,10 +959,10 @@ class BaseAdapter(metaclass=AdapterMeta):
if context_override is None: if context_override is None:
context_override = {} context_override = {}
if manifest is None: # manifest has type Optional[Manifest]. working_manifest is not optional.
manifest = self._macro_manifest working_manifest: Union[Manifest, MacroManifest] = manifest or self._macro_manifest
macro = manifest.find_macro_by_name( macro = working_manifest.find_macro_by_name(
macro_name, self.config.project_name, project macro_name, self.config.project_name, project
) )
if macro is None: if macro is None:
@@ -981,7 +981,7 @@ class BaseAdapter(metaclass=AdapterMeta):
macro_context = generate_runtime_macro_context( macro_context = generate_runtime_macro_context(
macro=macro, macro=macro,
config=self.config, config=self.config,
manifest=manifest, manifest=working_manifest, # type: ignore[arg-type]
package_name=project package_name=project
) )
macro_context.update(context_override) macro_context.update(context_override)

View File

@@ -89,8 +89,11 @@ class BaseRelation(FakeAPIObject, Hashable):
if not self._is_exactish_match(k, v): if not self._is_exactish_match(k, v):
exact_match = False exact_match = False
if ( lowered_part: Optional[str] = self.path.get_lowered_part(k)
self.path.get_lowered_part(k).strip(self.quote_character) != if lowered_part is None:
approximate_match = False
elif (
lowered_part.strip(self.quote_character) !=
v.lower().strip(self.quote_character) v.lower().strip(self.quote_character)
): ):
approximate_match = False approximate_match = False

View File

@@ -79,7 +79,7 @@ Column_T = TypeVar(
Compiler_T = TypeVar('Compiler_T', bound=CompilerProtocol) Compiler_T = TypeVar('Compiler_T', bound=CompilerProtocol)
class AdapterProtocol( class AdapterProtocol( # type: ignore[misc]
Protocol, Protocol,
Generic[ Generic[
AdapterConfig_T, AdapterConfig_T,

View File

@@ -92,17 +92,17 @@ class SQLConnectionManager(BaseConnectionManager):
@classmethod @classmethod
def process_results( def process_results(
cls, cls,
column_names: Iterable[str], column_names: List[str],
rows: Iterable[Any] rows: Iterable[Any]
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
unique_col_names = dict() unique_col_names: Dict[str, int] = dict()
for idx in range(len(column_names)): for idx in range(len(column_names)):
col_name = column_names[idx] col_name = column_names[idx]
if col_name in unique_col_names: if col_name in unique_col_names:
unique_col_names[col_name] += 1 unique_col_names[col_name] += 1
column_names[idx] = f'{col_name}_{unique_col_names[col_name]}' column_names[idx] = f'{col_name}_{unique_col_names[col_name]}'
else: else:
unique_col_names[column_names[idx]] = 1 unique_col_names[col_name] = 1
return [dict(zip(column_names, row)) for row in rows] return [dict(zip(column_names, row)) for row in rows]
@classmethod @classmethod

View File

@@ -74,7 +74,7 @@ class SQLAdapter(BaseAdapter):
def convert_number_type( def convert_number_type(
cls, agate_table: agate.Table, col_idx: int cls, agate_table: agate.Table, col_idx: int
) -> str: ) -> str:
decimals = agate_table.aggregate(agate.MaxPrecision(col_idx)) decimals = agate_table.aggregate(agate.MaxPrecision(col_idx)) # type: ignore[attr-defined]
return "float8" if decimals else "integer" return "float8" if decimals else "integer"
@classmethod @classmethod

View File

@@ -12,12 +12,12 @@ from typing import (
Callable Callable
) )
import jinja2 import jinja2 # type: ignore[import]
import jinja2.ext import jinja2.ext # type: ignore[import]
import jinja2.nativetypes # type: ignore import jinja2.nativetypes # type: ignore[import]
import jinja2.nodes import jinja2.nodes # type: ignore[import]
import jinja2.parser import jinja2.parser # type: ignore[import]
import jinja2.sandbox import jinja2.sandbox # type: ignore[import]
from dbt.utils import ( from dbt.utils import (
get_dbt_macro_name, get_docs_macro_name, get_materialization_macro_name, get_dbt_macro_name, get_docs_macro_name, get_materialization_macro_name,

View File

@@ -1,4 +1,4 @@
import jinja2 import jinja2 # type: ignore[import]
from dbt.clients.jinja import get_environment from dbt.clients.jinja import get_environment
from dbt.exceptions import raise_compiler_error from dbt.exceptions import raise_compiler_error

View File

@@ -1,5 +1,5 @@
import functools import functools
import requests import requests # type: ignore[import]
from dbt.events.functions import fire_event from dbt.events.functions import fire_event
from dbt.events.types import ( from dbt.events.types import (
RegistryProgressMakingGETRequest, RegistryProgressMakingGETRequest,

View File

@@ -9,7 +9,7 @@ import shutil
import subprocess import subprocess
import sys import sys
import tarfile import tarfile
import requests import requests # type: ignore[import]
import stat import stat
from typing import ( from typing import (
Type, NoReturn, List, Optional, Dict, Any, Tuple, Callable, Union Type, NoReturn, List, Optional, Dict, Any, Tuple, Callable, Union

View File

@@ -1,6 +1,6 @@
import dbt.exceptions import dbt.exceptions
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import yaml import yaml # type: ignore[import]
# the C version is faster, but it doesn't always exist # the C version is faster, but it doesn't always exist
try: try:

View File

View File

@@ -21,7 +21,7 @@ from dbt.version import __version__ as dbt_version
# These modules are added to the context. Consider alternative # These modules are added to the context. Consider alternative
# approaches which will extend well to potentially many modules # approaches which will extend well to potentially many modules
import pytz import pytz # type: ignore[import]
import datetime import datetime
import re import re
@@ -82,7 +82,7 @@ def get_datetime_module_context() -> Dict[str, Any]:
def get_re_module_context() -> Dict[str, Any]: def get_re_module_context() -> Dict[str, Any]:
context_exports = re.__all__ context_exports = re.__all__ # type: ignore[attr-defined]
return { return {
name: getattr(re, name) for name in context_exports name: getattr(re, name) for name in context_exports

View File

@@ -150,7 +150,7 @@ class BaseContextConfigGenerator(Generic[T]):
result = self._update_from_config(result, fqn_config) result = self._update_from_config(result, fqn_config)
# this is mostly impactful in the snapshot config case # this is mostly impactful in the snapshot config case
return result return result # type: ignore[return-value]
@abstractmethod @abstractmethod
def calculate_node_config_dict( def calculate_node_config_dict(
@@ -227,7 +227,7 @@ class UnrenderedConfigGenerator(BaseContextConfigGenerator[Dict[str, Any]]):
base: bool, base: bool,
patch_config_dict: dict = None patch_config_dict: dict = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
return self.calculate_node_config( return self.calculate_node_config( # type: ignore[return-value]
config_call_dict=config_call_dict, config_call_dict=config_call_dict,
fqn=fqn, fqn=fqn,
resource_type=resource_type, resource_type=resource_type,
@@ -299,9 +299,9 @@ class ContextConfig:
patch_config_dict: dict = None patch_config_dict: dict = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
if rendered: if rendered:
src = ContextConfigGenerator(self._active_project) src = ContextConfigGenerator(self._active_project) # type: ignore[var-annotated]
else: else:
src = UnrenderedConfigGenerator(self._active_project) src = UnrenderedConfigGenerator(self._active_project) # type: ignore[assignment]
return src.calculate_node_config_dict( return src.calculate_node_config_dict(
config_call_dict=self._config_call_dict, config_call_dict=self._config_call_dict,

View File

@@ -68,7 +68,7 @@ class DocsRuntimeContext(SchemaYamlContext):
file_id = target_doc.file_id file_id = target_doc.file_id
if file_id in self.manifest.files: if file_id in self.manifest.files:
source_file = self.manifest.files[file_id] source_file = self.manifest.files[file_id]
source_file.add_node(self.node.unique_id) source_file.add_node(self.node.unique_id) # type: ignore[union-attr]
else: else:
doc_target_not_found(self.node, doc_name, doc_package_name) doc_target_not_found(self.node, doc_name, doc_package_name)

View File

@@ -37,10 +37,10 @@ class MacroNamespace(Mapping):
self.global_project_namespace: FlatNamespace = global_project_namespace self.global_project_namespace: FlatNamespace = global_project_namespace
def _search_order(self) -> Iterable[Union[FullNamespace, FlatNamespace]]: def _search_order(self) -> Iterable[Union[FullNamespace, FlatNamespace]]:
yield self.local_namespace # local package yield self.local_namespace # local package
yield self.global_namespace # root package yield self.global_namespace # type: ignore[misc] # root package
yield self.packages # non-internal packages yield self.packages # type: ignore[misc] # non-internal packages
yield { yield { # type: ignore[misc]
GLOBAL_PROJECT_NAME: self.global_project_namespace, # dbt GLOBAL_PROJECT_NAME: self.global_project_namespace, # dbt
} }
yield self.global_project_namespace # other internal project besides dbt yield self.global_project_namespace # other internal project besides dbt

View File

@@ -21,6 +21,7 @@ from dbt.context.macro_resolver import MacroResolver, TestMacroNamespace
from .macros import MacroNamespaceBuilder, MacroNamespace from .macros import MacroNamespaceBuilder, MacroNamespace
from .manifest import ManifestContext from .manifest import ManifestContext
from dbt.contracts.connection import AdapterResponse from dbt.contracts.connection import AdapterResponse
from dbt.contracts.files import SchemaSourceFile
from dbt.contracts.graph.manifest import ( from dbt.contracts.graph.manifest import (
Manifest, Disabled Manifest, Disabled
) )
@@ -1189,7 +1190,7 @@ class ProviderContext(ManifestContext):
source_file = self.manifest.files[self.model.file_id] source_file = self.manifest.files[self.model.file_id]
# Schema files should never get here # Schema files should never get here
if source_file.parse_file_type != 'schema': if source_file.parse_file_type != 'schema':
source_file.env_vars.append(var) source_file.env_vars.append(var) # type: ignore[union-attr]
return return_value return return_value
else: else:
msg = f"Env var required but not provided: '{var}'" msg = f"Env var required but not provided: '{var}'"
@@ -1230,7 +1231,8 @@ class ModelContext(ProviderContext):
if self.model.resource_type in [NodeType.Source, NodeType.Test]: if self.model.resource_type in [NodeType.Source, NodeType.Test]:
return [] return []
return [ return [
h.to_dict(omit_none=True) for h in self.model.config.pre_hook h.to_dict(omit_none=True)
for h in self.model.config.pre_hook # type: ignore[union-attr]
] ]
@contextproperty @contextproperty
@@ -1238,13 +1240,14 @@ class ModelContext(ProviderContext):
if self.model.resource_type in [NodeType.Source, NodeType.Test]: if self.model.resource_type in [NodeType.Source, NodeType.Test]:
return [] return []
return [ return [
h.to_dict(omit_none=True) for h in self.model.config.post_hook h.to_dict(omit_none=True)
for h in self.model.config.post_hook # type: ignore[union-attr]
] ]
@contextproperty @contextproperty
def sql(self) -> Optional[str]: def sql(self) -> Optional[str]:
if getattr(self.model, 'extra_ctes_injected', None): if getattr(self.model, 'extra_ctes_injected', None):
return self.model.compiled_sql return self.model.compiled_sql # type: ignore[union-attr]
return None return None
@contextproperty @contextproperty
@@ -1495,9 +1498,14 @@ class TestContext(ProviderContext):
if self.model: if self.model:
self.manifest.env_vars[var] = return_value self.manifest.env_vars[var] = return_value
# the "model" should only be test nodes, but just in case, check # the "model" should only be test nodes, but just in case, check
if self.model.resource_type == NodeType.Test and self.model.file_key_name: if (
source_file = self.manifest.files[self.model.file_id] self.model.resource_type ==
(yaml_key, name) = self.model.file_key_name.split('.') NodeType.Test and self.model.file_key_name # type: ignore[union-attr]
):
source_file: SchemaSourceFile = \
self.manifest.files[self.model.file_id] # type: ignore[assignment]
(yaml_key, name) = \
self.model.file_key_name.split('.') # type: ignore[union-attr]
source_file.add_env_var(var, yaml_key, name) source_file.add_env_var(var, yaml_key, name)
return return_value return return_value
else: else:

View File

@@ -80,7 +80,7 @@ class GitPackage(Package):
class RegistryPackage(Package): class RegistryPackage(Package):
package: str package: str
version: Union[RawVersion, List[RawVersion]] version: Union[RawVersion, List[RawVersion]]
install_prerelease: Optional[bool] = False install_prerelease: bool = False
def get_versions(self) -> List[str]: def get_versions(self) -> List[str]:
if isinstance(self.version, list): if isinstance(self.version, list):

View File

@@ -5,7 +5,7 @@ import re
from dataclasses import fields from dataclasses import fields
from enum import Enum from enum import Enum
from datetime import datetime from datetime import datetime
from dateutil.parser import parse from dateutil.parser import parse # type: ignore[import]
from hologram import JsonSchemaMixin, FieldEncoder, ValidationError from hologram import JsonSchemaMixin, FieldEncoder, ValidationError

View File

View File

@@ -1,6 +1,6 @@
from typing import Iterable, List from typing import Iterable, List
import jinja2 import jinja2 # type: ignore[import]
from dbt.exceptions import ParsingException from dbt.exceptions import ParsingException
from dbt.clients import jinja from dbt.clients import jinja

View File

@@ -1,6 +1,6 @@
from typing import Iterable, List from typing import Iterable, List
import jinja2 import jinja2 # type: ignore[import]
from dbt.clients import jinja from dbt.clients import jinja
from dbt.contracts.graph.unparsed import UnparsedMacro from dbt.contracts.graph.unparsed import UnparsedMacro

View File

@@ -5,7 +5,7 @@ import re
import shutil import shutil
from typing import Optional from typing import Optional
import yaml import yaml # type: ignore[import]
import click import click
import dbt.config import dbt.config

View File

@@ -15,10 +15,10 @@ from snowplow_tracker import SelfDescribingJson
from datetime import datetime from datetime import datetime
import logbook import logbook
import pytz import pytz # type: ignore[import]
import platform import platform
import uuid import uuid
import requests import requests # type: ignore[import]
import os import os
sp_logger.setLevel(100) sp_logger.setLevel(100)

View File

@@ -6,10 +6,10 @@ import decimal
import functools import functools
import hashlib import hashlib
import itertools import itertools
import jinja2 import jinja2 # type: ignore[import]
import json import json
import os import os
import requests import requests # type: ignore[import]
import time import time
from contextlib import contextmanager from contextlib import contextmanager

View File

@@ -5,7 +5,7 @@ import glob
import json import json
from typing import Iterator from typing import Iterator
import requests import requests # type: ignore[import]
import dbt.exceptions import dbt.exceptions
import dbt.semver import dbt.semver

View File

@@ -3,7 +3,7 @@ flake8
flaky flaky
freezegun==0.3.12 freezegun==0.3.12
ipdb ipdb
mypy==0.782 mypy==0.910
pip-tools pip-tools
pytest pytest
pytest-dotenv pytest-dotenv

View File

@@ -16,7 +16,7 @@ deps =
description = mypy static type checking description = mypy static type checking
basepython = python3.8 basepython = python3.8
skip_install = true skip_install = true
commands = mypy core/dbt commands = mypy --show-error-codes core/dbt
deps = deps =
-rdev-requirements.txt -rdev-requirements.txt
-reditable-requirements.txt -reditable-requirements.txt