Compare commits

...

5 Commits

Author SHA1 Message Date
Gerda Shank
03ceb4f372 Changie 2023-06-13 09:30:18 -04:00
Gerda Shank
44691db3f1 Remove new uses of replace/Replaceable 2023-06-13 09:27:43 -04:00
Gerda Shank
e033c54efa Merge branch 'main' into ct-2652-remove_replaceable 2023-06-13 09:17:30 -04:00
Gerda Shank
96af8bd32c Fix up tests and pre-commit. Broken test_retry.py 2023-06-06 14:43:14 -04:00
Gerda Shank
f322ae1cf0 Remove imports and use of Replaceable 2023-06-06 12:15:57 -04:00
19 changed files with 201 additions and 177 deletions

View File

@@ -0,0 +1,6 @@
kind: Under the Hood
body: Remove uses of Replaceable class
time: 2023-06-13T09:30:13.30422-04:00
custom:
Author: gshank
Issue: "7802"

View File

@@ -1,5 +1,5 @@
from collections.abc import Hashable from collections.abc import Hashable
from dataclasses import dataclass, field from dataclasses import dataclass, field, replace
from typing import Optional, TypeVar, Any, Type, Dict, Iterator, Tuple, Set from typing import Optional, TypeVar, Any, Type, Dict, Iterator, Tuple, Set
from dbt.contracts.graph.nodes import SourceDefinition, ManifestNode, ResultNode, ParsedNode from dbt.contracts.graph.nodes import SourceDefinition, ManifestNode, ResultNode, ParsedNode
@@ -109,7 +109,7 @@ class BaseRelation(FakeAPIObject, Hashable):
return exact_match return exact_match
def replace_path(self, **kwargs): def replace_path(self, **kwargs):
return self.replace(path=self.path.replace(**kwargs)) return replace(self, path=replace(self.path, **kwargs))
def quote( def quote(
self: Self, self: Self,
@@ -126,7 +126,7 @@ class BaseRelation(FakeAPIObject, Hashable):
) )
new_quote_policy = self.quote_policy.replace_dict(policy) new_quote_policy = self.quote_policy.replace_dict(policy)
return self.replace(quote_policy=new_quote_policy) return replace(self, quote_policy=new_quote_policy)
def include( def include(
self: Self, self: Self,
@@ -143,7 +143,7 @@ class BaseRelation(FakeAPIObject, Hashable):
) )
new_include_policy = self.include_policy.replace_dict(policy) new_include_policy = self.include_policy.replace_dict(policy)
return self.replace(include_policy=new_include_policy) return replace(self, include_policy=new_include_policy)
def information_schema(self, view_name=None) -> "InformationSchema": def information_schema(self, view_name=None) -> "InformationSchema":
# some of our data comes from jinja, where things can be `Undefined`. # some of our data comes from jinja, where things can be `Undefined`.
@@ -384,7 +384,8 @@ class InformationSchema(BaseRelation):
relation, relation,
information_schema_view: Optional[str], information_schema_view: Optional[str],
) -> Policy: ) -> Policy:
return relation.include_policy.replace( return replace(
relation.include_policy,
database=relation.database is not None, database=relation.database is not None,
schema=False, schema=False,
identifier=True, identifier=True,
@@ -396,7 +397,8 @@ class InformationSchema(BaseRelation):
relation, relation,
information_schema_view: Optional[str], information_schema_view: Optional[str],
) -> Policy: ) -> Policy:
return relation.quote_policy.replace( return replace(
relation.quote_policy,
identifier=False, identifier=False,
) )

View File

@@ -1,6 +1,7 @@
import threading import threading
from copy import deepcopy from copy import deepcopy
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
from dataclasses import replace
from dbt.adapters.reference_keys import ( from dbt.adapters.reference_keys import (
_make_ref_key, _make_ref_key,
@@ -296,11 +297,11 @@ class RelationsCache:
return return
if ref_key not in self.relations: if ref_key not in self.relations:
# Insert a dummy "external" relation. # Insert a dummy "external" relation.
referenced = referenced.replace(type=referenced.External) referenced = replace(referenced, type=referenced.External)
self.add(referenced) self.add(referenced)
if dep_key not in self.relations: if dep_key not in self.relations:
# Insert a dummy "external" relation. # Insert a dummy "external" relation.
dependent = dependent.replace(type=referenced.External) dependent = replace(dependent, type=referenced.External)
self.add(dependent) self.add(dependent)
fire_event( fire_event(
CacheAction( CacheAction(

View File

@@ -25,7 +25,6 @@ from dbt.dataclass_schema import (
ValidatedStringMixin, ValidatedStringMixin,
register_pattern, register_pattern,
) )
from dbt.contracts.util import Replaceable
class Identifier(ValidatedStringMixin): class Identifier(ValidatedStringMixin):
@@ -54,7 +53,7 @@ class ConnectionState(StrEnum):
@dataclass(init=False) @dataclass(init=False)
class Connection(ExtensibleDbtClassMixin, Replaceable): class Connection(ExtensibleDbtClassMixin):
type: Identifier type: Identifier
name: Optional[str] = None name: Optional[str] = None
state: ConnectionState = ConnectionState.INIT state: ConnectionState = ConnectionState.INIT
@@ -123,7 +122,7 @@ class LazyHandle:
# for why we have type: ignore. Maybe someday dataclasses + abstract classes # for why we have type: ignore. Maybe someday dataclasses + abstract classes
# will work. # will work.
@dataclass # type: ignore @dataclass # type: ignore
class Credentials(ExtensibleDbtClassMixin, Replaceable, metaclass=abc.ABCMeta): class Credentials(ExtensibleDbtClassMixin, metaclass=abc.ABCMeta):
database: str database: str
schema: str schema: str
_ALIASES: ClassVar[Dict[str, str]] = field(default={}, init=False) _ALIASES: ClassVar[Dict[str, str]] = field(default={}, init=False)

View File

@@ -1,5 +1,5 @@
import enum import enum
from dataclasses import dataclass, field from dataclasses import dataclass, field, replace
from itertools import chain, islice from itertools import chain, islice
from mashumaro.mixins.msgpack import DataClassMessagePackMixin from mashumaro.mixins.msgpack import DataClassMessagePackMixin
from multiprocessing.synchronize import Lock from multiprocessing.synchronize import Lock
@@ -1154,7 +1154,7 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
) )
): ):
merged.add(unique_id) merged.add(unique_id)
self.nodes[unique_id] = node.replace(deferred=True) self.nodes[unique_id] = replace(node, deferred=True)
# Rebuild the flat_graph, which powers the 'graph' context variable, # Rebuild the flat_graph, which powers the 'graph' context variable,
# now that we've deferred some nodes # now that we've deferred some nodes
@@ -1179,7 +1179,7 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
current = self.nodes.get(unique_id) current = self.nodes.get(unique_id)
if current and (node.resource_type in refables and not node.is_ephemeral): if current and (node.resource_type in refables and not node.is_ephemeral):
state_relation = RelationalNode(node.database, node.schema, node.alias) state_relation = RelationalNode(node.database, node.schema, node.alias)
self.nodes[unique_id] = current.replace(state_relation=state_relation) self.nodes[unique_id] = replace(current, state_relation=state_relation)
# Methods that were formerly in ParseResult # Methods that were formerly in ParseResult

View File

@@ -11,8 +11,8 @@ from dbt.dataclass_schema import (
) )
from dbt.contracts.graph.unparsed import AdditionalPropertiesAllowed, Docs from dbt.contracts.graph.unparsed import AdditionalPropertiesAllowed, Docs
from dbt.contracts.graph.utils import validate_color from dbt.contracts.graph.utils import validate_color
from dbt.contracts.util import Replaceable, list_str
from dbt.exceptions import DbtInternalError, CompilationError from dbt.exceptions import DbtInternalError, CompilationError
from dbt.contracts.util import list_str
from dbt import hooks from dbt import hooks
from dbt.node_types import NodeType from dbt.node_types import NodeType
@@ -202,12 +202,12 @@ class OnConfigurationChangeOption(StrEnum):
@dataclass @dataclass
class ContractConfig(dbtClassMixin, Replaceable): class ContractConfig(dbtClassMixin):
enforced: bool = False enforced: bool = False
@dataclass @dataclass
class Hook(dbtClassMixin, Replaceable): class Hook(dbtClassMixin):
sql: str sql: str
transaction: bool = True transaction: bool = True
index: Optional[int] = None index: Optional[int] = None
@@ -217,7 +217,7 @@ T = TypeVar("T", bound="BaseConfig")
@dataclass @dataclass
class BaseConfig(AdditionalPropertiesAllowed, Replaceable): class BaseConfig(AdditionalPropertiesAllowed):
# enable syntax like: config['key'] # enable syntax like: config['key']
def __getitem__(self, key): def __getitem__(self, key):

