forked from repo-mirrors/dbt-core
Compare commits
1 Commits
main
...
iknox/CT-2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a5440f9dba |
@@ -46,7 +46,7 @@ from dbt.exceptions import (
|
||||
from dbt.adapters.protocol import AdapterConfig, ConnectionManagerProtocol
|
||||
from dbt.clients.agate_helper import empty_table, merge_tables, table_from_rows
|
||||
from dbt.clients.jinja import MacroGenerator
|
||||
from dbt.contracts.graph.manifest import Manifest, MacroManifest
|
||||
from dbt.contracts.graph.manifest import AnyManifest, Manifest, MacroManifest
|
||||
from dbt.contracts.graph.nodes import ResultNode
|
||||
from dbt.events.functions import fire_event, warn_or_error
|
||||
from dbt.events.types import (
|
||||
@@ -349,10 +349,8 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
self.connections.set_query_header,
|
||||
base_macros_only=base_macros_only,
|
||||
)
|
||||
# TODO CT-211
|
||||
self._macro_manifest_lazy = manifest # type: ignore[assignment]
|
||||
# TODO CT-211
|
||||
return self._macro_manifest_lazy # type: ignore[return-value]
|
||||
self._macro_manifest_lazy = manifest
|
||||
return self._macro_manifest_lazy
|
||||
|
||||
def clear_macro_manifest(self):
|
||||
if self._macro_manifest_lazy is not None:
|
||||
@@ -983,7 +981,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
def execute_macro(
|
||||
self,
|
||||
macro_name: str,
|
||||
manifest: Optional[Manifest] = None,
|
||||
manifest: Optional[AnyManifest] = None,
|
||||
project: Optional[str] = None,
|
||||
context_override: Optional[Dict[str, Any]] = None,
|
||||
kwargs: Dict[str, Any] = None,
|
||||
@@ -992,7 +990,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
"""Look macro_name up in the manifest and execute its results.
|
||||
|
||||
:param macro_name: The name of the macro to execute.
|
||||
:param manifest: The manifest to use for generating the base macro
|
||||
:param provided_manifest: The manifest to use for generating the base macro
|
||||
execution context. If none is provided, use the internal manifest.
|
||||
:param project: The name of the project to search in, or None for the
|
||||
first match.
|
||||
@@ -1004,16 +1002,15 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
if context_override is None:
|
||||
context_override = {}
|
||||
|
||||
if manifest is None:
|
||||
# TODO CT-211
|
||||
manifest = self._macro_manifest # type: ignore[assignment]
|
||||
# TODO CT-211
|
||||
macro = manifest.find_macro_by_name( # type: ignore[union-attr]
|
||||
macro_name, self.config.project_name, project
|
||||
)
|
||||
manifest = self._macro_manifest
|
||||
|
||||
macro = manifest.find_macro_by_name(macro_name, self.config.project_name, project)
|
||||
|
||||
if macro is None:
|
||||
if project is None:
|
||||
package_name = "any package"
|
||||
@@ -1036,6 +1033,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
manifest=manifest, # type: ignore[arg-type]
|
||||
package_name=project,
|
||||
)
|
||||
|
||||
macro_context.update(context_override)
|
||||
|
||||
macro_function = MacroGenerator(macro, macro_context)
|
||||
|
||||
@@ -67,11 +67,10 @@ AdapterConfig_T = TypeVar("AdapterConfig_T", bound=AdapterConfig)
|
||||
ConnectionManager_T = TypeVar("ConnectionManager_T", bound=ConnectionManagerProtocol)
|
||||
Relation_T = TypeVar("Relation_T", bound=RelationProtocol)
|
||||
Column_T = TypeVar("Column_T", bound=ColumnProtocol)
|
||||
Compiler_T = TypeVar("Compiler_T", bound=CompilerProtocol)
|
||||
Compiler_T = TypeVar("Compiler_T", bound=CompilerProtocol, covariant=True)
|
||||
|
||||
|
||||
# TODO CT-211
|
||||
class AdapterProtocol( # type: ignore[misc]
|
||||
class AdapterProtocol(
|
||||
Protocol,
|
||||
Generic[
|
||||
AdapterConfig_T,
|
||||
|
||||
@@ -99,22 +99,15 @@ class SQLConnectionManager(BaseConnectionManager):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def process_results(
|
||||
cls, column_names: Iterable[str], rows: Iterable[Any]
|
||||
) -> List[Dict[str, Any]]:
|
||||
# TODO CT-211
|
||||
unique_col_names = dict() # type: ignore[var-annotated]
|
||||
# TODO CT-211
|
||||
for idx in range(len(column_names)): # type: ignore[arg-type]
|
||||
# TODO CT-211
|
||||
col_name = column_names[idx] # type: ignore[index]
|
||||
def process_results(cls, column_names: List[str], rows: Iterable[Any]) -> List[Dict[str, Any]]:
|
||||
unique_col_names: Dict = dict()
|
||||
for idx in range(len(column_names)):
|
||||
col_name = column_names[idx]
|
||||
if col_name in unique_col_names:
|
||||
unique_col_names[col_name] += 1
|
||||
# TODO CT-211
|
||||
column_names[idx] = f"{col_name}_{unique_col_names[col_name]}" # type: ignore[index] # noqa
|
||||
column_names[idx] = f"{col_name}_{unique_col_names[col_name]}"
|
||||
else:
|
||||
# TODO CT-211
|
||||
unique_col_names[column_names[idx]] = 1 # type: ignore[index]
|
||||
unique_col_names[column_names[idx]] = 1
|
||||
return [dict(zip(column_names, row)) for row in rows]
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -70,8 +70,7 @@ class SQLAdapter(BaseAdapter):
|
||||
|
||||
@classmethod
|
||||
def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str:
|
||||
# TODO CT-211
|
||||
decimals = agate_table.aggregate(agate.MaxPrecision(col_idx)) # type: ignore[attr-defined]
|
||||
decimals = agate_table.aggregate(agate.MaxPrecision(col_idx))
|
||||
return "float8" if decimals else "integer"
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -46,8 +46,7 @@ def get_datetime_module_context() -> Dict[str, Any]:
|
||||
|
||||
|
||||
def get_re_module_context() -> Dict[str, Any]:
|
||||
# TODO CT-211
|
||||
context_exports = re.__all__ # type: ignore[attr-defined]
|
||||
context_exports = re.__all__
|
||||
|
||||
return {name: getattr(re, name) for name in context_exports}
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from abc import abstractmethod
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Iterator, Dict, Any, TypeVar, Generic
|
||||
from typing import List, Iterator, Dict, Any, TypeVar, Generic, Union
|
||||
|
||||
from dbt.config import RuntimeConfig, Project, IsFQNResource
|
||||
from dbt.contracts.graph.model_config import BaseConfig, get_config_for, _listify
|
||||
@@ -131,7 +131,7 @@ class BaseContextConfigGenerator(Generic[T]):
|
||||
project_name: str,
|
||||
base: bool,
|
||||
patch_config_dict: Dict[str, Any] = None,
|
||||
) -> BaseConfig:
|
||||
) -> T:
|
||||
own_config = self.get_node_project(project_name)
|
||||
|
||||
result = self.initial_result(resource_type=resource_type, base=base)
|
||||
@@ -155,8 +155,7 @@ class BaseContextConfigGenerator(Generic[T]):
|
||||
result = self._update_from_config(result, fqn_config)
|
||||
|
||||
# this is mostly impactful in the snapshot config case
|
||||
# TODO CT-211
|
||||
return result # type: ignore[return-value]
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
def calculate_node_config_dict(
|
||||
@@ -227,7 +226,6 @@ class UnrenderedConfigGenerator(BaseContextConfigGenerator[Dict[str, Any]]):
|
||||
base: bool,
|
||||
patch_config_dict: dict = None,
|
||||
) -> Dict[str, Any]:
|
||||
# TODO CT-211
|
||||
return self.calculate_node_config(
|
||||
config_call_dict=config_call_dict,
|
||||
fqn=fqn,
|
||||
@@ -235,7 +233,7 @@ class UnrenderedConfigGenerator(BaseContextConfigGenerator[Dict[str, Any]]):
|
||||
project_name=project_name,
|
||||
base=base,
|
||||
patch_config_dict=patch_config_dict,
|
||||
) # type: ignore[return-value]
|
||||
)
|
||||
|
||||
def initial_result(self, resource_type: NodeType, base: bool) -> Dict[str, Any]:
|
||||
return {}
|
||||
@@ -321,11 +319,11 @@ class ContextConfig:
|
||||
self, base: bool = False, *, rendered: bool = True, patch_config_dict: dict = None
|
||||
) -> Dict[str, Any]:
|
||||
if rendered:
|
||||
# TODO CT-211
|
||||
src = ContextConfigGenerator(self._active_project) # type: ignore[var-annotated]
|
||||
src: Union[ContextConfigGenerator, UnrenderedConfigGenerator] = ContextConfigGenerator(
|
||||
self._active_project
|
||||
)
|
||||
else:
|
||||
# TODO CT-211
|
||||
src = UnrenderedConfigGenerator(self._active_project) # type: ignore[assignment]
|
||||
src = UnrenderedConfigGenerator(self._active_project)
|
||||
|
||||
return src.calculate_node_config_dict(
|
||||
config_call_dict=self._config_call_dict,
|
||||
|
||||
@@ -7,6 +7,7 @@ from dbt.exceptions import (
|
||||
from dbt.config.runtime import RuntimeConfig
|
||||
from dbt.contracts.graph.manifest import Manifest
|
||||
from dbt.contracts.graph.nodes import Macro, ResultNode
|
||||
from dbt.contracts.files import SourceFile
|
||||
|
||||
from dbt.context.base import contextmember
|
||||
from dbt.context.configured import SchemaYamlContext
|
||||
@@ -65,8 +66,8 @@ class DocsRuntimeContext(SchemaYamlContext):
|
||||
file_id = target_doc.file_id
|
||||
if file_id in self.manifest.files:
|
||||
source_file = self.manifest.files[file_id]
|
||||
# TODO CT-211
|
||||
source_file.add_node(self.node.unique_id) # type: ignore[union-attr]
|
||||
if type(source_file) == SourceFile:
|
||||
source_file.add_node(self.node.unique_id)
|
||||
else:
|
||||
raise DocTargetNotFoundError(
|
||||
node=self.node, target_doc_name=doc_name, target_doc_package=doc_package_name
|
||||
|
||||
@@ -32,14 +32,14 @@ class MacroNamespace(Mapping):
|
||||
self.packages: Dict[str, FlatNamespace] = packages
|
||||
self.global_project_namespace: FlatNamespace = global_project_namespace
|
||||
|
||||
def _search_order(self) -> Iterable[Union[FullNamespace, FlatNamespace]]:
|
||||
def _search_order(
|
||||
self,
|
||||
) -> Iterable[Union[FullNamespace, FlatNamespace, Dict[str, FlatNamespace]]]:
|
||||
yield self.local_namespace # local package
|
||||
yield self.global_namespace # root package
|
||||
# TODO CT-211
|
||||
yield self.packages # type: ignore[misc] # non-internal packages
|
||||
yield self.packages
|
||||
yield {
|
||||
# TODO CT-211
|
||||
GLOBAL_PROJECT_NAME: self.global_project_namespace, # type: ignore[misc] # dbt
|
||||
GLOBAL_PROJECT_NAME: self.global_project_namespace,
|
||||
}
|
||||
yield self.global_project_namespace # other internal project besides dbt
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ from dbt.context.macro_resolver import MacroResolver, TestMacroNamespace
|
||||
from dbt.context.macros import MacroNamespaceBuilder, MacroNamespace
|
||||
from dbt.context.manifest import ManifestContext
|
||||
from dbt.contracts.connection import AdapterResponse
|
||||
from dbt.contracts.files import SourceFile, SchemaSourceFile
|
||||
from dbt.contracts.graph.manifest import Manifest, Disabled
|
||||
from dbt.contracts.graph.nodes import (
|
||||
Macro,
|
||||
@@ -39,6 +40,7 @@ from dbt.contracts.graph.nodes import (
|
||||
ManifestNode,
|
||||
RefArgs,
|
||||
AccessType,
|
||||
GenericTestNode,
|
||||
)
|
||||
from dbt.contracts.graph.metrics import MetricReference, ResolvedMetricReference
|
||||
from dbt.contracts.graph.unparsed import NodeVersion
|
||||
@@ -1290,9 +1292,8 @@ class ProviderContext(ManifestContext):
|
||||
if self.model.file_id in self.manifest.files:
|
||||
source_file = self.manifest.files[self.model.file_id]
|
||||
# Schema files should never get here
|
||||
if source_file.parse_file_type != "schema":
|
||||
# TODO CT-211
|
||||
source_file.env_vars.append(var) # type: ignore[union-attr]
|
||||
if source_file.parse_file_type != "schema" and type(source_file) == SourceFile:
|
||||
source_file.env_vars.append(var)
|
||||
return return_value
|
||||
else:
|
||||
raise EnvVarMissingError(var)
|
||||
@@ -1353,36 +1354,28 @@ class ModelContext(ProviderContext):
|
||||
def pre_hooks(self) -> List[Dict[str, Any]]:
|
||||
if self.model.resource_type in [NodeType.Source, NodeType.Test]:
|
||||
return []
|
||||
# TODO CT-211
|
||||
return [
|
||||
h.to_dict(omit_none=True) for h in self.model.config.pre_hook # type: ignore[union-attr] # noqa
|
||||
]
|
||||
return [h.to_dict(omit_none=True) for h in self.model.config.pre_hook]
|
||||
|
||||
@contextproperty
|
||||
def post_hooks(self) -> List[Dict[str, Any]]:
|
||||
if self.model.resource_type in [NodeType.Source, NodeType.Test]:
|
||||
return []
|
||||
# TODO CT-211
|
||||
return [
|
||||
h.to_dict(omit_none=True) for h in self.model.config.post_hook # type: ignore[union-attr] # noqa
|
||||
]
|
||||
return [h.to_dict(omit_none=True) for h in self.model.config.post_hook]
|
||||
|
||||
@contextproperty
|
||||
def sql(self) -> Optional[str]:
|
||||
# only doing this in sql model for backward compatible
|
||||
if (
|
||||
getattr(self.model, "extra_ctes_injected", None)
|
||||
and self.model.language == ModelLanguage.sql # type: ignore[union-attr]
|
||||
and self.model.language == ModelLanguage.sql
|
||||
):
|
||||
# TODO CT-211
|
||||
return self.model.compiled_code # type: ignore[union-attr]
|
||||
return self.model.compiled_code
|
||||
return None
|
||||
|
||||
@contextproperty
|
||||
def compiled_code(self) -> Optional[str]:
|
||||
if getattr(self.model, "extra_ctes_injected", None):
|
||||
# TODO CT-211
|
||||
return self.model.compiled_code # type: ignore[union-attr]
|
||||
return self.model.compiled_code
|
||||
return None
|
||||
|
||||
@contextproperty
|
||||
@@ -1652,13 +1645,16 @@ class TestContext(ProviderContext):
|
||||
return_value if var in os.environ else DEFAULT_ENV_PLACEHOLDER
|
||||
)
|
||||
# the "model" should only be test nodes, but just in case, check
|
||||
# TODO CT-211
|
||||
if self.model.resource_type == NodeType.Test and self.model.file_key_name: # type: ignore[union-attr] # noqa
|
||||
if (
|
||||
self.model.resource_type == NodeType.Test
|
||||
and type(self.model) == GenericTestNode
|
||||
and self.model.file_key_name
|
||||
):
|
||||
source_file = self.manifest.files[self.model.file_id]
|
||||
# TODO CT-211
|
||||
(yaml_key, name) = self.model.file_key_name.split(".") # type: ignore[union-attr] # noqa
|
||||
# TODO CT-211
|
||||
source_file.add_env_var(var, yaml_key, name) # type: ignore[union-attr]
|
||||
|
||||
(yaml_key, name) = self.model.file_key_name.split(".")
|
||||
if type(source_file) == SchemaSourceFile:
|
||||
source_file.add_env_var(var, yaml_key, name)
|
||||
return return_value
|
||||
else:
|
||||
raise EnvVarMissingError(var)
|
||||
|
||||
@@ -563,6 +563,8 @@ class TestConfig(NodeAndTestConfig):
|
||||
fail_calc: str = "count(*)"
|
||||
warn_if: str = "!= 0"
|
||||
error_if: str = "!= 0"
|
||||
pre_hook: List = []
|
||||
post_hook: List = []
|
||||
|
||||
@classmethod
|
||||
def same_contents(cls, unrendered: Dict[str, Any], other: Dict[str, Any]) -> bool:
|
||||
|
||||
@@ -786,6 +786,7 @@ class SeedNode(ParsedNode): # No SQLDefaults!
|
||||
root_path: Optional[str] = None
|
||||
depends_on: MacroDependsOn = field(default_factory=MacroDependsOn)
|
||||
state_relation: Optional[StateRelation] = None
|
||||
compiled_code = None
|
||||
|
||||
def same_seeds(self, other: "SeedNode") -> bool:
|
||||
# for seeds, we check the hashes. If the hashes are different types,
|
||||
|
||||
@@ -1117,7 +1117,7 @@ class ManifestLoader:
|
||||
root_config: RuntimeConfig,
|
||||
macro_hook: Callable[[Manifest], Any],
|
||||
base_macros_only=False,
|
||||
) -> Manifest:
|
||||
) -> MacroManifest:
|
||||
with PARSING_STATE:
|
||||
# base_only/base_macros_only: for testing only,
|
||||
# allows loading macros without running 'dbt deps' first
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from typing import Any, Optional, Callable, Iterable, Dict, Union
|
||||
from typing import Any, Optional, Callable, Iterable, Dict, Union, Tuple, OrderedDict
|
||||
|
||||
from . import data_types as data_types
|
||||
from .data_types import (
|
||||
@@ -52,6 +52,7 @@ class Table:
|
||||
def columns(self): ...
|
||||
@property
|
||||
def rows(self): ...
|
||||
def aggregate(self, aggregations: Any) -> OrderedDict: ...
|
||||
def print_csv(self, **kwargs: Any) -> None: ...
|
||||
def print_json(self, **kwargs: Any) -> None: ...
|
||||
def where(self, test: Callable[[Row], bool]) -> "Table": ...
|
||||
|
||||
Reference in New Issue
Block a user