Compare commits

...

1 Commits

Author SHA1 Message Date
Ian Knox
a5440f9dba all fixes save one that's actually a bug 2023-06-12 14:48:44 -05:00
13 changed files with 61 additions and 74 deletions

View File

@@ -46,7 +46,7 @@ from dbt.exceptions import (
from dbt.adapters.protocol import AdapterConfig, ConnectionManagerProtocol
from dbt.clients.agate_helper import empty_table, merge_tables, table_from_rows
from dbt.clients.jinja import MacroGenerator
from dbt.contracts.graph.manifest import Manifest, MacroManifest
from dbt.contracts.graph.manifest import AnyManifest, Manifest, MacroManifest
from dbt.contracts.graph.nodes import ResultNode
from dbt.events.functions import fire_event, warn_or_error
from dbt.events.types import (
@@ -349,10 +349,8 @@ class BaseAdapter(metaclass=AdapterMeta):
self.connections.set_query_header,
base_macros_only=base_macros_only,
)
# TODO CT-211
self._macro_manifest_lazy = manifest # type: ignore[assignment]
# TODO CT-211
return self._macro_manifest_lazy # type: ignore[return-value]
self._macro_manifest_lazy = manifest
return self._macro_manifest_lazy
def clear_macro_manifest(self):
if self._macro_manifest_lazy is not None:
@@ -983,7 +981,7 @@ class BaseAdapter(metaclass=AdapterMeta):
def execute_macro(
self,
macro_name: str,
manifest: Optional[Manifest] = None,
manifest: Optional[AnyManifest] = None,
project: Optional[str] = None,
context_override: Optional[Dict[str, Any]] = None,
kwargs: Dict[str, Any] = None,
@@ -992,7 +990,7 @@ class BaseAdapter(metaclass=AdapterMeta):
"""Look macro_name up in the manifest and execute its results.
:param macro_name: The name of the macro to execute.
:param manifest: The manifest to use for generating the base macro
:param provided_manifest: The manifest to use for generating the base macro
execution context. If none is provided, use the internal manifest.
:param project: The name of the project to search in, or None for the
first match.
@@ -1004,16 +1002,15 @@ class BaseAdapter(metaclass=AdapterMeta):
if kwargs is None:
kwargs = {}
if context_override is None:
context_override = {}
if manifest is None:
# TODO CT-211
manifest = self._macro_manifest # type: ignore[assignment]
# TODO CT-211
macro = manifest.find_macro_by_name( # type: ignore[union-attr]
macro_name, self.config.project_name, project
)
manifest = self._macro_manifest
macro = manifest.find_macro_by_name(macro_name, self.config.project_name, project)
if macro is None:
if project is None:
package_name = "any package"
@@ -1036,6 +1033,7 @@ class BaseAdapter(metaclass=AdapterMeta):
manifest=manifest, # type: ignore[arg-type]
package_name=project,
)
macro_context.update(context_override)
macro_function = MacroGenerator(macro, macro_context)

View File

@@ -67,11 +67,10 @@ AdapterConfig_T = TypeVar("AdapterConfig_T", bound=AdapterConfig)
ConnectionManager_T = TypeVar("ConnectionManager_T", bound=ConnectionManagerProtocol)
Relation_T = TypeVar("Relation_T", bound=RelationProtocol)
Column_T = TypeVar("Column_T", bound=ColumnProtocol)
Compiler_T = TypeVar("Compiler_T", bound=CompilerProtocol)
Compiler_T = TypeVar("Compiler_T", bound=CompilerProtocol, covariant=True)
# TODO CT-211
class AdapterProtocol( # type: ignore[misc]
class AdapterProtocol(
Protocol,
Generic[
AdapterConfig_T,

View File

@@ -99,22 +99,15 @@ class SQLConnectionManager(BaseConnectionManager):
)
@classmethod
def process_results(
cls, column_names: Iterable[str], rows: Iterable[Any]
) -> List[Dict[str, Any]]:
# TODO CT-211
unique_col_names = dict() # type: ignore[var-annotated]
# TODO CT-211
for idx in range(len(column_names)): # type: ignore[arg-type]
# TODO CT-211
col_name = column_names[idx] # type: ignore[index]
def process_results(cls, column_names: List[str], rows: Iterable[Any]) -> List[Dict[str, Any]]:
unique_col_names: Dict = dict()
for idx in range(len(column_names)):
col_name = column_names[idx]
if col_name in unique_col_names:
unique_col_names[col_name] += 1
# TODO CT-211
column_names[idx] = f"{col_name}_{unique_col_names[col_name]}" # type: ignore[index] # noqa
column_names[idx] = f"{col_name}_{unique_col_names[col_name]}"
else:
# TODO CT-211
unique_col_names[column_names[idx]] = 1 # type: ignore[index]
unique_col_names[column_names[idx]] = 1
return [dict(zip(column_names, row)) for row in rows]
@classmethod

View File

@@ -70,8 +70,7 @@ class SQLAdapter(BaseAdapter):
@classmethod
def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str:
# TODO CT-211
decimals = agate_table.aggregate(agate.MaxPrecision(col_idx)) # type: ignore[attr-defined]
decimals = agate_table.aggregate(agate.MaxPrecision(col_idx))
return "float8" if decimals else "integer"
@classmethod

View File

@@ -46,8 +46,7 @@ def get_datetime_module_context() -> Dict[str, Any]:
def get_re_module_context() -> Dict[str, Any]:
# TODO CT-211
context_exports = re.__all__ # type: ignore[attr-defined]
context_exports = re.__all__
return {name: getattr(re, name) for name in context_exports}

View File

@@ -1,7 +1,7 @@
from abc import abstractmethod
from copy import deepcopy
from dataclasses import dataclass
from typing import List, Iterator, Dict, Any, TypeVar, Generic
from typing import List, Iterator, Dict, Any, TypeVar, Generic, Union
from dbt.config import RuntimeConfig, Project, IsFQNResource
from dbt.contracts.graph.model_config import BaseConfig, get_config_for, _listify
@@ -131,7 +131,7 @@ class BaseContextConfigGenerator(Generic[T]):
project_name: str,
base: bool,
patch_config_dict: Dict[str, Any] = None,
) -> BaseConfig:
) -> T:
own_config = self.get_node_project(project_name)
result = self.initial_result(resource_type=resource_type, base=base)
@@ -155,8 +155,7 @@ class BaseContextConfigGenerator(Generic[T]):
result = self._update_from_config(result, fqn_config)
# this is mostly impactful in the snapshot config case
# TODO CT-211
return result # type: ignore[return-value]
return result
@abstractmethod
def calculate_node_config_dict(
@@ -227,7 +226,6 @@ class UnrenderedConfigGenerator(BaseContextConfigGenerator[Dict[str, Any]]):
base: bool,
patch_config_dict: dict = None,
) -> Dict[str, Any]:
# TODO CT-211
return self.calculate_node_config(
config_call_dict=config_call_dict,
fqn=fqn,
@@ -235,7 +233,7 @@ class UnrenderedConfigGenerator(BaseContextConfigGenerator[Dict[str, Any]]):
project_name=project_name,
base=base,
patch_config_dict=patch_config_dict,
) # type: ignore[return-value]
)
def initial_result(self, resource_type: NodeType, base: bool) -> Dict[str, Any]:
return {}
@@ -321,11 +319,11 @@ class ContextConfig:
self, base: bool = False, *, rendered: bool = True, patch_config_dict: dict = None
) -> Dict[str, Any]:
if rendered:
# TODO CT-211
src = ContextConfigGenerator(self._active_project) # type: ignore[var-annotated]
src: Union[ContextConfigGenerator, UnrenderedConfigGenerator] = ContextConfigGenerator(
self._active_project
)
else:
# TODO CT-211
src = UnrenderedConfigGenerator(self._active_project) # type: ignore[assignment]
src = UnrenderedConfigGenerator(self._active_project)
return src.calculate_node_config_dict(
config_call_dict=self._config_call_dict,

View File

@@ -7,6 +7,7 @@ from dbt.exceptions import (
from dbt.config.runtime import RuntimeConfig
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.nodes import Macro, ResultNode
from dbt.contracts.files import SourceFile
from dbt.context.base import contextmember
from dbt.context.configured import SchemaYamlContext
@@ -65,8 +66,8 @@ class DocsRuntimeContext(SchemaYamlContext):
file_id = target_doc.file_id
if file_id in self.manifest.files:
source_file = self.manifest.files[file_id]
# TODO CT-211
source_file.add_node(self.node.unique_id) # type: ignore[union-attr]
if type(source_file) == SourceFile:
source_file.add_node(self.node.unique_id)
else:
raise DocTargetNotFoundError(
node=self.node, target_doc_name=doc_name, target_doc_package=doc_package_name

View File

@@ -32,14 +32,14 @@ class MacroNamespace(Mapping):
self.packages: Dict[str, FlatNamespace] = packages
self.global_project_namespace: FlatNamespace = global_project_namespace
def _search_order(self) -> Iterable[Union[FullNamespace, FlatNamespace]]:
def _search_order(
self,
) -> Iterable[Union[FullNamespace, FlatNamespace, Dict[str, FlatNamespace]]]:
yield self.local_namespace # local package
yield self.global_namespace # root package
# TODO CT-211
yield self.packages # type: ignore[misc] # non-internal packages
yield self.packages
yield {
# TODO CT-211
GLOBAL_PROJECT_NAME: self.global_project_namespace, # type: ignore[misc] # dbt
GLOBAL_PROJECT_NAME: self.global_project_namespace,
}
yield self.global_project_namespace # other internal project besides dbt

View File

@@ -28,6 +28,7 @@ from dbt.context.macro_resolver import MacroResolver, TestMacroNamespace
from dbt.context.macros import MacroNamespaceBuilder, MacroNamespace
from dbt.context.manifest import ManifestContext
from dbt.contracts.connection import AdapterResponse
from dbt.contracts.files import SourceFile, SchemaSourceFile
from dbt.contracts.graph.manifest import Manifest, Disabled
from dbt.contracts.graph.nodes import (
Macro,
@@ -39,6 +40,7 @@ from dbt.contracts.graph.nodes import (
ManifestNode,
RefArgs,
AccessType,
GenericTestNode,
)
from dbt.contracts.graph.metrics import MetricReference, ResolvedMetricReference
from dbt.contracts.graph.unparsed import NodeVersion
@@ -1290,9 +1292,8 @@ class ProviderContext(ManifestContext):
if self.model.file_id in self.manifest.files:
source_file = self.manifest.files[self.model.file_id]
# Schema files should never get here
if source_file.parse_file_type != "schema":
# TODO CT-211
source_file.env_vars.append(var) # type: ignore[union-attr]
if source_file.parse_file_type != "schema" and type(source_file) == SourceFile:
source_file.env_vars.append(var)
return return_value
else:
raise EnvVarMissingError(var)
@@ -1353,36 +1354,28 @@ class ModelContext(ProviderContext):
def pre_hooks(self) -> List[Dict[str, Any]]:
if self.model.resource_type in [NodeType.Source, NodeType.Test]:
return []
# TODO CT-211
return [
h.to_dict(omit_none=True) for h in self.model.config.pre_hook # type: ignore[union-attr] # noqa
]
return [h.to_dict(omit_none=True) for h in self.model.config.pre_hook]
@contextproperty
def post_hooks(self) -> List[Dict[str, Any]]:
if self.model.resource_type in [NodeType.Source, NodeType.Test]:
return []
# TODO CT-211
return [
h.to_dict(omit_none=True) for h in self.model.config.post_hook # type: ignore[union-attr] # noqa
]
return [h.to_dict(omit_none=True) for h in self.model.config.post_hook]
@contextproperty
def sql(self) -> Optional[str]:
# only doing this in sql model for backward compatible
if (
getattr(self.model, "extra_ctes_injected", None)
and self.model.language == ModelLanguage.sql # type: ignore[union-attr]
and self.model.language == ModelLanguage.sql
):
# TODO CT-211
return self.model.compiled_code # type: ignore[union-attr]
return self.model.compiled_code
return None
@contextproperty
def compiled_code(self) -> Optional[str]:
if getattr(self.model, "extra_ctes_injected", None):
# TODO CT-211
return self.model.compiled_code # type: ignore[union-attr]
return self.model.compiled_code
return None
@contextproperty
@@ -1652,13 +1645,16 @@ class TestContext(ProviderContext):
return_value if var in os.environ else DEFAULT_ENV_PLACEHOLDER
)
# the "model" should only be test nodes, but just in case, check
# TODO CT-211
if self.model.resource_type == NodeType.Test and self.model.file_key_name: # type: ignore[union-attr] # noqa
if (
self.model.resource_type == NodeType.Test
and type(self.model) == GenericTestNode
and self.model.file_key_name
):
source_file = self.manifest.files[self.model.file_id]
# TODO CT-211
(yaml_key, name) = self.model.file_key_name.split(".") # type: ignore[union-attr] # noqa
# TODO CT-211
source_file.add_env_var(var, yaml_key, name) # type: ignore[union-attr]
(yaml_key, name) = self.model.file_key_name.split(".")
if type(source_file) == SchemaSourceFile:
source_file.add_env_var(var, yaml_key, name)
return return_value
else:
raise EnvVarMissingError(var)

View File

@@ -563,6 +563,8 @@ class TestConfig(NodeAndTestConfig):
fail_calc: str = "count(*)"
warn_if: str = "!= 0"
error_if: str = "!= 0"
pre_hook: List = []
post_hook: List = []
@classmethod
def same_contents(cls, unrendered: Dict[str, Any], other: Dict[str, Any]) -> bool:

View File

@@ -786,6 +786,7 @@ class SeedNode(ParsedNode): # No SQLDefaults!
root_path: Optional[str] = None
depends_on: MacroDependsOn = field(default_factory=MacroDependsOn)
state_relation: Optional[StateRelation] = None
compiled_code = None
def same_seeds(self, other: "SeedNode") -> bool:
# for seeds, we check the hashes. If the hashes are different types,

View File

@@ -1117,7 +1117,7 @@ class ManifestLoader:
root_config: RuntimeConfig,
macro_hook: Callable[[Manifest], Any],
base_macros_only=False,
) -> Manifest:
) -> MacroManifest:
with PARSING_STATE:
# base_only/base_macros_only: for testing only,
# allows loading macros without running 'dbt deps' first

View File

@@ -1,6 +1,6 @@
from collections.abc import Sequence
from typing import Any, Optional, Callable, Iterable, Dict, Union
from typing import Any, Optional, Callable, Iterable, Dict, Union, Tuple, OrderedDict
from . import data_types as data_types
from .data_types import (
@@ -52,6 +52,7 @@ class Table:
def columns(self): ...
@property
def rows(self): ...
def aggregate(self, aggregations: Any) -> OrderedDict: ...
def print_csv(self, **kwargs: Any) -> None: ...
def print_json(self, **kwargs: Any) -> None: ...
def where(self, test: Callable[[Row], bool]) -> "Table": ...