View File

@@ -31,7 +31,7 @@ from dbt.contracts.graph.unparsed import (
UnparsedSourceTableDefinition, UnparsedSourceTableDefinition,
UnparsedColumn, UnparsedColumn,
) )
from dbt.contracts.util import Replaceable, AdditionalPropertiesMixin from dbt.contracts.util import AdditionalPropertiesMixin
from dbt.events.functions import warn_or_error from dbt.events.functions import warn_or_error
from dbt.exceptions import ParsingError, ContractBreakingChangeError from dbt.exceptions import ParsingError, ContractBreakingChangeError
from dbt.events.types import ( from dbt.events.types import (
@@ -84,7 +84,7 @@ from .model_config import (
@dataclass @dataclass
class BaseNode(dbtClassMixin, Replaceable): class BaseNode(dbtClassMixin):
"""All nodes or node-like objects in this file should have this as a base class""" """All nodes or node-like objects in this file should have this as a base class"""
name: str name: str
@@ -201,7 +201,7 @@ class ModelLevelConstraint(ColumnLevelConstraint):
@dataclass @dataclass
class ColumnInfo(AdditionalPropertiesMixin, ExtensibleDbtClassMixin, Replaceable): class ColumnInfo(AdditionalPropertiesMixin, ExtensibleDbtClassMixin):
"""Used in all ManifestNodes and SourceDefinition""" """Used in all ManifestNodes and SourceDefinition"""
name: str name: str
@@ -215,14 +215,14 @@ class ColumnInfo(AdditionalPropertiesMixin, ExtensibleDbtClassMixin, Replaceable
@dataclass @dataclass
class Contract(dbtClassMixin, Replaceable): class Contract(dbtClassMixin):
enforced: bool = False enforced: bool = False
checksum: Optional[str] = None checksum: Optional[str] = None
# Metrics, exposures, # Metrics, exposures,
@dataclass @dataclass
class HasRelationMetadata(dbtClassMixin, Replaceable): class HasRelationMetadata(dbtClassMixin):
database: Optional[str] database: Optional[str]
schema: str schema: str
@@ -238,7 +238,7 @@ class HasRelationMetadata(dbtClassMixin, Replaceable):
@dataclass @dataclass
class MacroDependsOn(dbtClassMixin, Replaceable): class MacroDependsOn(dbtClassMixin):
"""Used only in the Macro class""" """Used only in the Macro class"""
macros: List[str] = field(default_factory=list) macros: List[str] = field(default_factory=list)
@@ -284,7 +284,7 @@ class StateRelation(dbtClassMixin):
@dataclass @dataclass
class ParsedNodeMandatory(GraphNode, HasRelationMetadata, Replaceable): class ParsedNodeMandatory(GraphNode, HasRelationMetadata):
alias: str alias: str
checksum: FileHash checksum: FileHash
config: NodeConfig = field(default_factory=NodeConfig) config: NodeConfig = field(default_factory=NodeConfig)
@@ -485,7 +485,7 @@ class ParsedNode(NodeInfoMixin, ParsedNodeMandatory, SerializableType):
@dataclass @dataclass
class InjectedCTE(dbtClassMixin, Replaceable): class InjectedCTE(dbtClassMixin):
"""Used in CompiledNodes as part of ephemeral model processing""" """Used in CompiledNodes as part of ephemeral model processing"""
id: str id: str
@@ -555,7 +555,7 @@ class CompiledNode(ParsedNode):
@dataclass @dataclass
class FileSlice(dbtClassMixin, Replaceable): class FileSlice(dbtClassMixin):
"""Provides file slice level context about what something was created from. """Provides file slice level context about what something was created from.
Implementation of the dbt-semantic-interfaces `FileSlice` protocol Implementation of the dbt-semantic-interfaces `FileSlice` protocol
@@ -568,7 +568,7 @@ class FileSlice(dbtClassMixin, Replaceable):
@dataclass @dataclass
class SourceFileMetadata(dbtClassMixin, Replaceable): class SourceFileMetadata(dbtClassMixin):
"""Provides file context about what something was created from. """Provides file context about what something was created from.
Implementation of the dbt-semantic-interfaces `Metadata` protocol Implementation of the dbt-semantic-interfaces `Metadata` protocol
@@ -925,7 +925,7 @@ class SingularTestNode(TestShouldStoreFailures, CompiledNode):
@dataclass @dataclass
class TestMetadata(dbtClassMixin, Replaceable): class TestMetadata(dbtClassMixin):
name: str name: str
# kwargs are the args that are left in the test builder after # kwargs are the args that are left in the test builder after
# removing configs. They are set from the test builder when # removing configs. They are set from the test builder when
@@ -1366,7 +1366,7 @@ class MetricTypeParams(dbtClassMixin):
@dataclass @dataclass
class MetricReference(dbtClassMixin, Replaceable): class MetricReference(dbtClassMixin):
sql: Optional[Union[str, int]] = None sql: Optional[Union[str, int]] = None
unique_id: Optional[str] = None unique_id: Optional[str] = None
@@ -1507,7 +1507,7 @@ class SemanticModel(GraphNode):
@dataclass @dataclass
class ParsedPatch(HasYamlMetadata, Replaceable): class ParsedPatch(HasYamlMetadata):
name: str name: str
description: str description: str
meta: Dict[str, Any] meta: Dict[str, Any]

View File

@@ -6,7 +6,6 @@ from dbt.node_types import NodeType
from dbt.contracts.util import ( from dbt.contracts.util import (
AdditionalPropertiesMixin, AdditionalPropertiesMixin,
Mergeable, Mergeable,
Replaceable,
) )
# trigger the PathEncoder # trigger the PathEncoder
@@ -22,7 +21,7 @@ from typing import Optional, List, Union, Dict, Any, Sequence
@dataclass @dataclass
class UnparsedBaseNode(dbtClassMixin, Replaceable): class UnparsedBaseNode(dbtClassMixin):
package_name: str package_name: str
path: str path: str
original_file_path: str original_file_path: str
@@ -82,13 +81,13 @@ class UnparsedRunHook(UnparsedNode):
@dataclass @dataclass
class Docs(dbtClassMixin, Replaceable): class Docs(dbtClassMixin):
show: bool = True show: bool = True
node_color: Optional[str] = None node_color: Optional[str] = None
@dataclass @dataclass
class HasColumnProps(AdditionalPropertiesMixin, ExtensibleDbtClassMixin, Replaceable): class HasColumnProps(AdditionalPropertiesMixin, ExtensibleDbtClassMixin):
name: str name: str
description: str = "" description: str = ""
meta: Dict[str, Any] = field(default_factory=dict) meta: Dict[str, Any] = field(default_factory=dict)
@@ -113,12 +112,12 @@ class UnparsedColumn(HasColumnAndTestProps):
@dataclass @dataclass
class HasColumnDocs(dbtClassMixin, Replaceable): class HasColumnDocs(dbtClassMixin):
columns: Sequence[HasColumnProps] = field(default_factory=list) columns: Sequence[HasColumnProps] = field(default_factory=list)
@dataclass @dataclass
class HasColumnTests(dbtClassMixin, Replaceable): class HasColumnTests(dbtClassMixin):
columns: Sequence[UnparsedColumn] = field(default_factory=list) columns: Sequence[UnparsedColumn] = field(default_factory=list)
@@ -324,7 +323,7 @@ class AdditionalPropertiesAllowed(AdditionalPropertiesMixin, ExtensibleDbtClassM
@dataclass @dataclass
class ExternalPartition(AdditionalPropertiesAllowed, Replaceable): class ExternalPartition(AdditionalPropertiesAllowed):
name: str = "" name: str = ""
description: str = "" description: str = ""
data_type: str = "" data_type: str = ""
@@ -373,7 +372,7 @@ class UnparsedSourceTableDefinition(HasColumnTests, HasColumnAndTestProps):
@dataclass @dataclass
class UnparsedSourceDefinition(dbtClassMixin, Replaceable): class UnparsedSourceDefinition(dbtClassMixin):
name: str name: str
description: str = "" description: str = ""
meta: Dict[str, Any] = field(default_factory=dict) meta: Dict[str, Any] = field(default_factory=dict)
@@ -428,7 +427,7 @@ class SourceTablePatch(dbtClassMixin):
@dataclass @dataclass
class SourcePatch(dbtClassMixin, Replaceable): class SourcePatch(dbtClassMixin):
name: str = field( name: str = field(
metadata=dict(description="The name of the source to override"), metadata=dict(description="The name of the source to override"),
) )
@@ -471,7 +470,7 @@ class SourcePatch(dbtClassMixin, Replaceable):
@dataclass @dataclass
class UnparsedDocumentation(dbtClassMixin, Replaceable): class UnparsedDocumentation(dbtClassMixin):
package_name: str package_name: str
path: str path: str
original_file_path: str original_file_path: str
@@ -534,13 +533,13 @@ class MaturityType(StrEnum):
@dataclass @dataclass
class Owner(AdditionalPropertiesAllowed, Replaceable): class Owner(AdditionalPropertiesAllowed):
email: Optional[str] = None email: Optional[str] = None
name: Optional[str] = None name: Optional[str] = None
@dataclass @dataclass
class UnparsedExposure(dbtClassMixin, Replaceable): class UnparsedExposure(dbtClassMixin):
name: str name: str
type: ExposureType type: ExposureType
owner: Owner owner: Owner
@@ -566,7 +565,7 @@ class UnparsedExposure(dbtClassMixin, Replaceable):
@dataclass @dataclass
class MetricFilter(dbtClassMixin, Replaceable): class MetricFilter(dbtClassMixin):
field: str field: str
operator: str operator: str
# TODO : Can we make this Any? # TODO : Can we make this Any?
@@ -656,7 +655,7 @@ class UnparsedMetric(dbtClassMixin):
@dataclass @dataclass
class UnparsedGroup(dbtClassMixin, Replaceable): class UnparsedGroup(dbtClassMixin):
name: str name: str
owner: Owner owner: Owner

View File

@@ -1,4 +1,4 @@
from dbt.contracts.util import Replaceable, Mergeable, list_str, Identifier from dbt.contracts.util import Mergeable, list_str, Identifier
from dbt.contracts.connection import QueryComment, UserConfigContract from dbt.contracts.connection import QueryComment, UserConfigContract
from dbt.helper_types import NoValue from dbt.helper_types import NoValue
from dbt.dataclass_schema import ( from dbt.dataclass_schema import (
@@ -42,7 +42,7 @@ class Quoting(dbtClassMixin, Mergeable):
@dataclass @dataclass
class Package(Replaceable, HyphenatedDbtClassMixin): class Package(HyphenatedDbtClassMixin):
pass pass
@@ -92,7 +92,7 @@ PackageSpec = Union[LocalPackage, TarballPackage, GitPackage, RegistryPackage]
@dataclass @dataclass
class PackageConfig(dbtClassMixin, Replaceable): class PackageConfig(dbtClassMixin):
packages: List[PackageSpec] packages: List[PackageSpec]
@classmethod @classmethod
@@ -124,7 +124,7 @@ class ProjectPackageMetadata:
@dataclass @dataclass
class Downloads(ExtensibleDbtClassMixin, Replaceable): class Downloads(ExtensibleDbtClassMixin):
tarball: str tarball: str
@@ -182,7 +182,7 @@ BANNED_PROJECT_NAMES = {
@dataclass @dataclass
class Project(HyphenatedDbtClassMixin, Replaceable): class Project(HyphenatedDbtClassMixin):
name: Identifier name: Identifier
config_version: Optional[int] = 2 config_version: Optional[int] = 2
version: Optional[Union[SemverString, float]] = None version: Optional[Union[SemverString, float]] = None
@@ -242,7 +242,7 @@ class Project(HyphenatedDbtClassMixin, Replaceable):
@dataclass @dataclass
class UserConfig(ExtensibleDbtClassMixin, Replaceable, UserConfigContract): class UserConfig(ExtensibleDbtClassMixin, UserConfigContract):
cache_selected_only: Optional[bool] = None cache_selected_only: Optional[bool] = None
debug: Optional[bool] = None debug: Optional[bool] = None
fail_fast: Optional[bool] = None fail_fast: Optional[bool] = None
@@ -266,7 +266,7 @@ class UserConfig(ExtensibleDbtClassMixin, Replaceable, UserConfigContract):
@dataclass @dataclass
class ProfileConfig(HyphenatedDbtClassMixin, Replaceable): class ProfileConfig(HyphenatedDbtClassMixin):
profile_name: str = field(metadata={"preserve_underscore": True}) profile_name: str = field(metadata={"preserve_underscore": True})
target_name: str = field(metadata={"preserve_underscore": True}) target_name: str = field(metadata={"preserve_underscore": True})
user_config: UserConfig = field(metadata={"preserve_underscore": True}) user_config: UserConfig = field(metadata={"preserve_underscore": True})
@@ -276,7 +276,7 @@ class ProfileConfig(HyphenatedDbtClassMixin, Replaceable):
@dataclass @dataclass
class ConfiguredQuoting(Quoting, Replaceable): class ConfiguredQuoting(Quoting):
identifier: bool = True identifier: bool = True
schema: bool = True schema: bool = True
database: Optional[bool] = None database: Optional[bool] = None

View File

@@ -1,5 +1,5 @@
from collections.abc import Mapping from collections.abc import Mapping
from dataclasses import dataclass from dataclasses import dataclass, replace
from typing import ( from typing import (
Optional, Optional,
Dict, Dict,
@@ -8,7 +8,6 @@ from typing_extensions import Protocol
from dbt.dataclass_schema import dbtClassMixin, StrEnum from dbt.dataclass_schema import dbtClassMixin, StrEnum
from dbt.contracts.util import Replaceable
from dbt.exceptions import CompilationError, DataclassNotDictError from dbt.exceptions import CompilationError, DataclassNotDictError
from dbt.utils import deep_merge from dbt.utils import deep_merge
@@ -31,7 +30,7 @@ class HasQuoting(Protocol):
quoting: Dict[str, bool] quoting: Dict[str, bool]
class FakeAPIObject(dbtClassMixin, Replaceable, Mapping): class FakeAPIObject(dbtClassMixin, Mapping):
# override the mapping truthiness, len is always >1 # override the mapping truthiness, len is always >1
def __bool__(self): def __bool__(self):
return True return True
@@ -76,7 +75,7 @@ class Policy(FakeAPIObject):
kwargs: Dict[str, bool] = {} kwargs: Dict[str, bool] = {}
for k, v in dct.items(): for k, v in dct.items():
kwargs[str(k)] = v kwargs[str(k)] = v
return self.replace(**kwargs) return replace(self, **kwargs)
@dataclass @dataclass
@@ -116,4 +115,4 @@ class Path(FakeAPIObject):
kwargs: Dict[str, str] = {} kwargs: Dict[str, str] = {}
for k, v in dct.items(): for k, v in dct.items():
kwargs[str(k)] = v kwargs[str(k)] = v
return self.replace(**kwargs) return replace(self, **kwargs)

View File

@@ -4,7 +4,6 @@ from dbt.contracts.util import (
BaseArtifactMetadata, BaseArtifactMetadata,
ArtifactMixin, ArtifactMixin,
VersionedSchema, VersionedSchema,
Replaceable,
schema_version, schema_version,
) )
from dbt.exceptions import DbtInternalError from dbt.exceptions import DbtInternalError
@@ -425,7 +424,7 @@ class TableMetadata(dbtClassMixin):
@dataclass @dataclass
class CatalogTable(dbtClassMixin, Replaceable): class CatalogTable(dbtClassMixin):
metadata: TableMetadata metadata: TableMetadata
columns: ColumnMap columns: ColumnMap
stats: StatsDict stats: StatsDict

View File

@@ -44,7 +44,7 @@ class Replaceable:
return dataclasses.replace(self, **kwargs) return dataclasses.replace(self, **kwargs)
class Mergeable(Replaceable): class Mergeable:
def merged(self, *args): def merged(self, *args):
"""Perform a shallow merge, where the last non-None write wins. This is """Perform a shallow merge, where the last non-None write wins. This is
intended to merge dataclasses that are a collection of optional values. intended to merge dataclasses that are a collection of optional values.
@@ -57,7 +57,7 @@ class Mergeable(Replaceable):
if value is not None: if value is not None:
replacements[field.name] = value replacements[field.name] = value
return self.replace(**replacements) return dataclasses.replace(self, **replacements)
class Writable: class Writable:

View File

@@ -1,6 +1,7 @@
import itertools import itertools
from pathlib import Path from pathlib import Path
from typing import Iterable, Dict, Optional, Set, Any, List from typing import Iterable, Dict, Optional, Set, Any, List
from dataclasses import replace
from dbt.adapters.factory import get_adapter from dbt.adapters.factory import get_adapter
from dbt.config import RuntimeConfig from dbt.config import RuntimeConfig
from dbt.context.context_config import ( from dbt.context.context_config import (
@@ -117,7 +118,7 @@ class SourcePatcher:
source = UnparsedSourceDefinition.from_dict(source_dct) source = UnparsedSourceDefinition.from_dict(source_dct)
table = UnparsedSourceTableDefinition.from_dict(table_dct) table = UnparsedSourceTableDefinition.from_dict(table_dct)
return unpatched.replace(source=source, table=table, patch_path=patch_path) return replace(unpatched, source=source, table=table, patch_path=patch_path)
# This converts an UnpatchedSourceDefinition to a SourceDefinition # This converts an UnpatchedSourceDefinition to a SourceDefinition
def parse_source(self, target: UnpatchedSourceDefinition) -> SourceDefinition: def parse_source(self, target: UnpatchedSourceDefinition) -> SourceDefinition:

View File

@@ -3,6 +3,7 @@ import shutil
from datetime import datetime from datetime import datetime
from typing import Dict, List, Any, Optional, Tuple, Set from typing import Dict, List, Any, Optional, Tuple, Set
import agate import agate
from dataclasses import replace
from dbt.dataclass_schema import ValidationError from dbt.dataclass_schema import ValidationError
@@ -115,7 +116,7 @@ class Catalog(Dict[CatalogKey, CatalogTable]):
key = table.key() key = table.key()
if key in node_map: if key in node_map:
unique_id = node_map[key] unique_id = node_map[key]
nodes[unique_id] = table.replace(unique_id=unique_id) nodes[unique_id] = replace(table, unique_id=unique_id)
unique_ids = source_map.get(table.key(), set()) unique_ids = source_map.get(table.key(), set())
for unique_id in unique_ids: for unique_id in unique_ids:
@@ -126,7 +127,7 @@ class Catalog(Dict[CatalogKey, CatalogTable]):
table.to_dict(omit_none=True), table.to_dict(omit_none=True),
) )
else: else:
sources[unique_id] = table.replace(unique_id=unique_id) sources[unique_id] = replace(table, unique_id=unique_id)
return nodes, sources return nodes, sources

View File

@@ -1,5 +1,6 @@
import pickle import pickle
import pytest import pytest
from dataclasses import replace
from dbt.contracts.files import FileHash from dbt.contracts.files import FileHash
from dbt.contracts.graph.nodes import ( from dbt.contracts.graph.nodes import (
@@ -263,44 +264,46 @@ def test_invalid_bad_type_model(minimal_uncompiled_dict):
unchanged_compiled_models = [ unchanged_compiled_models = [
lambda u: (u, u.replace(description="a description")), lambda u: (u, replace(u, description="a description")),
lambda u: (u, u.replace(tags=["mytag"])), lambda u: (u, replace(u, tags=["mytag"])),
lambda u: (u, u.replace(meta={"cool_key": "cool value"})), lambda u: (u, replace(u, meta={"cool_key": "cool value"})),
# changing the final alias/schema/datbase isn't a change - could just be target changing! # changing the final alias/schema/datbase isn't a change - could just be target changing!
lambda u: (u, u.replace(database="nope")), lambda u: (u, replace(u, database="nope")),
lambda u: (u, u.replace(schema="nope")), lambda u: (u, replace(u, schema="nope")),
lambda u: (u, u.replace(alias="nope")), lambda u: (u, replace(u, alias="nope")),
# None -> False is a config change even though it's pretty much the same # None -> False is a config change even though it's pretty much the same
lambda u: ( lambda u: (
u.replace(config=u.config.replace(persist_docs={"relation": False})), replace(u, config=replace(u.config, persist_docs={"relation": False})),
u.replace(config=u.config.replace(persist_docs={"relation": False})), replace(u, config=replace(u.config, persist_docs={"relation": False})),
), ),
lambda u: ( lambda u: (
u.replace(config=u.config.replace(persist_docs={"columns": False})), replace(u, config=replace(u.config, persist_docs={"columns": False})),
u.replace(config=u.config.replace(persist_docs={"columns": False})), replace(u, config=replace(u.config, persist_docs={"columns": False})),
), ),
# True -> True # True -> True
lambda u: ( lambda u: (
u.replace(config=u.config.replace(persist_docs={"relation": True})), replace(u, config=replace(u.config, persist_docs={"relation": True})),
u.replace(config=u.config.replace(persist_docs={"relation": True})), replace(u, config=replace(u.config, persist_docs={"relation": True})),
), ),
lambda u: ( lambda u: (
u.replace(config=u.config.replace(persist_docs={"columns": True})), replace(u, config=replace(u.config, persist_docs={"columns": True})),
u.replace(config=u.config.replace(persist_docs={"columns": True})), replace(u, config=replace(u.config, persist_docs={"columns": True})),
), ),
# only columns docs enabled, but description changed # only columns docs enabled, but description changed
lambda u: ( lambda u: (
u.replace(config=u.config.replace(persist_docs={"columns": True})), replace(u, config=replace(u.config, persist_docs={"columns": True})),
u.replace( replace(
config=u.config.replace(persist_docs={"columns": True}), u,
config=replace(u.config, persist_docs={"columns": True}),
description="a model description", description="a model description",
), ),
), ),
# only relation docs eanbled, but columns changed # only relation docs eanbled, but columns changed
lambda u: ( lambda u: (
u.replace(config=u.config.replace(persist_docs={"relation": True})), replace(u, config=replace(u.config, persist_docs={"relation": True})),
u.replace( replace(
config=u.config.replace(persist_docs={"relation": True}), u,
config=replace(u.config, persist_docs={"relation": True}),
columns={"a": ColumnInfo(name="a", description="a column description")}, columns={"a": ColumnInfo(name="a", description="a column description")},
), ),
), ),
@@ -309,10 +312,11 @@ unchanged_compiled_models = [
changed_compiled_models = [ changed_compiled_models = [
lambda u: (u, None), lambda u: (u, None),
lambda u: (u, u.replace(raw_code="select * from wherever")), lambda u: (u, replace(u, raw_code="select * from wherever")),
lambda u: ( lambda u: (
u, u,
u.replace( replace(
u,
fqn=["test", "models", "subdir", "foo"], fqn=["test", "models", "subdir", "foo"],
original_file_path="models/subdir/foo.sql", original_file_path="models/subdir/foo.sql",
path="/root/models/subdir/foo.sql", path="/root/models/subdir/foo.sql",
@@ -616,17 +620,17 @@ def test_invalid_resource_type_schema_test(minimal_schema_test_dict):
unchanged_schema_tests = [ unchanged_schema_tests = [
# for tests, raw_code isn't a change (because it's always the same for a given test macro) # for tests, raw_code isn't a change (because it's always the same for a given test macro)
lambda u: u.replace(raw_code="select * from wherever"), lambda u: replace(u, raw_code="select * from wherever"),
lambda u: u.replace(description="a description"), lambda u: replace(u, description="a description"),
lambda u: u.replace(tags=["mytag"]), lambda u: replace(u, tags=["mytag"]),
lambda u: u.replace(meta={"cool_key": "cool value"}), lambda u: replace(u, meta={"cool_key": "cool value"}),
# these values don't even mean anything on schema tests! # these values don't even mean anything on schema tests!
lambda u: replace_config(u, alias="nope"), lambda u: replace_config(u, alias="nope"),
lambda u: replace_config(u, database="nope"), lambda u: replace_config(u, database="nope"),
lambda u: replace_config(u, schema="nope"), lambda u: replace_config(u, schema="nope"),
lambda u: u.replace(database="other_db"), lambda u: replace(u, database="other_db"),
lambda u: u.replace(schema="other_schema"), lambda u: replace(u, schema="other_schema"),
lambda u: u.replace(alias="foo"), lambda u: replace(u, alias="foo"),
lambda u: replace_config(u, full_refresh=True), lambda u: replace_config(u, full_refresh=True),
lambda u: replace_config(u, post_hook=["select 1 as id"]), lambda u: replace_config(u, post_hook=["select 1 as id"]),
lambda u: replace_config(u, pre_hook=["select 1 as id"]), lambda u: replace_config(u, pre_hook=["select 1 as id"]),
@@ -636,7 +640,8 @@ unchanged_schema_tests = [
changed_schema_tests = [ changed_schema_tests = [
lambda u: None, lambda u: None,
lambda u: u.replace( lambda u: replace(
u,
fqn=["test", "models", "subdir", "foo"], fqn=["test", "models", "subdir", "foo"],
original_file_path="models/subdir/foo.sql", original_file_path="models/subdir/foo.sql",
path="/root/models/subdir/foo.sql", path="/root/models/subdir/foo.sql",
@@ -666,8 +671,8 @@ def test_compare_to_compiled(basic_uncompiled_schema_test_node, basic_compiled_s
uncompiled = basic_uncompiled_schema_test_node uncompiled = basic_uncompiled_schema_test_node
compiled = basic_compiled_schema_test_node compiled = basic_compiled_schema_test_node
assert not uncompiled.same_contents(compiled, "postgres") assert not uncompiled.same_contents(compiled, "postgres")
fixed_config = compiled.config.replace(severity=uncompiled.config.severity) fixed_config = replace(compiled.config, severity=uncompiled.config.severity)
fixed_compiled = compiled.replace( fixed_compiled = replace(
config=fixed_config, unrendered_config=uncompiled.unrendered_config compiled, config=fixed_config, unrendered_config=uncompiled.unrendered_config
) )
assert uncompiled.same_contents(fixed_compiled, "postgres") assert uncompiled.same_contents(fixed_compiled, "postgres")

View File

@@ -1,5 +1,6 @@
import pickle import pickle
import pytest import pytest
from dataclasses import replace
from dbt.node_types import NodeType, AccessType from dbt.node_types import NodeType, AccessType
from dbt.contracts.files import FileHash from dbt.contracts.files import FileHash
@@ -423,8 +424,8 @@ def test_invalid_bad_materialized(base_parsed_model_dict):
unchanged_nodes = [ unchanged_nodes = [
lambda u: (u, u.replace(tags=["mytag"])), lambda u: (u, replace(u, tags=["mytag"])),
lambda u: (u, u.replace(meta={"something": 1000})), lambda u: (u, replace(u, meta={"something": 1000})),
# True -> True # True -> True
lambda u: ( lambda u: (
replace_config(u, persist_docs={"relation": True}), replace_config(u, persist_docs={"relation": True}),
@@ -437,28 +438,30 @@ unchanged_nodes = [
# only columns docs enabled, but description changed # only columns docs enabled, but description changed
lambda u: ( lambda u: (
replace_config(u, persist_docs={"columns": True}), replace_config(u, persist_docs={"columns": True}),
replace_config(u, persist_docs={"columns": True}).replace( replace(
description="a model description" replace_config(u, persist_docs={"columns": True}), description="a model description"
), ),
), ),
# only relation docs eanbled, but columns changed # only relation docs eanbled, but columns changed
lambda u: ( lambda u: (
replace_config(u, persist_docs={"relation": True}), replace_config(u, persist_docs={"relation": True}),
replace_config(u, persist_docs={"relation": True}).replace( replace(
columns={"a": ColumnInfo(name="a", description="a column description")} replace_config(u, persist_docs={"relation": True}),
columns={"a": ColumnInfo(name="a", description="a column description")},
), ),
), ),
# not tracked, we track config.alias/config.schema/config.database # not tracked, we track config.alias/config.schema/config.database
lambda u: (u, u.replace(alias="other")), lambda u: (u, replace(u, alias="other")),
lambda u: (u, u.replace(schema="other")), lambda u: (u, replace(u, schema="other")),
lambda u: (u, u.replace(database="other")), lambda u: (u, replace(u, database="other")),
] ]
changed_nodes = [ changed_nodes = [
lambda u: ( lambda u: (
u, u,
u.replace( replace(
u,
fqn=["test", "models", "subdir", "foo"], fqn=["test", "models", "subdir", "foo"],
original_file_path="models/subdir/foo.sql", original_file_path="models/subdir/foo.sql",
path="/root/models/subdir/foo.sql", path="/root/models/subdir/foo.sql",
@@ -470,15 +473,16 @@ changed_nodes = [
# persist docs was true for the relation and we changed the model description # persist docs was true for the relation and we changed the model description
lambda u: ( lambda u: (
replace_config(u, persist_docs={"relation": True}), replace_config(u, persist_docs={"relation": True}),
replace_config(u, persist_docs={"relation": True}).replace( replace(
description="a model description" replace_config(u, persist_docs={"relation": True}), description="a model description"
), ),
), ),
# persist docs was true for columns and we changed the model description # persist docs was true for columns and we changed the model description
lambda u: ( lambda u: (
replace_config(u, persist_docs={"columns": True}), replace_config(u, persist_docs={"columns": True}),
replace_config(u, persist_docs={"columns": True}).replace( replace(
columns={"a": ColumnInfo(name="a", description="a column description")} replace_config(u, persist_docs={"columns": True}),
columns={"a": ColumnInfo(name="a", description="a column description")},
), ),
), ),
# not tracked, we track config.alias/config.schema/config.database # not tracked, we track config.alias/config.schema/config.database
@@ -696,8 +700,8 @@ def test_seed_complex(complex_parsed_seed_dict, complex_parsed_seed_object):
unchanged_seeds = [ unchanged_seeds = [
lambda u: (u, u.replace(tags=["mytag"])), lambda u: (u, replace(u, tags=["mytag"])),
lambda u: (u, u.replace(meta={"something": 1000})), lambda u: (u, replace(u, meta={"something": 1000})),
# True -> True # True -> True
lambda u: ( lambda u: (
replace_config(u, persist_docs={"relation": True}), replace_config(u, persist_docs={"relation": True}),
@@ -710,27 +714,29 @@ unchanged_seeds = [
# only columns docs enabled, but description changed # only columns docs enabled, but description changed
lambda u: ( lambda u: (
replace_config(u, persist_docs={"columns": True}), replace_config(u, persist_docs={"columns": True}),
replace_config(u, persist_docs={"columns": True}).replace( replace(
description="a model description" replace_config(u, persist_docs={"columns": True}), description="a model description"
), ),
), ),
# only relation docs eanbled, but columns changed # only relation docs eanbled, but columns changed
lambda u: ( lambda u: (
replace_config(u, persist_docs={"relation": True}), replace_config(u, persist_docs={"relation": True}),
replace_config(u, persist_docs={"relation": True}).replace( replace(
columns={"a": ColumnInfo(name="a", description="a column description")} replace_config(u, persist_docs={"relation": True}),
columns={"a": ColumnInfo(name="a", description="a column description")},
), ),
), ),
lambda u: (u, u.replace(alias="other")), lambda u: (u, replace(u, alias="other")),
lambda u: (u, u.replace(schema="other")), lambda u: (u, replace(u, schema="other")),
lambda u: (u, u.replace(database="other")), lambda u: (u, replace(u, database="other")),
] ]
changed_seeds = [ changed_seeds = [
lambda u: ( lambda u: (
u, u,
u.replace( replace(
u,
fqn=["test", "models", "subdir", "foo"], fqn=["test", "models", "subdir", "foo"],
original_file_path="models/subdir/foo.sql", original_file_path="models/subdir/foo.sql",
path="/root/models/subdir/foo.sql", path="/root/models/subdir/foo.sql",
@@ -742,15 +748,16 @@ changed_seeds = [
# persist docs was true for the relation and we changed the model description # persist docs was true for the relation and we changed the model description
lambda u: ( lambda u: (
replace_config(u, persist_docs={"relation": True}), replace_config(u, persist_docs={"relation": True}),
replace_config(u, persist_docs={"relation": True}).replace( replace(
description="a model description" replace_config(u, persist_docs={"relation": True}), description="a model description"
), ),
), ),
# persist docs was true for columns and we changed the model description # persist docs was true for columns and we changed the model description
lambda u: ( lambda u: (
replace_config(u, persist_docs={"columns": True}), replace_config(u, persist_docs={"columns": True}),
replace_config(u, persist_docs={"columns": True}).replace( replace(
columns={"a": ColumnInfo(name="a", description="a column description")} replace_config(u, persist_docs={"columns": True}),
columns={"a": ColumnInfo(name="a", description="a column description")},
), ),
), ),
lambda u: (u, replace_config(u, alias="other")), lambda u: (u, replace_config(u, alias="other")),
@@ -2104,27 +2111,30 @@ def test_source_no_freshness(complex_parsed_source_definition_object):
unchanged_source_definitions = [ unchanged_source_definitions = [
lambda u: (u, u.replace(tags=["mytag"])), lambda u: (u, replace(u, tags=["mytag"])),
lambda u: (u, u.replace(meta={"a": 1000})), lambda u: (u, replace(u, meta={"a": 1000})),
] ]
changed_source_definitions = [ changed_source_definitions = [
lambda u: ( lambda u: (
u, u,
u.replace( replace(
u,
freshness=FreshnessThreshold(warn_after=Time(period=TimePeriod.hour, count=1)), freshness=FreshnessThreshold(warn_after=Time(period=TimePeriod.hour, count=1)),
loaded_at_field="loaded_at", loaded_at_field="loaded_at",
), ),
), ),
lambda u: (u, u.replace(loaded_at_field="loaded_at")), lambda u: (u, replace(u, loaded_at_field="loaded_at")),
lambda u: ( lambda u: (
u, u,
u.replace(freshness=FreshnessThreshold(error_after=Time(period=TimePeriod.hour, count=1))), replace(
u, freshness=FreshnessThreshold(error_after=Time(period=TimePeriod.hour, count=1))
),
), ),
lambda u: (u, u.replace(quoting=Quoting(identifier=True))), lambda u: (u, replace(u, quoting=Quoting(identifier=True))),
lambda u: (u, u.replace(database="other_database")), lambda u: (u, replace(u, database="other_database")),
lambda u: (u, u.replace(schema="other_schema")), lambda u: (u, replace(u, schema="other_schema")),
lambda u: (u, u.replace(identifier="identifier")), lambda u: (u, replace(u, identifier="identifier")),
] ]
@@ -2291,13 +2301,13 @@ unchanged_parsed_exposures = [
changed_parsed_exposures = [ changed_parsed_exposures = [
lambda u: (u, u.replace(fqn=u.fqn[:-1] + ["something", u.fqn[-1]])), lambda u: (u, replace(u, fqn=u.fqn[:-1] + ["something", u.fqn[-1]])),
lambda u: (u, u.replace(type=ExposureType.ML)), lambda u: (u, replace(u, type=ExposureType.ML)),
lambda u: (u, u.replace(owner=u.owner.replace(name="My Name"))), lambda u: (u, replace(u, owner=replace(u.owner, name="My Name"))),
lambda u: (u, u.replace(maturity=MaturityType.Medium)), lambda u: (u, replace(u, maturity=MaturityType.Medium)),
lambda u: (u, u.replace(url="https://example.com/dashboard/1")), lambda u: (u, replace(u, url="https://example.com/dashboard/1")),
lambda u: (u, u.replace(description="My description")), lambda u: (u, replace(u, description="My description")),
lambda u: (u, u.replace(depends_on=DependsOn(nodes=["model.test.blah"]))), lambda u: (u, replace(u, depends_on=DependsOn(nodes=["model.test.blah"]))),
] ]

View File

@@ -2,6 +2,7 @@ import copy
import pytest import pytest
from unittest import mock from unittest import mock
from dataclasses import replace
from pathlib import Path from pathlib import Path
@@ -902,9 +903,7 @@ def test_select_group(manifest, view_model):
manifest.groups[group.unique_id] = group manifest.groups[group.unique_id] = group
change_node( change_node(
manifest, manifest,
view_model.replace( replace(view_model, config={"materialized": "view", "group": group_name}),
config={"materialized": "view", "group": group_name},
),
) )
methods = MethodManager(manifest, None) methods = MethodManager(manifest, None)
method = methods.get_method("group", []) method = methods.get_method("group", [])
@@ -918,9 +917,7 @@ def test_select_group(manifest, view_model):
def test_select_access(manifest, view_model): def test_select_access(manifest, view_model):
change_node( change_node(
manifest, manifest,
view_model.replace( replace(view_model, access="public"),
access="public",
),
) )
methods = MethodManager(manifest, None) methods = MethodManager(manifest, None)
method = methods.get_method("access", []) method = methods.get_method("access", [])
@@ -1274,7 +1271,7 @@ def test_select_state_added_model(manifest, previous_state):
def test_select_state_changed_model_sql(manifest, previous_state, view_model): def test_select_state_changed_model_sql(manifest, previous_state, view_model):
change_node(manifest, view_model.replace(raw_code="select 1 as id")) change_node(manifest, replace(view_model, raw_code="select 1 as id"))
method = statemethod(manifest, previous_state) method = statemethod(manifest, previous_state)
# both of these # both of these
@@ -1291,7 +1288,7 @@ def test_select_state_changed_model_sql(manifest, previous_state, view_model):
def test_select_state_changed_model_fqn(manifest, previous_state, view_model): def test_select_state_changed_model_fqn(manifest, previous_state, view_model):
change_node( change_node(
manifest, view_model.replace(fqn=view_model.fqn[:-1] + ["nested"] + view_model.fqn[-1:]) manifest, replace(view_model, fqn=view_model.fqn[:-1] + ["nested"] + view_model.fqn[-1:])
) )
method = statemethod(manifest, previous_state) method = statemethod(manifest, previous_state)
assert search_manifest_using_method(manifest, method, "modified") == {"view_model"} assert search_manifest_using_method(manifest, method, "modified") == {"view_model"}
@@ -1306,7 +1303,7 @@ def test_select_state_added_seed(manifest, previous_state):
def test_select_state_changed_seed_checksum_sha_to_sha(manifest, previous_state, seed): def test_select_state_changed_seed_checksum_sha_to_sha(manifest, previous_state, seed):
change_node(manifest, seed.replace(checksum=FileHash.from_contents("changed"))) change_node(manifest, replace(seed, checksum=FileHash.from_contents("changed")))
method = statemethod(manifest, previous_state) method = statemethod(manifest, previous_state)
assert search_manifest_using_method(manifest, method, "modified") == {"seed"} assert search_manifest_using_method(manifest, method, "modified") == {"seed"}
assert not search_manifest_using_method(manifest, method, "new") assert not search_manifest_using_method(manifest, method, "new")
@@ -1315,10 +1312,10 @@ def test_select_state_changed_seed_checksum_sha_to_sha(manifest, previous_state,
def test_select_state_changed_seed_checksum_path_to_path(manifest, previous_state, seed): def test_select_state_changed_seed_checksum_path_to_path(manifest, previous_state, seed):
change_node( change_node(
previous_state.manifest, previous_state.manifest,
seed.replace(checksum=FileHash(name="path", checksum=seed.original_file_path)), replace(seed, checksum=FileHash(name="path", checksum=seed.original_file_path)),
) )
change_node( change_node(
manifest, seed.replace(checksum=FileHash(name="path", checksum=seed.original_file_path)) manifest, replace(seed, checksum=FileHash(name="path", checksum=seed.original_file_path))
) )
method = statemethod(manifest, previous_state) method = statemethod(manifest, previous_state)
with mock.patch("dbt.contracts.graph.nodes.warn_or_error") as warn_or_error_patch: with mock.patch("dbt.contracts.graph.nodes.warn_or_error") as warn_or_error_patch:
@@ -1335,7 +1332,7 @@ def test_select_state_changed_seed_checksum_path_to_path(manifest, previous_stat
def test_select_state_changed_seed_checksum_sha_to_path(manifest, previous_state, seed): def test_select_state_changed_seed_checksum_sha_to_path(manifest, previous_state, seed):
change_node( change_node(
manifest, seed.replace(checksum=FileHash(name="path", checksum=seed.original_file_path)) manifest, replace(seed, checksum=FileHash(name="path", checksum=seed.original_file_path))
) )
method = statemethod(manifest, previous_state) method = statemethod(manifest, previous_state)
with mock.patch("dbt.contracts.graph.nodes.warn_or_error") as warn_or_error_patch: with mock.patch("dbt.contracts.graph.nodes.warn_or_error") as warn_or_error_patch:
@@ -1353,7 +1350,7 @@ def test_select_state_changed_seed_checksum_sha_to_path(manifest, previous_state
def test_select_state_changed_seed_checksum_path_to_sha(manifest, previous_state, seed): def test_select_state_changed_seed_checksum_path_to_sha(manifest, previous_state, seed):
change_node( change_node(
previous_state.manifest, previous_state.manifest,
seed.replace(checksum=FileHash(name="path", checksum=seed.original_file_path)), replace(seed, checksum=FileHash(name="path", checksum=seed.original_file_path)),
) )
method = statemethod(manifest, previous_state) method = statemethod(manifest, previous_state)
with mock.patch("dbt.contracts.graph.nodes.warn_or_error") as warn_or_error_patch: with mock.patch("dbt.contracts.graph.nodes.warn_or_error") as warn_or_error_patch:
@@ -1365,7 +1362,7 @@ def test_select_state_changed_seed_checksum_path_to_sha(manifest, previous_state
def test_select_state_changed_seed_fqn(manifest, previous_state, seed): def test_select_state_changed_seed_fqn(manifest, previous_state, seed):
change_node(manifest, seed.replace(fqn=seed.fqn[:-1] + ["nested"] + seed.fqn[-1:])) change_node(manifest, replace(seed, fqn=seed.fqn[:-1] + ["nested"] + seed.fqn[-1:]))
method = statemethod(manifest, previous_state) method = statemethod(manifest, previous_state)
assert search_manifest_using_method(manifest, method, "modified") == {"seed"} assert search_manifest_using_method(manifest, method, "modified") == {"seed"}
assert not search_manifest_using_method(manifest, method, "new") assert not search_manifest_using_method(manifest, method, "new")
@@ -1384,7 +1381,7 @@ def test_select_state_changed_seed_relation_documented(manifest, previous_state,
def test_select_state_changed_seed_relation_documented_nodocs(manifest, previous_state, seed): def test_select_state_changed_seed_relation_documented_nodocs(manifest, previous_state, seed):
seed_doc_relation = replace_config(seed, persist_docs={"relation": True}) seed_doc_relation = replace_config(seed, persist_docs={"relation": True})
seed_doc_relation_documented = seed_doc_relation.replace(description="a description") seed_doc_relation_documented = replace(seed_doc_relation, description="a description")
change_node(previous_state.manifest, seed_doc_relation) change_node(previous_state.manifest, seed_doc_relation)
change_node(manifest, seed_doc_relation_documented) change_node(manifest, seed_doc_relation_documented)
method = statemethod(manifest, previous_state) method = statemethod(manifest, previous_state)
@@ -1398,7 +1395,7 @@ def test_select_state_changed_seed_relation_documented_nodocs(manifest, previous
def test_select_state_changed_seed_relation_documented_withdocs(manifest, previous_state, seed): def test_select_state_changed_seed_relation_documented_withdocs(manifest, previous_state, seed):
seed_doc_relation = replace_config(seed, persist_docs={"relation": True}) seed_doc_relation = replace_config(seed, persist_docs={"relation": True})
seed_doc_relation_documented = seed_doc_relation.replace(description="a description") seed_doc_relation_documented = replace(seed_doc_relation, description="a description")
change_node(previous_state.manifest, seed_doc_relation_documented) change_node(previous_state.manifest, seed_doc_relation_documented)
change_node(manifest, seed_doc_relation) change_node(manifest, seed_doc_relation)
method = statemethod(manifest, previous_state) method = statemethod(manifest, previous_state)
@@ -1422,7 +1419,8 @@ def test_select_state_changed_seed_columns_documented(manifest, previous_state,
def test_select_state_changed_seed_columns_documented_nodocs(manifest, previous_state, seed): def test_select_state_changed_seed_columns_documented_nodocs(manifest, previous_state, seed):
seed_doc_columns = replace_config(seed, persist_docs={"columns": True}) seed_doc_columns = replace_config(seed, persist_docs={"columns": True})
seed_doc_columns_documented_columns = seed_doc_columns.replace( seed_doc_columns_documented_columns = replace(
seed_doc_columns,
columns={"a": ColumnInfo(name="a", description="a description")}, columns={"a": ColumnInfo(name="a", description="a description")},
) )
@@ -1440,7 +1438,8 @@ def test_select_state_changed_seed_columns_documented_nodocs(manifest, previous_
def test_select_state_changed_seed_columns_documented_withdocs(manifest, previous_state, seed): def test_select_state_changed_seed_columns_documented_withdocs(manifest, previous_state, seed):
seed_doc_columns = replace_config(seed, persist_docs={"columns": True}) seed_doc_columns = replace_config(seed, persist_docs={"columns": True})
seed_doc_columns_documented_columns = seed_doc_columns.replace( seed_doc_columns_documented_columns = replace(
seed_doc_columns,
columns={"a": ColumnInfo(name="a", description="a description")}, columns={"a": ColumnInfo(name="a", description="a description")},
) )
@@ -1459,8 +1458,8 @@ def test_select_state_changed_seed_columns_documented_withdocs(manifest, previou
def test_select_state_changed_test_macro_sql( def test_select_state_changed_test_macro_sql(
manifest, previous_state, macro_default_test_not_null manifest, previous_state, macro_default_test_not_null
): ):
manifest.macros[macro_default_test_not_null.unique_id] = macro_default_test_not_null.replace( manifest.macros[macro_default_test_not_null.unique_id] = replace(
macro_sql="lalala" macro_default_test_not_null, macro_sql="lalala"
) )
method = statemethod(manifest, previous_state) method = statemethod(manifest, previous_state)
assert search_manifest_using_method(manifest, method, "modified") == { assert search_manifest_using_method(manifest, method, "modified") == {
@@ -1475,7 +1474,7 @@ def test_select_state_changed_test_macro_sql(
def test_select_state_changed_test_macros(manifest, previous_state): def test_select_state_changed_test_macros(manifest, previous_state):
changed_macro = make_macro("dbt", "changed_macro", "blablabla") changed_macro = make_macro("dbt", "changed_macro", "blablabla")
add_macro(manifest, changed_macro) add_macro(manifest, changed_macro)
add_macro(previous_state.manifest, changed_macro.replace(macro_sql="something different")) add_macro(previous_state.manifest, replace(changed_macro, macro_sql="something different"))
unchanged_macro = make_macro("dbt", "unchanged_macro", "blablabla") unchanged_macro = make_macro("dbt", "unchanged_macro", "blablabla")
add_macro(manifest, unchanged_macro) add_macro(manifest, unchanged_macro)
@@ -1512,7 +1511,7 @@ def test_select_state_changed_test_macros(manifest, previous_state):
def test_select_state_changed_test_macros_with_upstream_change(manifest, previous_state): def test_select_state_changed_test_macros_with_upstream_change(manifest, previous_state):
changed_macro = make_macro("dbt", "changed_macro", "blablabla") changed_macro = make_macro("dbt", "changed_macro", "blablabla")
add_macro(manifest, changed_macro) add_macro(manifest, changed_macro)
add_macro(previous_state.manifest, changed_macro.replace(macro_sql="something different")) add_macro(previous_state.manifest, replace(changed_macro, macro_sql="something different"))
unchanged_macro1 = make_macro("dbt", "unchanged_macro", "blablabla") unchanged_macro1 = make_macro("dbt", "unchanged_macro", "blablabla")
add_macro(manifest, unchanged_macro1) add_macro(manifest, unchanged_macro1)

View File

@@ -2,6 +2,7 @@ import agate
import decimal import decimal
import unittest import unittest
from unittest import mock from unittest import mock
from dataclasses import replace
from dbt.task.debug import DebugTask from dbt.task.debug import DebugTask
@@ -130,7 +131,7 @@ class TestPostgresAdapter(unittest.TestCase):
@mock.patch("dbt.adapters.postgres.connections.psycopg2") @mock.patch("dbt.adapters.postgres.connections.psycopg2")
def test_changed_connect_timeout(self, psycopg2): def test_changed_connect_timeout(self, psycopg2):
self.config.credentials = self.config.credentials.replace(connect_timeout=30) self.config.credentials = replace(self.config.credentials, connect_timeout=30)
connection = self.adapter.acquire_connection("dummy") connection = self.adapter.acquire_connection("dummy")
psycopg2.connect.assert_not_called() psycopg2.connect.assert_not_called()
@@ -163,7 +164,7 @@ class TestPostgresAdapter(unittest.TestCase):
@mock.patch("dbt.adapters.postgres.connections.psycopg2") @mock.patch("dbt.adapters.postgres.connections.psycopg2")
def test_changed_keepalive(self, psycopg2): def test_changed_keepalive(self, psycopg2):
self.config.credentials = self.config.credentials.replace(keepalives_idle=256) self.config.credentials = replace(self.config.credentials, keepalives_idle=256)
connection = self.adapter.acquire_connection("dummy") connection = self.adapter.acquire_connection("dummy")
psycopg2.connect.assert_not_called() psycopg2.connect.assert_not_called()
@@ -197,7 +198,7 @@ class TestPostgresAdapter(unittest.TestCase):
@mock.patch("dbt.adapters.postgres.connections.psycopg2") @mock.patch("dbt.adapters.postgres.connections.psycopg2")
def test_changed_application_name(self, psycopg2): def test_changed_application_name(self, psycopg2):
self.config.credentials = self.config.credentials.replace(application_name="myapp") self.config.credentials = replace(self.config.credentials, application_name="myapp")
connection = self.adapter.acquire_connection("dummy") connection = self.adapter.acquire_connection("dummy")
psycopg2.connect.assert_not_called() psycopg2.connect.assert_not_called()
@@ -214,7 +215,7 @@ class TestPostgresAdapter(unittest.TestCase):
@mock.patch("dbt.adapters.postgres.connections.psycopg2") @mock.patch("dbt.adapters.postgres.connections.psycopg2")
def test_role(self, psycopg2): def test_role(self, psycopg2):
self.config.credentials = self.config.credentials.replace(role="somerole") self.config.credentials = replace(self.config.credentials, role="somerole")
connection = self.adapter.acquire_connection("dummy") connection = self.adapter.acquire_connection("dummy")
cursor = connection.handle.cursor() cursor = connection.handle.cursor()
@@ -223,7 +224,7 @@ class TestPostgresAdapter(unittest.TestCase):
@mock.patch("dbt.adapters.postgres.connections.psycopg2") @mock.patch("dbt.adapters.postgres.connections.psycopg2")
def test_search_path(self, psycopg2): def test_search_path(self, psycopg2):
self.config.credentials = self.config.credentials.replace(search_path="test") self.config.credentials = replace(self.config.credentials, search_path="test")
connection = self.adapter.acquire_connection("dummy") connection = self.adapter.acquire_connection("dummy")
psycopg2.connect.assert_not_called() psycopg2.connect.assert_not_called()
@@ -241,7 +242,7 @@ class TestPostgresAdapter(unittest.TestCase):
@mock.patch("dbt.adapters.postgres.connections.psycopg2") @mock.patch("dbt.adapters.postgres.connections.psycopg2")
def test_sslmode(self, psycopg2): def test_sslmode(self, psycopg2):
self.config.credentials = self.config.credentials.replace(sslmode="require") self.config.credentials = replace(self.config.credentials, sslmode="require")
connection = self.adapter.acquire_connection("dummy") connection = self.adapter.acquire_connection("dummy")
psycopg2.connect.assert_not_called() psycopg2.connect.assert_not_called()
@@ -259,10 +260,10 @@ class TestPostgresAdapter(unittest.TestCase):
@mock.patch("dbt.adapters.postgres.connections.psycopg2") @mock.patch("dbt.adapters.postgres.connections.psycopg2")
def test_ssl_parameters(self, psycopg2): def test_ssl_parameters(self, psycopg2):
self.config.credentials = self.config.credentials.replace(sslmode="verify-ca") self.config.credentials = replace(self.config.credentials, sslmode="verify-ca")
self.config.credentials = self.config.credentials.replace(sslcert="service.crt") self.config.credentials = replace(self.config.credentials, sslcert="service.crt")
self.config.credentials = self.config.credentials.replace(sslkey="service.key") self.config.credentials = replace(self.config.credentials, sslkey="service.key")
self.config.credentials = self.config.credentials.replace(sslrootcert="ca.crt") self.config.credentials = replace(self.config.credentials, sslrootcert="ca.crt")
connection = self.adapter.acquire_connection("dummy") connection = self.adapter.acquire_connection("dummy")
psycopg2.connect.assert_not_called() psycopg2.connect.assert_not_called()
@@ -283,7 +284,7 @@ class TestPostgresAdapter(unittest.TestCase):
@mock.patch("dbt.adapters.postgres.connections.psycopg2") @mock.patch("dbt.adapters.postgres.connections.psycopg2")
def test_schema_with_space(self, psycopg2): def test_schema_with_space(self, psycopg2):
self.config.credentials = self.config.credentials.replace(search_path="test test") self.config.credentials = replace(self.config.credentials, search_path="test test")
connection = self.adapter.acquire_connection("dummy") connection = self.adapter.acquire_connection("dummy")
psycopg2.connect.assert_not_called() psycopg2.connect.assert_not_called()
@@ -301,7 +302,7 @@ class TestPostgresAdapter(unittest.TestCase):
@mock.patch("dbt.adapters.postgres.connections.psycopg2") @mock.patch("dbt.adapters.postgres.connections.psycopg2")
def test_set_zero_keepalive(self, psycopg2): def test_set_zero_keepalive(self, psycopg2):
self.config.credentials = self.config.credentials.replace(keepalives_idle=0) self.config.credentials = replace(self.config.credentials, keepalives_idle=0)
connection = self.adapter.acquire_connection("dummy") connection = self.adapter.acquire_connection("dummy")
psycopg2.connect.assert_not_called() psycopg2.connect.assert_not_called()

View File

@@ -7,6 +7,7 @@ import string
import os import os
from unittest import mock from unittest import mock
from unittest import TestCase from unittest import TestCase
from dataclasses import replace
import agate import agate
import pytest import pytest
@@ -379,7 +380,8 @@ def dict_replace(dct, **kwargs):
def replace_config(n, **kwargs): def replace_config(n, **kwargs):
return n.replace( return replace(
n,
config=n.config.replace(**kwargs), config=n.config.replace(**kwargs),
unrendered_config=dict_replace(n.unrendered_config, **kwargs), unrendered_config=dict_replace(n.unrendered_config, **kwargs),
) )