Compare commits

...

42 Commits

Author SHA1 Message Date
Kyle Wigley
0259c50b49 update manifest 2021-02-09 17:43:13 -05:00
Kyle Wigley
7be2992b1b move to core 2021-02-09 17:09:02 -05:00
Kyle Wigley
e8f057d785 hacking extensions in monorepo 2021-02-09 16:45:14 -05:00
Gerda Shank
6c6649f912 Performance fixes, including supporting libyaml, caching
mapped_fields in the classes for 'from_dict', removing deepcopy
on fqn_search, separating validation from 'from_dict',
and special handling for dbt internal not_null and unique tests.
Use TestMacroNamespace instead of original in order to limit
the number of macros in the context.  Integrate mashumaro into
dbt to improve performance of 'from_dict' and 'to_dict'
2021-02-05 15:23:55 -05:00
Kyle Wigley
2b48152da6 Merge branch 'dev/0.19.1' into dev/margaret-mead 2021-01-27 17:16:13 -05:00
Jeremy Cohen
46d36cd412 Merge pull request #3028 from NiallRees/lowercase_cte_names
Make generated CTE test names lowercase to match style guide
2021-01-25 14:39:26 +01:00
NiallRees
a170764fc5 Add to contributors 2021-01-25 11:16:00 +00:00
NiallRees
f72873a1ce Update CHANGELOG.md
Co-authored-by: Jeremy Cohen <jtcohen6@gmail.com>
2021-01-25 11:13:32 +00:00
NiallRees
82496c30b1 Changelog 2021-01-24 16:35:40 +00:00
NiallRees
cb3c007acd Make generated CTE test names lowercase to match style guide 2021-01-24 16:19:20 +00:00
Jeremy Cohen
cb460a797c Merge pull request #3018 from lynxcare/fix-issue-debug-exit-code
dbt debug should return 1 when one of the tests fail
2021-01-21 16:36:03 +01:00
Sam Debruyn
df24c7d2f8 Merge branch 'dev/margaret-mead' into fix-issue-debug-exit-code 2021-01-21 15:39:18 +01:00
Sam Debruyn
133c15c0e2 move in changelog to v0.20 2021-01-21 15:38:31 +01:00
Sam Debruyn
ec0af7c97b remove exitcodes and sys.exit 2021-01-21 10:36:05 +01:00
Jeremy Cohen
a34a877737 Merge pull request #2974 from rvacaru/fix-bug-2731
Fix bug #2731 on stripping query comments for snowflake
2021-01-21 09:54:22 +01:00
Sam Debruyn
f018794465 fix flake test - formatting 2021-01-20 21:09:58 +01:00
Sam Debruyn
d45f5e9791 add missing conditions 2021-01-20 18:15:32 +01:00
Razvan Vacaru
04bd0d834c added extra unit test 2021-01-20 18:06:17 +01:00
Sam Debruyn
ed4f0c4713 formatting 2021-01-20 18:04:21 +01:00
Sam Debruyn
c747068d4a use sys.exit 2021-01-20 16:51:06 +01:00
Sam Debruyn
e91988f679 use ExitCodes enum for exit code 2021-01-20 16:09:41 +01:00
Sam Debruyn
3ed1fce3fb update changelog 2021-01-20 16:06:24 +01:00
Sam Debruyn
e3ea0b511a dbt debug should return 1 when one of the tests fail 2021-01-20 16:00:58 +01:00
Razvan Vacaru
c411c663de moved unit tests and updated changelog.md 2021-01-19 19:04:58 +01:00
Razvan Vacaru
1c6f66fc14 Merge branch 'dev/margaret-mead' of https://github.com/fishtown-analytics/dbt into fix-bug-2731 2021-01-19 19:01:01 +01:00
Jeremy Cohen
1f927a374c Merge pull request #2928 from yu-iskw/issue-1843
Support require_partition_filter and partition_expiration_days in BQ
2021-01-19 12:11:39 +01:00
Jeremy Cohen
07c4225aa8 Merge branch 'dev/margaret-mead' into issue-1843 2021-01-19 11:24:59 +01:00
Razvan Vacaru
16b098ea42 updated CHANGELOG.md 2021-01-04 17:43:03 +01:00
Razvan Vacaru
b31c4d407a Fix #2731 stripping snowflake comments in multiline queries 2021-01-04 17:41:00 +01:00
Yu ISHIKAWA
330065f5e0 Add a condition for require_partition_filter 2020-12-18 11:14:03 +09:00
Yu ISHIKAWA
944db82553 Remove unnecessary code for print debug 2020-12-18 11:14:03 +09:00
Yu ISHIKAWA
c257361f05 Fix syntax 2020-12-18 11:14:03 +09:00
Yu ISHIKAWA
ffdbfb018a Implement tests in test_bigquery_changing_partitions.py 2020-12-18 11:14:01 +09:00
Yu ISHIKAWA
cfa2bd6b08 Remove tests fromm test_bigquery_adapter_specific.py 2020-12-18 11:13:16 +09:00
Yu ISHIKAWA
51e90c3ce0 Format 2020-12-18 11:13:16 +09:00
Yu ISHIKAWA
d69149f43e Update 2020-12-18 11:13:15 +09:00
Yu ISHIKAWA
f261663f3d Add debug code 2020-12-18 11:13:15 +09:00
Yu ISHIKAWA
e5948dd1d3 Update 2020-12-18 11:13:15 +09:00
Yu ISHIKAWA
5f13aab7d8 Print debug 2020-12-18 11:13:15 +09:00
Yu ISHIKAWA
292d489592 Format code 2020-12-18 11:13:15 +09:00
Yu ISHIKAWA
0a01f20e35 Update CHANGELOG.md 2020-12-18 11:13:11 +09:00
Yu ISHIKAWA
2bd08d5c4c Support require_partition_filter and partition_expiration_days in BQ 2020-12-18 11:12:47 +09:00
127 changed files with 2343 additions and 1026 deletions

View File

@@ -1,7 +1,25 @@
## dbt 0.20.0 (Release TBD)
### Fixes
- Fix exit code from dbt debug not returning a failure when one of the tests fail ([#3017](https://github.com/fishtown-analytics/dbt/issues/3017))
- Auto-generated CTEs in tests and ephemeral models have lowercase names to comply with dbt coding conventions ([#3027](https://github.com/fishtown-analytics/dbt/issues/3027), [#3028](https://github.com/fishtown-analytics/dbt/issues/3028))
### Features
- Add optional configs for `require_partition_filter` and `partition_expiration_days` in BigQuery ([#1843](https://github.com/fishtown-analytics/dbt/issues/1843), [#2928](https://github.com/fishtown-analytics/dbt/pull/2928))
- Fix for EOL SQL comments prevent entire line execution ([#2731](https://github.com/fishtown-analytics/dbt/issues/2731), [#2974](https://github.com/fishtown-analytics/dbt/pull/2974))
Contributors:
- [@yu-iskw](https://github.com/yu-iskw) ([#2928](https://github.com/fishtown-analytics/dbt/pull/2928))
- [@sdebruyn](https://github.com/sdebruyn) / [@lynxcare](https://github.com/lynxcare) ([#3018](https://github.com/fishtown-analytics/dbt/pull/3018))
- [@rvacaru](https://github.com/rvacaru) ([#2974](https://github.com/fishtown-analytics/dbt/pull/2974))
- [@NiallRees](https://github.com/NiallRees) ([#3028](https://github.com/fishtown-analytics/dbt/pull/3028))
## dbt 0.19.1 (Release TBD)
### Under the hood
- Bump werkzeug upper bound dependency to `<v2.0` ([#3011](https://github.com/fishtown-analytics/dbt/pull/3011))
- Performance fixes for many different things ([#2862](https://github.com/fishtown-analytics/dbt/issues/2862), [#3034](https://github.com/fishtown-analytics/dbt/pull/3034))
Contributors:
- [@Bl3f](https://github.com/Bl3f) ([#3011](https://github.com/fishtown-analytics/dbt/pull/3011))
@@ -48,6 +66,7 @@ Contributors:
- Normalize cli-style-strings in manifest selectors dictionary ([#2879](https://github.com/fishtown-anaytics/dbt/issues/2879), [#2895](https://github.com/fishtown-analytics/dbt/pull/2895))
- Hourly, monthly and yearly partitions available in BigQuery ([#2476](https://github.com/fishtown-analytics/dbt/issues/2476), [#2903](https://github.com/fishtown-analytics/dbt/pull/2903))
- Allow BigQuery to default to the environment's default project ([#2828](https://github.com/fishtown-analytics/dbt/pull/2828), [#2908](https://github.com/fishtown-analytics/dbt/pull/2908))
- Rationalize run result status reporting and clean up artifact schema ([#2493](https://github.com/fishtown-analytics/dbt/issues/2493), [#2943](https://github.com/fishtown-analytics/dbt/pull/2943))
### Fixes
- Respect `--project-dir` in `dbt clean` command ([#2840](https://github.com/fishtown-analytics/dbt/issues/2840), [#2841](https://github.com/fishtown-analytics/dbt/pull/2841))
@@ -939,7 +958,6 @@ Thanks for your contributions to dbt!
- [@bastienboutonnet](https://github.com/bastienboutonnet) ([#1591](https://github.com/fishtown-analytics/dbt/pull/1591), [#1689](https://github.com/fishtown-analytics/dbt/pull/1689))
## dbt 0.14.0 - Wilt Chamberlain (July 10, 2019)
### Overview

241
core/Cargo.lock generated Normal file
View File

@@ -0,0 +1,241 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
[[package]]
name = "cfg-if"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "ctor"
version = "0.1.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8f45d9ad417bcef4817d614a501ab55cdd96a6fdb24f49aab89a54acfd66b19"
dependencies = [
"quote",
"syn",
]
[[package]]
name = "extensions-tracking"
version = "0.1.0"
dependencies = [
"pyo3",
]
[[package]]
name = "ghost"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1a5bcf1bbeab73aa4cf2fde60a846858dc036163c7c33bec309f8d17de785479"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "indoc"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5a75aeaaef0ce18b58056d306c27b07436fbb34b8816c53094b76dd81803136"
dependencies = [
"unindent",
]
[[package]]
name = "instant"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61124eeebbd69b8190558df225adf7e4caafce0d743919e5d6b19652314ec5ec"
dependencies = [
"cfg-if",
]
[[package]]
name = "inventory"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f0f7efb804ec95e33db9ad49e4252f049e37e8b0a4652e3cd61f7999f2eff7f"
dependencies = [
"ctor",
"ghost",
"inventory-impl",
]
[[package]]
name = "inventory-impl"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75c094e94816723ab936484666968f5b58060492e880f3c8d00489a1e244fa51"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "libc"
version = "0.2.86"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b7282d924be3275cec7f6756ff4121987bc6481325397dde6ba3e7802b1a8b1c"
[[package]]
name = "lock_api"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd96ffd135b2fd7b973ac026d28085defbe8983df057ced3eb4f2130b0831312"
dependencies = [
"scopeguard",
]
[[package]]
name = "parking_lot"
version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d7744ac029df22dca6284efe4e898991d28e3085c706c972bcd7da4a27a15eb"
dependencies = [
"instant",
"lock_api",
"parking_lot_core",
]
[[package]]
name = "parking_lot_core"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ccb628cad4f84851442432c60ad8e1f607e29752d0bf072cbd0baf28aa34272"
dependencies = [
"cfg-if",
"instant",
"libc",
"redox_syscall",
"smallvec",
"winapi",
]
[[package]]
name = "paste"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c5d65c4d95931acda4498f675e332fcbdc9a06705cd07086c510e9b6009cd1c1"
[[package]]
name = "proc-macro2"
version = "1.0.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e0704ee1a7e00d7bb417d0770ea303c1bccbabf0ef1667dae92b5967f5f8a71"
dependencies = [
"unicode-xid",
]
[[package]]
name = "pyo3"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00ca634cf3acd58a599b535ed6cb188223298977d471d146121792bfa23b754c"
dependencies = [
"cfg-if",
"ctor",
"indoc",
"inventory",
"libc",
"parking_lot",
"paste",
"pyo3-macros",
"unindent",
]
[[package]]
name = "pyo3-macros"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "483ac516dbda6789a5b4be0271e7a31b9ad4ec8c0a5955050e8076f72bdbef8f"
dependencies = [
"pyo3-macros-backend",
"quote",
"syn",
]
[[package]]
name = "pyo3-macros-backend"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "15230cabcda008f03565ed8bac40f094cbb5ee1b46e6551f1ec3a0e922cf7df9"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "quote"
version = "1.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "991431c3519a3f36861882da93630ce66b52918dcf1b8e2fd66b397fc96f28df"
dependencies = [
"proc-macro2",
]
[[package]]
name = "redox_syscall"
version = "0.1.57"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41cc0f7e4d5d4544e8861606a285bb08d3e70712ccc7d2b84d7c0ccfaf4b05ce"
[[package]]
name = "scopeguard"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd"
[[package]]
name = "smallvec"
version = "1.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fe0f37c9e8f3c5a4a66ad655a93c74daac4ad00c441533bf5c6e7990bb42604e"
[[package]]
name = "syn"
version = "1.0.60"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c700597eca8a5a762beb35753ef6b94df201c81cca676604f547495a0d7f0081"
dependencies = [
"proc-macro2",
"quote",
"unicode-xid",
]
[[package]]
name = "unicode-xid"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7fe0bb3479651439c9112f72b6c505038574c9fbb575ed1bf3b797fa39dd564"
[[package]]
name = "unindent"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f14ee04d9415b52b3aeab06258a3f07093182b88ba0f9b8d203f211a7a7d41c7"
[[package]]
name = "winapi"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
dependencies = [
"winapi-i686-pc-windows-gnu",
"winapi-x86_64-pc-windows-gnu",
]
[[package]]
name = "winapi-i686-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"

2
core/Cargo.toml Normal file
View File

@@ -0,0 +1,2 @@
[workspace]
members = [ "extensions/tracking",]

View File

@@ -1 +1,3 @@
recursive-include dbt/include *.py *.sql *.yml *.html *.md
recursive-include extensions *
include Cargo.toml

View File

@@ -1,14 +1,12 @@
from dataclasses import dataclass
import re
from hologram import JsonSchemaMixin
from dbt.exceptions import RuntimeException
from typing import Dict, ClassVar, Any, Optional
from dbt.exceptions import RuntimeException
@dataclass
class Column(JsonSchemaMixin):
class Column:
TYPE_LABELS: ClassVar[Dict[str, str]] = {
'STRING': 'TEXT',
'TIMESTAMP': 'TIMESTAMP',

View File

@@ -28,7 +28,7 @@ from dbt.clients.jinja import MacroGenerator
from dbt.contracts.graph.compiled import (
CompileResultNode, CompiledSeedNode
)
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.manifest import Manifest, MacroManifest
from dbt.contracts.graph.parsed import ParsedSeedNode
from dbt.exceptions import warn_or_error
from dbt.node_types import NodeType
@@ -160,7 +160,7 @@ class BaseAdapter(metaclass=AdapterMeta):
self.config = config
self.cache = RelationsCache()
self.connections = self.ConnectionManager(config)
self._macro_manifest_lazy: Optional[Manifest] = None
self._macro_manifest_lazy: Optional[MacroManifest] = None
###
# Methods that pass through to the connection manager
@@ -259,18 +259,18 @@ class BaseAdapter(metaclass=AdapterMeta):
return cls.ConnectionManager.TYPE
@property
def _macro_manifest(self) -> Manifest:
def _macro_manifest(self) -> MacroManifest:
if self._macro_manifest_lazy is None:
return self.load_macro_manifest()
return self._macro_manifest_lazy
def check_macro_manifest(self) -> Optional[Manifest]:
def check_macro_manifest(self) -> Optional[MacroManifest]:
"""Return the internal manifest (used for executing macros) if it's
been initialized, otherwise return None.
"""
return self._macro_manifest_lazy
def load_macro_manifest(self) -> Manifest:
def load_macro_manifest(self) -> MacroManifest:
if self._macro_manifest_lazy is None:
# avoid a circular import
from dbt.parser.manifest import load_macro_manifest

View File

@@ -21,8 +21,8 @@ Self = TypeVar('Self', bound='BaseRelation')
@dataclass(frozen=True, eq=False, repr=False)
class BaseRelation(FakeAPIObject, Hashable):
type: Optional[RelationType]
path: Path
type: Optional[RelationType] = None
quote_character: str = '"'
include_policy: Policy = Policy()
quote_policy: Policy = Policy()
@@ -203,7 +203,7 @@ class BaseRelation(FakeAPIObject, Hashable):
@staticmethod
def add_ephemeral_prefix(name: str):
return f'__dbt__CTE__{name}'
return f'__dbt__cte__{name}'
@classmethod
def create_ephemeral_from_node(

View File

@@ -231,6 +231,7 @@ class BaseMacroGenerator:
template = self.get_template()
# make the module. previously we set both vars and local, but that's
# redundant: They both end up in the same place
# make_module is in jinja2.environment. It returns a TemplateModule
module = template.make_module(vars=self.context, shared=False)
macro = module.__dict__[get_dbt_macro_name(name)]
module.__dict__.update(self.context)
@@ -244,6 +245,7 @@ class BaseMacroGenerator:
raise_compiler_error(str(e))
def call_macro(self, *args, **kwargs):
# called from __call__ methods
if self.context is None:
raise InternalException(
'Context is still None in call_macro!'
@@ -306,8 +308,10 @@ class MacroGenerator(BaseMacroGenerator):
e.stack.append(self.macro)
raise e
# This adds the macro's unique id to the node's 'depends_on'
@contextmanager
def track_call(self):
# This is only called from __call__
if self.stack is None or self.node is None:
yield
else:
@@ -322,6 +326,7 @@ class MacroGenerator(BaseMacroGenerator):
finally:
self.stack.pop(unique_id)
# this makes MacroGenerator objects callable like functions
def __call__(self, *args, **kwargs):
with self.track_call():
return self.call_macro(*args, **kwargs)

View File

@@ -438,7 +438,9 @@ def run_cmd(
return out, err
def download(url: str, path: str, timeout: Union[float, tuple] = None) -> None:
def download(
url: str, path: str, timeout: Optional[Union[float, tuple]] = None
) -> None:
path = convert_path(path)
connection_timeout = timeout or float(os.getenv('DBT_HTTP_TIMEOUT', 10))
response = requests.get(url, timeout=connection_timeout)

View File

@@ -1,16 +1,19 @@
from typing import Any
import dbt.exceptions
import yaml
import yaml.scanner
# the C version is faster, but it doesn't always exist
YamlLoader: Any
try:
from yaml import CSafeLoader as YamlLoader
from yaml import (
CLoader as Loader,
CSafeLoader as SafeLoader,
CDumper as Dumper
)
except ImportError:
from yaml import SafeLoader as YamlLoader
from yaml import ( # type: ignore # noqa: F401
Loader, SafeLoader, Dumper
)
YAML_ERROR_MESSAGE = """
@@ -54,7 +57,7 @@ def contextualized_yaml_error(raw_contents, error):
def safe_load(contents):
return yaml.load(contents, Loader=YamlLoader)
return yaml.load(contents, Loader=SafeLoader)
def load_yaml_text(contents):

View File

@@ -191,11 +191,11 @@ class Compiler:
[
InjectedCTE(
id="cte_id_1",
sql="__dbt__CTE__ephemeral as (select * from table)",
sql="__dbt__cte__ephemeral as (select * from table)",
),
InjectedCTE(
id="cte_id_2",
sql="__dbt__CTE__events as (select id, type from events)",
sql="__dbt__cte__events as (select id, type from events)",
),
]
@@ -206,8 +206,8 @@ class Compiler:
This will spit out:
"with __dbt__CTE__ephemeral as (select * from table),
__dbt__CTE__events as (select id, type from events),
"with __dbt__cte__ephemeral as (select * from table),
__dbt__cte__events as (select id, type from events),
with internal_cte as (select * from sessions)
select * from internal_cte"
@@ -246,7 +246,7 @@ class Compiler:
return str(parsed)
def _get_dbt_test_name(self) -> str:
return 'dbt__CTE__INTERNAL_test'
return 'dbt__cte__internal_test'
# This method is called by the 'compile_node' method. Starting
# from the node that it is passed in, it will recursively call

View File

@@ -2,7 +2,7 @@ from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple
import os
from hologram import ValidationError
from dbt.dataclass_schema import ValidationError
from dbt.clients.system import load_file_contents
from dbt.clients.yaml_helper import load_yaml_text
@@ -75,6 +75,7 @@ def read_user_config(directory: str) -> UserConfig:
if profile:
user_cfg = coerce_dict_str(profile.get('config', {}))
if user_cfg is not None:
UserConfig.validate(user_cfg)
return UserConfig.from_dict(user_cfg)
except (RuntimeException, ValidationError):
pass
@@ -137,10 +138,10 @@ class Profile(HasCredentials):
def validate(self):
try:
if self.credentials:
self.credentials.to_dict(validate=True)
ProfileConfig.from_dict(
self.to_profile_info(serialize_credentials=True)
)
dct = self.credentials.to_dict()
self.credentials.validate(dct)
dct = self.to_profile_info(serialize_credentials=True)
ProfileConfig.validate(dct)
except ValidationError as exc:
raise DbtProfileError(validator_error_message(exc)) from exc
@@ -160,7 +161,9 @@ class Profile(HasCredentials):
typename = profile.pop('type')
try:
cls = load_plugin(typename)
credentials = cls.from_dict(profile)
data = cls.translate_aliases(profile)
cls.validate(data)
credentials = cls.from_dict(data)
except (RuntimeException, ValidationError) as e:
msg = str(e) if isinstance(e, RuntimeException) else e.message
raise DbtProfileError(
@@ -233,6 +236,7 @@ class Profile(HasCredentials):
"""
if user_cfg is None:
user_cfg = {}
UserConfig.validate(user_cfg)
config = UserConfig.from_dict(user_cfg)
profile = cls(

View File

@@ -26,15 +26,12 @@ from dbt.version import get_installed_version
from dbt.utils import MultiDict
from dbt.node_types import NodeType
from dbt.config.selectors import SelectorDict
from dbt.contracts.project import (
Project as ProjectContract,
SemverString,
)
from dbt.contracts.project import PackageConfig
from hologram import ValidationError
from dbt.dataclass_schema import ValidationError
from .renderer import DbtProjectYamlRenderer
from .selectors import (
selector_config_from_data,
@@ -101,6 +98,7 @@ def package_config_from_data(packages_data: Dict[str, Any]):
packages_data = {'packages': []}
try:
PackageConfig.validate(packages_data)
packages = PackageConfig.from_dict(packages_data)
except ValidationError as e:
raise DbtProjectError(
@@ -306,7 +304,10 @@ class PartialProject(RenderComponents):
)
try:
cfg = ProjectContract.from_dict(rendered.project_dict)
ProjectContract.validate(rendered.project_dict)
cfg = ProjectContract.from_dict(
rendered.project_dict
)
except ValidationError as e:
raise DbtProjectError(validator_error_message(e)) from e
# name/version are required in the Project definition, so we can assume
@@ -586,7 +587,7 @@ class Project:
def validate(self):
try:
ProjectContract.from_dict(self.to_project_config())
ProjectContract.validate(self.to_project_config())
except ValidationError as e:
raise DbtProjectError(validator_error_message(e)) from e

View File

@@ -33,7 +33,7 @@ from dbt.exceptions import (
raise_compiler_error
)
from hologram import ValidationError
from dbt.dataclass_schema import ValidationError
def _project_quoting_dict(
@@ -174,7 +174,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
:raises DbtProjectError: If the configuration fails validation.
"""
try:
Configuration.from_dict(self.serialize())
Configuration.validate(self.serialize())
except ValidationError as e:
raise DbtProjectError(validator_error_message(e)) from e
@@ -391,7 +391,7 @@ class UnsetConfig(UserConfig):
f"'UnsetConfig' object has no attribute {name}"
)
def to_dict(self):
def __post_serialize__(self, dct, options=None):
return {}

View File

@@ -1,8 +1,9 @@
from pathlib import Path
from typing import Dict, Any
import yaml
from hologram import ValidationError
from dbt.clients.yaml_helper import ( # noqa: F401
yaml, Loader, Dumper, load_yaml_text
)
from dbt.dataclass_schema import ValidationError
from .renderer import SelectorRenderer
@@ -11,7 +12,6 @@ from dbt.clients.system import (
path_exists,
resolve_path_from_base,
)
from dbt.clients.yaml_helper import load_yaml_text
from dbt.contracts.selection import SelectorFile
from dbt.exceptions import DbtSelectorsError, RuntimeException
from dbt.graph import parse_from_selectors_definition, SelectionSpec
@@ -30,9 +30,11 @@ Validator Error:
class SelectorConfig(Dict[str, SelectionSpec]):
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'SelectorConfig':
def selectors_from_dict(cls, data: Dict[str, Any]) -> 'SelectorConfig':
try:
SelectorFile.validate(data)
selector_file = SelectorFile.from_dict(data)
selectors = parse_from_selectors_definition(selector_file)
except ValidationError as exc:
@@ -66,7 +68,7 @@ class SelectorConfig(Dict[str, SelectionSpec]):
f'Could not render selector data: {exc}',
result_type='invalid_selector',
) from exc
return cls.from_dict(rendered)
return cls.selectors_from_dict(rendered)
@classmethod
def from_path(
@@ -107,7 +109,7 @@ def selector_config_from_data(
selectors_data = {'selectors': []}
try:
selectors = SelectorConfig.from_dict(selectors_data)
selectors = SelectorConfig.selectors_from_dict(selectors_data)
except ValidationError as e:
raise DbtSelectorsError(
MALFORMED_SELECTOR_ERROR.format(error=str(e.message)),

View File

@@ -7,13 +7,14 @@ from typing import (
from dbt import flags
from dbt import tracking
from dbt.clients.jinja import undefined_error, get_rendered
from dbt.clients import yaml_helper
from dbt.clients.yaml_helper import ( # noqa: F401
yaml, safe_load, SafeLoader, Loader, Dumper
)
from dbt.contracts.graph.compiled import CompiledResource
from dbt.exceptions import raise_compiler_error, MacroReturn
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.version import __version__ as dbt_version
import yaml
# These modules are added to the context. Consider alternative
# approaches which will extend well to potentially many modules
import pytz
@@ -172,6 +173,7 @@ class BaseContext(metaclass=ContextMeta):
builtins[key] = value
return builtins
# no dbtClassMixin so this is not an actual override
def to_dict(self):
self._ctx['context'] = self._ctx
builtins = self.generate_builtins()
@@ -394,7 +396,7 @@ class BaseContext(metaclass=ContextMeta):
-- ["good"]
"""
try:
return yaml_helper.safe_load(value)
return safe_load(value)
except (AttributeError, ValueError, yaml.YAMLError):
return default

View File

@@ -165,7 +165,7 @@ class ContextConfigGenerator(BaseContextConfigGenerator[C]):
# Calculate the defaults. We don't want to validate the defaults,
# because it might be invalid in the case of required config members
# (such as on snapshots!)
result = config_cls.from_dict({}, validate=False)
result = config_cls.from_dict({})
return result
def _update_from_config(

View File

@@ -0,0 +1,153 @@
from typing import (
Dict, MutableMapping, Optional
)
from dbt.contracts.graph.parsed import ParsedMacro
from dbt.exceptions import raise_duplicate_macro_name, raise_compiler_error
from dbt.include.global_project import PROJECT_NAME as GLOBAL_PROJECT_NAME
from dbt.clients.jinja import MacroGenerator
MacroNamespace = Dict[str, ParsedMacro]
# This class builds the MacroResolver by adding macros
# to various categories for finding macros in the right order,
# so that higher precedence macros are found first.
# This functionality is also provided by the MacroNamespace,
# but the intention is to eventually replace that class.
# This enables us to get the macor unique_id without
# processing every macro in the project.
class MacroResolver:
def __init__(
self,
macros: MutableMapping[str, ParsedMacro],
root_project_name: str,
internal_package_names,
) -> None:
self.root_project_name = root_project_name
self.macros = macros
# internal packages comes from get_adapter_package_names
self.internal_package_names = internal_package_names
# To be filled in from macros.
self.internal_packages: Dict[str, MacroNamespace] = {}
self.packages: Dict[str, MacroNamespace] = {}
self.root_package_macros: MacroNamespace = {}
# add the macros to internal_packages, packages, and root packages
self.add_macros()
self._build_internal_packages_namespace()
self._build_macros_by_name()
def _build_internal_packages_namespace(self):
# Iterate in reverse-order and overwrite: the packages that are first
# in the list are the ones we want to "win".
self.internal_packages_namespace: MacroNamespace = {}
for pkg in reversed(self.internal_package_names):
if pkg in self.internal_packages:
# Turn the internal packages into a flat namespace
self.internal_packages_namespace.update(
self.internal_packages[pkg])
def _build_macros_by_name(self):
macros_by_name = {}
# search root package macros
for macro in self.root_package_macros.values():
macros_by_name[macro.name] = macro
# search miscellaneous non-internal packages
for fnamespace in self.packages.values():
for macro in fnamespace.values():
macros_by_name[macro.name] = macro
# search all internal packages
for macro in self.internal_packages_namespace.values():
macros_by_name[macro.name] = macro
self.macros_by_name = macros_by_name
def _add_macro_to(
self,
package_namespaces: Dict[str, MacroNamespace],
macro: ParsedMacro,
):
if macro.package_name in package_namespaces:
namespace = package_namespaces[macro.package_name]
else:
namespace = {}
package_namespaces[macro.package_name] = namespace
if macro.name in namespace:
raise_duplicate_macro_name(
macro, macro, macro.package_name
)
package_namespaces[macro.package_name][macro.name] = macro
def add_macro(self, macro: ParsedMacro):
macro_name: str = macro.name
# internal macros (from plugins) will be processed separately from
# project macros, so store them in a different place
if macro.package_name in self.internal_package_names:
self._add_macro_to(self.internal_packages, macro)
else:
# if it's not an internal package
self._add_macro_to(self.packages, macro)
# add to root_package_macros if it's in the root package
if macro.package_name == self.root_project_name:
self.root_package_macros[macro_name] = macro
def add_macros(self):
for macro in self.macros.values():
self.add_macro(macro)
def get_macro_id(self, local_package, macro_name):
local_package_macros = {}
if (local_package not in self.internal_package_names and
local_package in self.packages):
local_package_macros = self.packages[local_package]
# First: search the local packages for this macro
if macro_name in local_package_macros:
return local_package_macros[macro_name].unique_id
if macro_name in self.macros_by_name:
return self.macros_by_name[macro_name].unique_id
return None
# Currently this is just used by test processing in the schema
# parser (in connection with the MacroResolver). Future work
# will extend the use of these classes to other parsing areas.
# One of the features of this class compared to the MacroNamespace
# is that you can limit the number of macros provided to the
# context dictionary in the 'to_dict' manifest method.
class TestMacroNamespace:
def __init__(
self, macro_resolver, ctx, node, thread_ctx, depends_on_macros
):
self.macro_resolver = macro_resolver
self.ctx = ctx
self.node = node
self.thread_ctx = thread_ctx
local_namespace = {}
if depends_on_macros:
for macro_unique_id in depends_on_macros:
macro = self.manifest.macros[macro_unique_id]
local_namespace[macro.name] = MacroGenerator(
macro, self.ctx, self.node, self.thread_ctx,
)
self.local_namespace = local_namespace
def get_from_package(
self, package_name: Optional[str], name: str
) -> Optional[MacroGenerator]:
macro = None
if package_name is None:
macro = self.macro_resolver.macros_by_name.get(name)
elif package_name == GLOBAL_PROJECT_NAME:
macro = self.macro_resolver.internal_packages_namespace.get(name)
elif package_name in self.resolver.packages:
macro = self.macro_resolver.packages[package_name].get(name)
else:
raise_compiler_error(
f"Could not find package '{package_name}'"
)
macro_func = MacroGenerator(
macro, self.ctx, self.node, self.thread_ctx
)
return macro_func

View File

@@ -15,6 +15,10 @@ NamespaceMember = Union[FlatNamespace, MacroGenerator]
FullNamespace = Dict[str, NamespaceMember]
# The point of this class is to collect the various macros
# and provide the ability to flatten them into the ManifestContexts
# that are created for jinja, so that macro calls can be resolved.
# Creates special iterators and _keys methods to flatten the lists.
class MacroNamespace(Mapping):
def __init__(
self,
@@ -37,12 +41,16 @@ class MacroNamespace(Mapping):
}
yield self.global_project_namespace
# provides special keys method for MacroNamespace iterator
# returns keys from local_namespace, global_namespace, packages,
# global_project_namespace
def _keys(self) -> Set[str]:
keys: Set[str] = set()
for search in self._search_order():
keys.update(search)
return keys
# special iterator using special keys
def __iter__(self) -> Iterator[str]:
for key in self._keys():
yield key
@@ -72,6 +80,10 @@ class MacroNamespace(Mapping):
)
# This class builds the MacroNamespace by adding macros to
# internal_packages or packages, and locals/globals.
# Call 'build_namespace' to return a MacroNamespace.
# This is used by ManifestContext (and subclasses)
class MacroNamespaceBuilder:
def __init__(
self,
@@ -83,10 +95,15 @@ class MacroNamespaceBuilder:
) -> None:
self.root_package = root_package
self.search_package = search_package
# internal packages comes from get_adapter_package_names
self.internal_package_names = set(internal_packages)
self.internal_package_names_order = internal_packages
# macro_func is added here if in root package
self.globals: FlatNamespace = {}
# macro_func is added here if it's the package for this node
self.locals: FlatNamespace = {}
# Create a dictionary of [package name][macro name] =
# MacroGenerator object which acts like a function
self.internal_packages: Dict[str, FlatNamespace] = {}
self.packages: Dict[str, FlatNamespace] = {}
self.thread_ctx = thread_ctx
@@ -94,25 +111,28 @@ class MacroNamespaceBuilder:
def _add_macro_to(
self,
heirarchy: Dict[str, FlatNamespace],
hierarchy: Dict[str, FlatNamespace],
macro: ParsedMacro,
macro_func: MacroGenerator,
):
if macro.package_name in heirarchy:
namespace = heirarchy[macro.package_name]
if macro.package_name in hierarchy:
namespace = hierarchy[macro.package_name]
else:
namespace = {}
heirarchy[macro.package_name] = namespace
hierarchy[macro.package_name] = namespace
if macro.name in namespace:
raise_duplicate_macro_name(
macro_func.macro, macro, macro.package_name
)
heirarchy[macro.package_name][macro.name] = macro_func
hierarchy[macro.package_name][macro.name] = macro_func
def add_macro(self, macro: ParsedMacro, ctx: Dict[str, Any]):
macro_name: str = macro.name
# MacroGenerator is in clients/jinja.py
# a MacroGenerator object is a callable object that will
# execute the MacroGenerator.__call__ function
macro_func: MacroGenerator = MacroGenerator(
macro, ctx, self.node, self.thread_ctx
)
@@ -122,10 +142,12 @@ class MacroNamespaceBuilder:
if macro.package_name in self.internal_package_names:
self._add_macro_to(self.internal_packages, macro, macro_func)
else:
# if it's not an internal package
self._add_macro_to(self.packages, macro, macro_func)
# add to locals if it's the package this node is in
if macro.package_name == self.search_package:
self.locals[macro_name] = macro_func
# add to globals if it's in the root package
elif macro.package_name == self.root_package:
self.globals[macro_name] = macro_func
@@ -143,6 +165,7 @@ class MacroNamespaceBuilder:
global_project_namespace: FlatNamespace = {}
for pkg in reversed(self.internal_package_names_order):
if pkg in self.internal_packages:
# add the macros pointed to by this package name
global_project_namespace.update(self.internal_packages[pkg])
return MacroNamespace(

View File

@@ -2,7 +2,8 @@ from typing import List
from dbt.clients.jinja import MacroStack
from dbt.contracts.connection import AdapterRequiredConfig
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.manifest import Manifest, AnyManifest
from dbt.context.macro_resolver import TestMacroNamespace
from .configured import ConfiguredContext
@@ -19,17 +20,25 @@ class ManifestContext(ConfiguredContext):
def __init__(
self,
config: AdapterRequiredConfig,
manifest: Manifest,
manifest: AnyManifest,
search_package: str,
) -> None:
super().__init__(config)
self.manifest = manifest
# this is the package of the node for which this context was built
self.search_package = search_package
self.macro_stack = MacroStack()
# This namespace is used by the BaseDatabaseWrapper in jinja rendering.
# The namespace is passed to it when it's constructed. It expects
# to be able to do: namespace.get_from_package(..)
self.namespace = self._build_namespace()
def _build_namespace(self):
# this takes all the macros in the manifest and adds them
# to the MacroNamespaceBuilder stored in self.namespace
builder = self._get_namespace_builder()
self.namespace = builder.build_namespace(
self.manifest.macros.values(),
self._ctx,
return builder.build_namespace(
self.manifest.macros.values(), self._ctx
)
def _get_namespace_builder(self) -> MacroNamespaceBuilder:
@@ -46,9 +55,15 @@ class ManifestContext(ConfiguredContext):
None,
)
# This does not use the Mashumaro code
def to_dict(self):
dct = super().to_dict()
dct.update(self.namespace)
# This moves all of the macros in the 'namespace' into top level
# keys in the manifest dictionary
if isinstance(self.namespace, TestMacroNamespace):
dct.update(self.namespace.local_namespace)
else:
dct.update(self.namespace)
return dct

View File

@@ -10,15 +10,18 @@ from dbt import deprecations
from dbt.adapters.base.column import Column
from dbt.adapters.factory import get_adapter, get_adapter_package_names
from dbt.clients import agate_helper
from dbt.clients.jinja import get_rendered, MacroGenerator
from dbt.clients.jinja import get_rendered, MacroGenerator, MacroStack
from dbt.config import RuntimeConfig, Project
from .base import contextmember, contextproperty, Var
from .configured import FQNLookup
from .context_config import ContextConfig
from dbt.context.macro_resolver import MacroResolver, TestMacroNamespace
from .macros import MacroNamespaceBuilder, MacroNamespace
from .manifest import ManifestContext
from dbt.contracts.graph.manifest import Manifest, Disabled
from dbt.contracts.connection import AdapterResponse
from dbt.contracts.graph.manifest import (
Manifest, AnyManifest, Disabled, MacroManifest
)
from dbt.contracts.graph.compiled import (
CompiledResource,
CompiledSeedNode,
@@ -141,6 +144,7 @@ class BaseDatabaseWrapper:
for prefix in self._get_adapter_macro_prefixes():
search_name = f'{prefix}__{macro_name}'
try:
# this uses the namespace from the context
macro = self._namespace.get_from_package(
package_name, search_name
)
@@ -638,10 +642,13 @@ class ProviderContext(ManifestContext):
self.context_config: Optional[ContextConfig] = context_config
self.provider: Provider = provider
self.adapter = get_adapter(self.config)
# The macro namespace is used in creating the DatabaseWrapper
self.db_wrapper = self.provider.DatabaseWrapper(
self.adapter, self.namespace
)
# This overrides the method in ManifestContext, and provides
# a model, which the ManifestContext builder does not
def _get_namespace_builder(self):
internal_packages = get_adapter_package_names(
self.config.credentials.type
@@ -1203,7 +1210,7 @@ class MacroContext(ProviderContext):
self,
model: ParsedMacro,
config: RuntimeConfig,
manifest: Manifest,
manifest: AnyManifest,
provider: Provider,
search_package: Optional[str],
) -> None:
@@ -1289,34 +1296,28 @@ class ModelContext(ProviderContext):
return self.db_wrapper.Relation.create_from(self.config, self.model)
# This is called by '_context_for', used in 'render_with_context'
def generate_parser_model(
model: ManifestNode,
config: RuntimeConfig,
manifest: Manifest,
manifest: MacroManifest,
context_config: ContextConfig,
) -> Dict[str, Any]:
# The __init__ method of ModelContext also initializes
# a ManifestContext object which creates a MacroNamespaceBuilder
# which adds every macro in the Manifest.
ctx = ModelContext(
model, config, manifest, ParseProvider(), context_config
)
return ctx.to_dict()
def generate_parser_macro(
macro: ParsedMacro,
config: RuntimeConfig,
manifest: Manifest,
package_name: Optional[str],
) -> Dict[str, Any]:
ctx = MacroContext(
macro, config, manifest, ParseProvider(), package_name
)
# The 'to_dict' method in ManifestContext moves all of the macro names
# in the macro 'namespace' up to top level keys
return ctx.to_dict()
def generate_generate_component_name_macro(
macro: ParsedMacro,
config: RuntimeConfig,
manifest: Manifest,
manifest: MacroManifest,
) -> Dict[str, Any]:
ctx = MacroContext(
macro, config, manifest, GenerateNameProvider(), None
@@ -1369,7 +1370,7 @@ class ExposureSourceResolver(BaseResolver):
def generate_parse_exposure(
exposure: ParsedExposure,
config: RuntimeConfig,
manifest: Manifest,
manifest: MacroManifest,
package_name: str,
) -> Dict[str, Any]:
project = config.load_dependencies()[package_name]
@@ -1387,3 +1388,57 @@ def generate_parse_exposure(
manifest,
)
}
# This class is currently used by the schema parser in order
# to limit the number of macros in the context by using
# the TestMacroNamespace
class TestContext(ProviderContext):
def __init__(
self,
model,
config: RuntimeConfig,
manifest: Manifest,
provider: Provider,
context_config: Optional[ContextConfig],
macro_resolver: MacroResolver,
) -> None:
# this must be before super init so that macro_resolver exists for
# build_namespace
self.macro_resolver = macro_resolver
self.thread_ctx = MacroStack()
super().__init__(model, config, manifest, provider, context_config)
self._build_test_namespace
def _build_namespace(self):
return {}
# this overrides _build_namespace in ManifestContext which provides a
# complete namespace of all macros to only specify macros in the depends_on
# This only provides a namespace with macros in the test node
# 'depends_on.macros' by using the TestMacroNamespace
def _build_test_namespace(self):
depends_on_macros = []
if self.model.depends_on and self.model.depends_on.macros:
depends_on_macros = self.model.depends_on.macros
macro_namespace = TestMacroNamespace(
self.macro_resolver, self.ctx, self.node, self.thread_ctx,
depends_on_macros
)
self._namespace = macro_namespace
def generate_test_context(
model: ManifestNode,
config: RuntimeConfig,
manifest: Manifest,
context_config: ContextConfig,
macro_resolver: MacroResolver
) -> Dict[str, Any]:
ctx = TestContext(
model, config, manifest, ParseProvider(), context_config,
macro_resolver
)
# The 'to_dict' method in ManifestContext moves all of the macro names
# in the macro 'namespace' up to top level keys
return ctx.to_dict()

View File

@@ -2,28 +2,29 @@ import abc
import itertools
from dataclasses import dataclass, field
from typing import (
Any, ClassVar, Dict, Tuple, Iterable, Optional, NewType, List, Callable,
Any, ClassVar, Dict, Tuple, Iterable, Optional, List, Callable,
)
from typing_extensions import Protocol
from hologram import JsonSchemaMixin
from hologram.helpers import (
StrEnum, register_pattern, ExtensibleJsonSchemaMixin
)
from dbt.contracts.util import Replaceable
from dbt.exceptions import InternalException
from dbt.utils import translate_aliases
from dbt.logger import GLOBAL_LOGGER as logger
from typing_extensions import Protocol
from dbt.dataclass_schema import (
dbtClassMixin, StrEnum, ExtensibleDbtClassMixin,
ValidatedStringMixin, register_pattern
)
from dbt.contracts.util import Replaceable
Identifier = NewType('Identifier', str)
class Identifier(ValidatedStringMixin):
ValidationRegex = r'^[A-Za-z_][A-Za-z0-9_]+$'
# we need register_pattern for jsonschema validation
register_pattern(Identifier, r'^[A-Za-z_][A-Za-z0-9_]+$')
@dataclass
class AdapterResponse(JsonSchemaMixin):
class AdapterResponse(dbtClassMixin):
_message: str
code: Optional[str] = None
rows_affected: Optional[int] = None
@@ -40,20 +41,19 @@ class ConnectionState(StrEnum):
@dataclass(init=False)
class Connection(ExtensibleJsonSchemaMixin, Replaceable):
class Connection(ExtensibleDbtClassMixin, Replaceable):
type: Identifier
name: Optional[str]
name: Optional[str] = None
state: ConnectionState = ConnectionState.INIT
transaction_open: bool = False
# prevent serialization
_handle: Optional[Any] = None
_credentials: JsonSchemaMixin = field(init=False)
_credentials: Optional[Any] = None
def __init__(
self,
type: Identifier,
name: Optional[str],
credentials: JsonSchemaMixin,
credentials: dbtClassMixin,
state: ConnectionState = ConnectionState.INIT,
transaction_open: bool = False,
handle: Optional[Any] = None,
@@ -113,7 +113,7 @@ class LazyHandle:
# will work.
@dataclass # type: ignore
class Credentials(
ExtensibleJsonSchemaMixin,
ExtensibleDbtClassMixin,
Replaceable,
metaclass=abc.ABCMeta
):
@@ -132,7 +132,7 @@ class Credentials(
) -> Iterable[Tuple[str, Any]]:
"""Return an ordered iterator of key/value pairs for pretty-printing.
"""
as_dict = self.to_dict(omit_none=False, with_aliases=with_aliases)
as_dict = self.to_dict(options={'keep_none': True})
connection_keys = set(self._connection_keys())
aliases: List[str] = []
if with_aliases:
@@ -148,9 +148,10 @@ class Credentials(
raise NotImplementedError
@classmethod
def from_dict(cls, data):
def __pre_deserialize__(cls, data, options=None):
data = super().__pre_deserialize__(data, options=options)
data = cls.translate_aliases(data)
return super().from_dict(data)
return data
@classmethod
def translate_aliases(
@@ -158,31 +159,26 @@ class Credentials(
) -> Dict[str, Any]:
return translate_aliases(kwargs, cls._ALIASES, recurse)
def to_dict(self, omit_none=True, validate=False, *, with_aliases=False):
serialized = super().to_dict(omit_none=omit_none, validate=validate)
if with_aliases:
serialized.update({
new_name: serialized[canonical_name]
def __post_serialize__(self, dct, options=None):
# no super() -- do we need it?
if self._ALIASES:
dct.update({
new_name: dct[canonical_name]
for new_name, canonical_name in self._ALIASES.items()
if canonical_name in serialized
if canonical_name in dct
})
return serialized
return dct
class UserConfigContract(Protocol):
send_anonymous_usage_stats: bool
use_colors: Optional[bool]
partial_parse: Optional[bool]
printer_width: Optional[int]
use_colors: Optional[bool] = None
partial_parse: Optional[bool] = None
printer_width: Optional[int] = None
def set_values(self, cookie_dir: str) -> None:
...
def to_dict(
self, omit_none: bool = True, validate: bool = False
) -> Dict[str, Any]:
...
class HasCredentials(Protocol):
credentials: Credentials
@@ -216,7 +212,7 @@ DEFAULT_QUERY_COMMENT = '''
@dataclass
class QueryComment(JsonSchemaMixin):
class QueryComment(dbtClassMixin):
comment: str = DEFAULT_QUERY_COMMENT
append: bool = False

View File

@@ -3,7 +3,7 @@ import os
from dataclasses import dataclass, field
from typing import List, Optional, Union
from hologram import JsonSchemaMixin
from dbt.dataclass_schema import dbtClassMixin
from dbt.exceptions import InternalException
@@ -15,7 +15,7 @@ MAXIMUM_SEED_SIZE_NAME = '1MB'
@dataclass
class FilePath(JsonSchemaMixin):
class FilePath(dbtClassMixin):
searched_path: str
relative_path: str
project_root: str
@@ -51,7 +51,7 @@ class FilePath(JsonSchemaMixin):
@dataclass
class FileHash(JsonSchemaMixin):
class FileHash(dbtClassMixin):
name: str # the hash type name
checksum: str # the hashlib.hash_type().hexdigest() of the file contents
@@ -91,7 +91,7 @@ class FileHash(JsonSchemaMixin):
@dataclass
class RemoteFile(JsonSchemaMixin):
class RemoteFile(dbtClassMixin):
@property
def searched_path(self) -> str:
return 'from remote system'
@@ -110,7 +110,7 @@ class RemoteFile(JsonSchemaMixin):
@dataclass
class SourceFile(JsonSchemaMixin):
class SourceFile(dbtClassMixin):
"""Define a source file in dbt"""
path: Union[FilePath, RemoteFile] # the path information
checksum: FileHash

View File

@@ -19,19 +19,19 @@ from dbt.contracts.graph.parsed import (
from dbt.node_types import NodeType
from dbt.contracts.util import Replaceable
from hologram import JsonSchemaMixin
from dbt.dataclass_schema import dbtClassMixin
from dataclasses import dataclass, field
from typing import Optional, List, Union, Dict, Type
@dataclass
class InjectedCTE(JsonSchemaMixin, Replaceable):
class InjectedCTE(dbtClassMixin, Replaceable):
id: str
sql: str
@dataclass
class CompiledNodeMixin(JsonSchemaMixin):
class CompiledNodeMixin(dbtClassMixin):
# this is a special mixin class to provide a required argument. If a node
# is missing a `compiled` flag entirely, it must not be a CompiledNode.
compiled: bool
@@ -178,8 +178,7 @@ def parsed_instance_for(compiled: CompiledNode) -> ParsedResource:
raise ValueError('invalid resource_type: {}'
.format(compiled.resource_type))
# validate=False to allow extra keys from compiling
return cls.from_dict(compiled.to_dict(), validate=False)
return cls.from_dict(compiled.to_dict())
NonSourceCompiledNode = Union[

View File

@@ -428,158 +428,13 @@ def _update_into(dest: MutableMapping[str, T], new_item: T):
dest[unique_id] = new_item
@dataclass
class Manifest:
"""The manifest for the full graph, after parsing and during compilation.
"""
# These attributes are both positional and by keyword. If an attribute
# is added it must all be added in the __reduce_ex__ method in the
# args tuple in the right position.
nodes: MutableMapping[str, ManifestNode]
sources: MutableMapping[str, ParsedSourceDefinition]
macros: MutableMapping[str, ParsedMacro]
docs: MutableMapping[str, ParsedDocumentation]
exposures: MutableMapping[str, ParsedExposure]
selectors: MutableMapping[str, Any]
disabled: List[CompileResultNode]
files: MutableMapping[str, SourceFile]
metadata: ManifestMetadata = field(default_factory=ManifestMetadata)
flat_graph: Dict[str, Any] = field(default_factory=dict)
_docs_cache: Optional[DocCache] = None
_sources_cache: Optional[SourceCache] = None
_refs_cache: Optional[RefableCache] = None
_lock: Lock = field(default_factory=flags.MP_CONTEXT.Lock)
@classmethod
def from_macros(
cls,
macros: Optional[MutableMapping[str, ParsedMacro]] = None,
files: Optional[MutableMapping[str, SourceFile]] = None,
) -> 'Manifest':
if macros is None:
macros = {}
if files is None:
files = {}
return cls(
nodes={},
sources={},
macros=macros,
docs={},
exposures={},
selectors={},
disabled=[],
files=files,
)
def sync_update_node(
self, new_node: NonSourceCompiledNode
) -> NonSourceCompiledNode:
"""update the node with a lock. The only time we should want to lock is
when compiling an ephemeral ancestor of a node at runtime, because
multiple threads could be just-in-time compiling the same ephemeral
dependency, and we want them to have a consistent view of the manifest.
If the existing node is not compiled, update it with the new node and
return that. If the existing node is compiled, do not update the
manifest and return the existing node.
"""
with self._lock:
existing = self.nodes[new_node.unique_id]
if getattr(existing, 'compiled', False):
# already compiled -> must be a NonSourceCompiledNode
return cast(NonSourceCompiledNode, existing)
_update_into(self.nodes, new_node)
return new_node
def update_exposure(self, new_exposure: ParsedExposure):
_update_into(self.exposures, new_exposure)
def update_node(self, new_node: ManifestNode):
_update_into(self.nodes, new_node)
def update_source(self, new_source: ParsedSourceDefinition):
_update_into(self.sources, new_source)
def build_flat_graph(self):
"""This attribute is used in context.common by each node, so we want to
only build it once and avoid any concurrency issues around it.
Make sure you don't call this until you're done with building your
manifest!
"""
self.flat_graph = {
'nodes': {
k: v.to_dict(omit_none=False) for k, v in self.nodes.items()
},
'sources': {
k: v.to_dict(omit_none=False) for k, v in self.sources.items()
}
}
def find_disabled_by_name(
self, name: str, package: Optional[str] = None
) -> Optional[ManifestNode]:
searcher: NameSearcher = NameSearcher(
name, package, NodeType.refable()
)
result = searcher.search(self.disabled)
return result
def find_disabled_source_by_name(
self, source_name: str, table_name: str, package: Optional[str] = None
) -> Optional[ParsedSourceDefinition]:
search_name = f'{source_name}.{table_name}'
searcher: NameSearcher = NameSearcher(
search_name, package, [NodeType.Source]
)
result = searcher.search(self.disabled)
if result is not None:
assert isinstance(result, ParsedSourceDefinition)
return result
def _find_macros_by_name(
self,
name: str,
root_project_name: str,
filter: Optional[Callable[[MacroCandidate], bool]] = None
) -> CandidateList:
"""Find macros by their name.
"""
# avoid an import cycle
from dbt.adapters.factory import get_adapter_package_names
candidates: CandidateList = CandidateList()
packages = set(get_adapter_package_names(self.metadata.adapter_type))
for unique_id, macro in self.macros.items():
if macro.name != name:
continue
candidate = MacroCandidate(
locality=_get_locality(macro, root_project_name, packages),
macro=macro,
)
if filter is None or filter(candidate):
candidates.append(candidate)
return candidates
def _materialization_candidates_for(
self, project_name: str,
materialization_name: str,
adapter_type: Optional[str],
) -> CandidateList:
if adapter_type is None:
specificity = Specificity.Default
else:
specificity = Specificity.Adapter
full_name = dbt.utils.get_materialization_macro_name(
materialization_name=materialization_name,
adapter_type=adapter_type,
with_prefix=False,
)
return CandidateList(
MaterializationCandidate.from_macro(m, specificity)
for m in self._find_macros_by_name(full_name, project_name)
)
# This contains macro methods that are in both the Manifest
# and the MacroManifest
class MacroMethods:
# Just to make mypy happy. There must be a better way.
def __init__(self):
self.macros = []
self.metadata = {}
def find_macro_by_name(
self, name: str, root_project_name: str, package: Optional[str]
@@ -625,6 +480,141 @@ class Manifest:
)
return candidates.last()
def _find_macros_by_name(
self,
name: str,
root_project_name: str,
filter: Optional[Callable[[MacroCandidate], bool]] = None
) -> CandidateList:
"""Find macros by their name.
"""
# avoid an import cycle
from dbt.adapters.factory import get_adapter_package_names
candidates: CandidateList = CandidateList()
packages = set(get_adapter_package_names(self.metadata.adapter_type))
for unique_id, macro in self.macros.items():
if macro.name != name:
continue
candidate = MacroCandidate(
locality=_get_locality(macro, root_project_name, packages),
macro=macro,
)
if filter is None or filter(candidate):
candidates.append(candidate)
return candidates
@dataclass
class Manifest(MacroMethods):
"""The manifest for the full graph, after parsing and during compilation.
"""
# These attributes are both positional and by keyword. If an attribute
# is added it must all be added in the __reduce_ex__ method in the
# args tuple in the right position.
nodes: MutableMapping[str, ManifestNode]
sources: MutableMapping[str, ParsedSourceDefinition]
macros: MutableMapping[str, ParsedMacro]
docs: MutableMapping[str, ParsedDocumentation]
exposures: MutableMapping[str, ParsedExposure]
selectors: MutableMapping[str, Any]
disabled: List[CompileResultNode]
files: MutableMapping[str, SourceFile]
metadata: ManifestMetadata = field(default_factory=ManifestMetadata)
flat_graph: Dict[str, Any] = field(default_factory=dict)
_docs_cache: Optional[DocCache] = None
_sources_cache: Optional[SourceCache] = None
_refs_cache: Optional[RefableCache] = None
_lock: Lock = field(default_factory=flags.MP_CONTEXT.Lock)
def sync_update_node(
self, new_node: NonSourceCompiledNode
) -> NonSourceCompiledNode:
"""update the node with a lock. The only time we should want to lock is
when compiling an ephemeral ancestor of a node at runtime, because
multiple threads could be just-in-time compiling the same ephemeral
dependency, and we want them to have a consistent view of the manifest.
If the existing node is not compiled, update it with the new node and
return that. If the existing node is compiled, do not update the
manifest and return the existing node.
"""
with self._lock:
existing = self.nodes[new_node.unique_id]
if getattr(existing, 'compiled', False):
# already compiled -> must be a NonSourceCompiledNode
return cast(NonSourceCompiledNode, existing)
_update_into(self.nodes, new_node)
return new_node
def update_exposure(self, new_exposure: ParsedExposure):
_update_into(self.exposures, new_exposure)
def update_node(self, new_node: ManifestNode):
_update_into(self.nodes, new_node)
def update_source(self, new_source: ParsedSourceDefinition):
_update_into(self.sources, new_source)
def build_flat_graph(self):
"""This attribute is used in context.common by each node, so we want to
only build it once and avoid any concurrency issues around it.
Make sure you don't call this until you're done with building your
manifest!
"""
self.flat_graph = {
'nodes': {
k: v.to_dict(options={'keep_none': True})
for k, v in self.nodes.items()
},
'sources': {
k: v.to_dict(options={'keep_none': True})
for k, v in self.sources.items()
}
}
def find_disabled_by_name(
self, name: str, package: Optional[str] = None
) -> Optional[ManifestNode]:
searcher: NameSearcher = NameSearcher(
name, package, NodeType.refable()
)
result = searcher.search(self.disabled)
return result
def find_disabled_source_by_name(
self, source_name: str, table_name: str, package: Optional[str] = None
) -> Optional[ParsedSourceDefinition]:
search_name = f'{source_name}.{table_name}'
searcher: NameSearcher = NameSearcher(
search_name, package, [NodeType.Source]
)
result = searcher.search(self.disabled)
if result is not None:
assert isinstance(result, ParsedSourceDefinition)
return result
def _materialization_candidates_for(
self, project_name: str,
materialization_name: str,
adapter_type: Optional[str],
) -> CandidateList:
if adapter_type is None:
specificity = Specificity.Default
else:
specificity = Specificity.Adapter
full_name = dbt.utils.get_materialization_macro_name(
materialization_name=materialization_name,
adapter_type=adapter_type,
with_prefix=False,
)
return CandidateList(
MaterializationCandidate.from_macro(m, specificity)
for m in self._find_macros_by_name(full_name, project_name)
)
def find_materialization_macro_by_name(
self, project_name: str, materialization_name: str, adapter_type: str
) -> Optional[ParsedMacro]:
@@ -763,10 +753,10 @@ class Manifest:
parent_map=backward_edges,
)
def to_dict(self, omit_none=True, validate=False):
return self.writable_manifest().to_dict(
omit_none=omit_none, validate=validate
)
# When 'to_dict' is called on the Manifest, it substitues a
# WritableManifest
def __pre_serialize__(self, options=None):
return self.writable_manifest()
def write(self, path):
self.writable_manifest().write(path)
@@ -944,6 +934,19 @@ class Manifest:
return self.__class__, args
class MacroManifest(MacroMethods):
def __init__(self, macros, files):
self.macros = macros
self.files = files
self.metadata = ManifestMetadata()
# This is returned by the 'graph' context property
# in the ProviderContext class.
self.flat_graph = {}
AnyManifest = Union[Manifest, MacroManifest]
@dataclass
@schema_version('manifest', 1)
class WritableManifest(ArtifactMixin):

View File

@@ -2,19 +2,12 @@ from dataclasses import field, Field, dataclass
from enum import Enum
from itertools import chain
from typing import (
Any, List, Optional, Dict, MutableMapping, Union, Type, NewType, Tuple,
TypeVar, Callable, cast, Hashable
Any, List, Optional, Dict, MutableMapping, Union, Type,
TypeVar, Callable,
)
from dbt.dataclass_schema import (
dbtClassMixin, ValidationError, register_pattern,
)
# TODO: patch+upgrade hologram to avoid this jsonschema import
import jsonschema # type: ignore
# This is protected, but we really do want to reuse this logic, and the cache!
# It would be nice to move the custom error picking stuff into hologram!
from hologram import _validate_schema
from hologram import JsonSchemaMixin, ValidationError
from hologram.helpers import StrEnum, register_pattern
from dbt.contracts.graph.unparsed import AdditionalPropertiesAllowed
from dbt.exceptions import CompilationException, InternalException
from dbt.contracts.util import Replaceable, list_str
@@ -170,22 +163,15 @@ def insensitive_patterns(*patterns: str):
return '^({})$'.format('|'.join(lowercased))
Severity = NewType('Severity', str)
class Severity(str):
pass
register_pattern(Severity, insensitive_patterns('warn', 'error'))
class SnapshotStrategy(StrEnum):
Timestamp = 'timestamp'
Check = 'check'
class All(StrEnum):
All = 'all'
@dataclass
class Hook(JsonSchemaMixin, Replaceable):
class Hook(dbtClassMixin, Replaceable):
sql: str
transaction: bool = True
index: Optional[int] = None
@@ -313,29 +299,6 @@ class BaseConfig(
)
return result
def to_dict(
self,
omit_none: bool = True,
validate: bool = False,
*,
omit_hidden: bool = True,
) -> Dict[str, Any]:
result = super().to_dict(omit_none=omit_none, validate=validate)
if omit_hidden and not omit_none:
for fld, target_field in self._get_fields():
if target_field not in result:
continue
# if the field is not None, preserve it regardless of the
# setting. This is in line with existing behavior, but isn't
# an endorsement of it!
if result[target_field] is not None:
continue
if not ShowBehavior.should_show(fld):
del result[target_field]
return result
def update_from(
self: T, data: Dict[str, Any], adapter_type: str, validate: bool = True
) -> T:
@@ -344,7 +307,7 @@ class BaseConfig(
"""
# sadly, this is a circular import
from dbt.adapters.factory import get_config_class_by_name
dct = self.to_dict(omit_none=False, validate=False, omit_hidden=False)
dct = self.to_dict(options={'keep_none': True})
adapter_config_cls = get_config_class_by_name(adapter_type)
@@ -358,21 +321,23 @@ class BaseConfig(
dct.update(data)
# any validation failures must have come from the update
return self.from_dict(dct, validate=validate)
if validate:
self.validate(dct)
return self.from_dict(dct)
def finalize_and_validate(self: T) -> T:
# from_dict will validate for us
dct = self.to_dict(omit_none=False, validate=False)
dct = self.to_dict(options={'keep_none': True})
self.validate(dct)
return self.from_dict(dct)
def replace(self, **kwargs):
dct = self.to_dict(validate=False)
dct = self.to_dict()
mapping = self.field_mapping()
for key, value in kwargs.items():
new_key = mapping.get(key, key)
dct[new_key] = value
return self.from_dict(dct, validate=False)
return self.from_dict(dct)
@dataclass
@@ -431,12 +396,33 @@ class NodeConfig(BaseConfig):
full_refresh: Optional[bool] = None
@classmethod
def from_dict(cls, data, validate=True):
def __pre_deserialize__(cls, data, options=None):
data = super().__pre_deserialize__(data, options=options)
field_map = {'post-hook': 'post_hook', 'pre-hook': 'pre_hook'}
# create a new dict because otherwise it gets overwritten in
# tests
new_dict = {}
for key in data:
new_dict[key] = data[key]
data = new_dict
for key in hooks.ModelHookType:
if key in data:
data[key] = [hooks.get_hook_dict(h) for h in data[key]]
return super().from_dict(data, validate=validate)
for field_name in field_map:
if field_name in data:
new_name = field_map[field_name]
data[new_name] = data.pop(field_name)
return data
def __post_serialize__(self, dct, options=None):
dct = super().__post_serialize__(dct, options=options)
field_map = {'post_hook': 'post-hook', 'pre_hook': 'pre-hook'}
for field_name in field_map:
if field_name in dct:
dct[field_map[field_name]] = dct.pop(field_name)
return dct
# this is still used by jsonschema validation
@classmethod
def field_mapping(cls):
return {'post_hook': 'post-hook', 'pre_hook': 'pre-hook'}
@@ -454,182 +440,49 @@ class TestConfig(NodeConfig):
severity: Severity = Severity('ERROR')
SnapshotVariants = Union[
'TimestampSnapshotConfig',
'CheckSnapshotConfig',
'GenericSnapshotConfig',
]
def _relevance_without_strategy(error: jsonschema.ValidationError):
# calculate the 'relevance' of an error the normal jsonschema way, except
# if the validator is in the 'strategy' field and its conflicting with the
# 'enum'. This suppresses `"'timestamp' is not one of ['check']` and such
if 'strategy' in error.path and error.validator in {'enum', 'not'}:
length = 1
else:
length = -len(error.path)
validator = error.validator
return length, validator not in {'anyOf', 'oneOf'}
@dataclass
class SnapshotWrapper(JsonSchemaMixin):
"""This is a little wrapper to let us serialize/deserialize the
SnapshotVariants union.
"""
config: SnapshotVariants # mypy: ignore
@classmethod
def validate(cls, data: Any):
config = data.get('config', {})
if config.get('strategy') == 'check':
schema = _validate_schema(CheckSnapshotConfig)
to_validate = config
elif config.get('strategy') == 'timestamp':
schema = _validate_schema(TimestampSnapshotConfig)
to_validate = config
else:
h_cls = cast(Hashable, cls)
schema = _validate_schema(h_cls)
to_validate = data
validator = jsonschema.Draft7Validator(schema)
error = jsonschema.exceptions.best_match(
validator.iter_errors(to_validate),
key=_relevance_without_strategy,
)
if error is not None:
raise ValidationError.create_from(error) from error
@dataclass
class EmptySnapshotConfig(NodeConfig):
materialized: str = 'snapshot'
@dataclass(init=False)
@dataclass
class SnapshotConfig(EmptySnapshotConfig):
unique_key: str = field(init=False, metadata=dict(init_required=True))
target_schema: str = field(init=False, metadata=dict(init_required=True))
strategy: Optional[str] = None
unique_key: Optional[str] = None
target_schema: Optional[str] = None
target_database: Optional[str] = None
updated_at: Optional[str] = None
check_cols: Optional[Union[str, List[str]]] = None
def __init__(
self,
unique_key: str,
target_schema: str,
target_database: Optional[str] = None,
**kwargs
) -> None:
self.unique_key = unique_key
self.target_schema = target_schema
self.target_database = target_database
# kwargs['materialized'] = materialized
super().__init__(**kwargs)
# type hacks...
@classmethod
def _get_fields(cls) -> List[Tuple[Field, str]]: # type: ignore
fields: List[Tuple[Field, str]] = []
for old_field, name in super()._get_fields():
new_field = old_field
# tell hologram we're really an initvar
if old_field.metadata and old_field.metadata.get('init_required'):
new_field = field(init=True, metadata=old_field.metadata)
new_field.name = old_field.name
new_field.type = old_field.type
new_field._field_type = old_field._field_type # type: ignore
fields.append((new_field, name))
return fields
def validate(cls, data):
super().validate(data)
if data.get('strategy') == 'check':
if not data.get('check_cols'):
raise ValidationError(
"A snapshot configured with the check strategy must "
"specify a check_cols configuration.")
if (isinstance(data['check_cols'], str) and
data['check_cols'] != 'all'):
raise ValidationError(
f"Invalid value for 'check_cols': {data['check_cols']}. "
"Expected 'all' or a list of strings.")
def finalize_and_validate(self: 'SnapshotConfig') -> SnapshotVariants:
elif data.get('strategy') == 'timestamp':
if not data.get('updated_at'):
raise ValidationError(
"A snapshot configured with the timestamp strategy "
"must specify an updated_at configuration.")
if data.get('check_cols'):
raise ValidationError(
"A 'timestamp' snapshot should not have 'check_cols'")
# If the strategy is not 'check' or 'timestamp' it's a custom strategy,
# formerly supported with GenericSnapshotConfig
def finalize_and_validate(self):
data = self.to_dict()
return SnapshotWrapper.from_dict({'config': data}).config
@dataclass(init=False)
class GenericSnapshotConfig(SnapshotConfig):
strategy: str = field(init=False, metadata=dict(init_required=True))
def __init__(self, strategy: str, **kwargs) -> None:
self.strategy = strategy
super().__init__(**kwargs)
@classmethod
def _collect_json_schema(
cls, definitions: Dict[str, Any]
) -> Dict[str, Any]:
# this is the method you want to override in hologram if you want
# to do clever things about the json schema and have classes that
# contain instances of your JsonSchemaMixin respect the change.
schema = super()._collect_json_schema(definitions)
# Instead of just the strategy we'd calculate normally, say
# "this strategy except none of our specialization strategies".
strategies = [schema['properties']['strategy']]
for specialization in (TimestampSnapshotConfig, CheckSnapshotConfig):
strategies.append(
{'not': specialization.json_schema()['properties']['strategy']}
)
schema['properties']['strategy'] = {
'allOf': strategies
}
return schema
@dataclass(init=False)
class TimestampSnapshotConfig(SnapshotConfig):
strategy: str = field(
init=False,
metadata=dict(
restrict=[str(SnapshotStrategy.Timestamp)],
init_required=True,
),
)
updated_at: str = field(init=False, metadata=dict(init_required=True))
def __init__(
self, strategy: str, updated_at: str, **kwargs
) -> None:
self.strategy = strategy
self.updated_at = updated_at
super().__init__(**kwargs)
@dataclass(init=False)
class CheckSnapshotConfig(SnapshotConfig):
strategy: str = field(
init=False,
metadata=dict(
restrict=[str(SnapshotStrategy.Check)],
init_required=True,
),
)
# TODO: is there a way to get this to accept tuples of strings? Adding
# `Tuple[str, ...]` to the list of types results in this:
# ['email'] is valid under each of {'type': 'array', 'items':
# {'type': 'string'}}, {'type': 'array', 'items': {'type': 'string'}}
# but without it, parsing gets upset about values like `('email',)`
# maybe hologram itself should support this behavior? It's not like tuples
# are meaningful in json
check_cols: Union[All, List[str]] = field(
init=False,
metadata=dict(init_required=True),
)
def __init__(
self, strategy: str, check_cols: Union[All, List[str]],
**kwargs
) -> None:
self.strategy = strategy
self.check_cols = check_cols
super().__init__(**kwargs)
self.validate(data)
return self.from_dict(data)
RESOURCE_TYPES: Dict[NodeType, Type[BaseConfig]] = {

View File

@@ -13,8 +13,9 @@ from typing import (
TypeVar,
)
from hologram import JsonSchemaMixin
from hologram.helpers import ExtensibleJsonSchemaMixin
from dbt.dataclass_schema import (
dbtClassMixin, ExtensibleDbtClassMixin
)
from dbt.clients.system import write_file
from dbt.contracts.files import FileHash, MAXIMUM_SEED_SIZE_NAME
@@ -38,20 +39,14 @@ from .model_config import (
TestConfig,
SourceConfig,
EmptySnapshotConfig,
SnapshotVariants,
)
# import these 3 so the SnapshotVariants forward ref works.
from .model_config import ( # noqa
TimestampSnapshotConfig,
CheckSnapshotConfig,
GenericSnapshotConfig,
SnapshotConfig,
)
@dataclass
class ColumnInfo(
AdditionalPropertiesMixin,
ExtensibleJsonSchemaMixin,
ExtensibleDbtClassMixin,
Replaceable
):
name: str
@@ -64,7 +59,7 @@ class ColumnInfo(
@dataclass
class HasFqn(JsonSchemaMixin, Replaceable):
class HasFqn(dbtClassMixin, Replaceable):
fqn: List[str]
def same_fqn(self, other: 'HasFqn') -> bool:
@@ -72,12 +67,12 @@ class HasFqn(JsonSchemaMixin, Replaceable):
@dataclass
class HasUniqueID(JsonSchemaMixin, Replaceable):
class HasUniqueID(dbtClassMixin, Replaceable):
unique_id: str
@dataclass
class MacroDependsOn(JsonSchemaMixin, Replaceable):
class MacroDependsOn(dbtClassMixin, Replaceable):
macros: List[str] = field(default_factory=list)
# 'in' on lists is O(n) so this is O(n^2) for # of macros
@@ -96,12 +91,22 @@ class DependsOn(MacroDependsOn):
@dataclass
class HasRelationMetadata(JsonSchemaMixin, Replaceable):
class HasRelationMetadata(dbtClassMixin, Replaceable):
database: Optional[str]
schema: str
# Can't set database to None like it ought to be
# because it messes up the subclasses and default parameters
# so hack it here
@classmethod
def __pre_deserialize__(cls, data, options=None):
data = super().__pre_deserialize__(data, options=options)
if 'database' not in data:
data['database'] = None
return data
class ParsedNodeMixins(JsonSchemaMixin):
class ParsedNodeMixins(dbtClassMixin):
resource_type: NodeType
depends_on: DependsOn
config: NodeConfig
@@ -132,8 +137,12 @@ class ParsedNodeMixins(JsonSchemaMixin):
self.meta = patch.meta
self.docs = patch.docs
if flags.STRICT_MODE:
assert isinstance(self, JsonSchemaMixin)
self.to_dict(validate=True, omit_none=False)
# It seems odd that an instance can be invalid
# Maybe there should be validation or restrictions
# elsewhere?
assert isinstance(self, dbtClassMixin)
dct = self.to_dict(options={'keep_none': True})
self.validate(dct)
def get_materialization(self):
return self.config.materialized
@@ -335,14 +344,14 @@ class ParsedSeedNode(ParsedNode):
@dataclass
class TestMetadata(JsonSchemaMixin, Replaceable):
namespace: Optional[str]
class TestMetadata(dbtClassMixin, Replaceable):
name: str
kwargs: Dict[str, Any]
kwargs: Dict[str, Any] = field(default_factory=dict)
namespace: Optional[str] = None
@dataclass
class HasTestMetadata(JsonSchemaMixin):
class HasTestMetadata(dbtClassMixin):
test_metadata: TestMetadata
@@ -394,7 +403,7 @@ class IntermediateSnapshotNode(ParsedNode):
@dataclass
class ParsedSnapshotNode(ParsedNode):
resource_type: NodeType = field(metadata={'restrict': [NodeType.Snapshot]})
config: SnapshotVariants
config: SnapshotConfig
@dataclass
@@ -443,8 +452,10 @@ class ParsedMacro(UnparsedBaseNode, HasUniqueID):
self.docs = patch.docs
self.arguments = patch.arguments
if flags.STRICT_MODE:
assert isinstance(self, JsonSchemaMixin)
self.to_dict(validate=True, omit_none=False)
# What does this actually validate?
assert isinstance(self, dbtClassMixin)
dct = self.to_dict(options={'keep_none': True})
self.validate(dct)
def same_contents(self, other: Optional['ParsedMacro']) -> bool:
if other is None:

View File

@@ -8,8 +8,9 @@ from dbt.contracts.util import (
import dbt.helper_types # noqa:F401
from dbt.exceptions import CompilationException
from hologram import JsonSchemaMixin
from hologram.helpers import StrEnum, ExtensibleJsonSchemaMixin
from dbt.dataclass_schema import (
dbtClassMixin, StrEnum, ExtensibleDbtClassMixin
)
from dataclasses import dataclass, field
from datetime import timedelta
@@ -18,7 +19,7 @@ from typing import Optional, List, Union, Dict, Any, Sequence
@dataclass
class UnparsedBaseNode(JsonSchemaMixin, Replaceable):
class UnparsedBaseNode(dbtClassMixin, Replaceable):
package_name: str
root_path: str
path: str
@@ -66,12 +67,12 @@ class UnparsedRunHook(UnparsedNode):
@dataclass
class Docs(JsonSchemaMixin, Replaceable):
class Docs(dbtClassMixin, Replaceable):
show: bool = True
@dataclass
class HasDocs(AdditionalPropertiesMixin, ExtensibleJsonSchemaMixin,
class HasDocs(AdditionalPropertiesMixin, ExtensibleDbtClassMixin,
Replaceable):
name: str
description: str = ''
@@ -100,7 +101,7 @@ class UnparsedColumn(HasTests):
@dataclass
class HasColumnDocs(JsonSchemaMixin, Replaceable):
class HasColumnDocs(dbtClassMixin, Replaceable):
columns: Sequence[HasDocs] = field(default_factory=list)
@@ -110,7 +111,7 @@ class HasColumnTests(HasColumnDocs):
@dataclass
class HasYamlMetadata(JsonSchemaMixin):
class HasYamlMetadata(dbtClassMixin):
original_file_path: str
yaml_key: str
package_name: str
@@ -127,7 +128,7 @@ class UnparsedNodeUpdate(HasColumnTests, HasTests, HasYamlMetadata):
@dataclass
class MacroArgument(JsonSchemaMixin):
class MacroArgument(dbtClassMixin):
name: str
type: Optional[str] = None
description: str = ''
@@ -148,7 +149,7 @@ class TimePeriod(StrEnum):
@dataclass
class Time(JsonSchemaMixin, Replaceable):
class Time(dbtClassMixin, Replaceable):
count: int
period: TimePeriod
@@ -159,7 +160,7 @@ class Time(JsonSchemaMixin, Replaceable):
@dataclass
class FreshnessThreshold(JsonSchemaMixin, Mergeable):
class FreshnessThreshold(dbtClassMixin, Mergeable):
warn_after: Optional[Time] = None
error_after: Optional[Time] = None
filter: Optional[str] = None
@@ -180,7 +181,7 @@ class FreshnessThreshold(JsonSchemaMixin, Mergeable):
@dataclass
class AdditionalPropertiesAllowed(
AdditionalPropertiesMixin,
ExtensibleJsonSchemaMixin
ExtensibleDbtClassMixin
):
_extra: Dict[str, Any] = field(default_factory=dict)
@@ -212,7 +213,7 @@ class ExternalTable(AdditionalPropertiesAllowed, Mergeable):
@dataclass
class Quoting(JsonSchemaMixin, Mergeable):
class Quoting(dbtClassMixin, Mergeable):
database: Optional[bool] = None
schema: Optional[bool] = None
identifier: Optional[bool] = None
@@ -230,15 +231,18 @@ class UnparsedSourceTableDefinition(HasColumnTests, HasTests):
external: Optional[ExternalTable] = None
tags: List[str] = field(default_factory=list)
def to_dict(self, omit_none=True, validate=False):
result = super().to_dict(omit_none=omit_none, validate=validate)
if omit_none and self.freshness is None:
result['freshness'] = None
return result
def __post_serialize__(self, dct, options=None):
dct = super().__post_serialize__(dct)
keep_none = False
if options and 'keep_none' in options and options['keep_none']:
keep_none = True
if not keep_none and self.freshness is None:
dct['freshness'] = None
return dct
@dataclass
class UnparsedSourceDefinition(JsonSchemaMixin, Replaceable):
class UnparsedSourceDefinition(dbtClassMixin, Replaceable):
name: str
description: str = ''
meta: Dict[str, Any] = field(default_factory=dict)
@@ -257,15 +261,18 @@ class UnparsedSourceDefinition(JsonSchemaMixin, Replaceable):
def yaml_key(self) -> 'str':
return 'sources'
def to_dict(self, omit_none=True, validate=False):
result = super().to_dict(omit_none=omit_none, validate=validate)
if omit_none and self.freshness is None:
result['freshness'] = None
return result
def __post_serialize__(self, dct, options=None):
dct = super().__post_serialize__(dct)
keep_none = False
if options and 'keep_none' in options and options['keep_none']:
keep_none = True
if not keep_none and self.freshness is None:
dct['freshness'] = None
return dct
@dataclass
class SourceTablePatch(JsonSchemaMixin):
class SourceTablePatch(dbtClassMixin):
name: str
description: Optional[str] = None
meta: Optional[Dict[str, Any]] = None
@@ -283,7 +290,7 @@ class SourceTablePatch(JsonSchemaMixin):
columns: Optional[Sequence[UnparsedColumn]] = None
def to_patch_dict(self) -> Dict[str, Any]:
dct = self.to_dict(omit_none=True)
dct = self.to_dict()
remove_keys = ('name')
for key in remove_keys:
if key in dct:
@@ -296,7 +303,7 @@ class SourceTablePatch(JsonSchemaMixin):
@dataclass
class SourcePatch(JsonSchemaMixin, Replaceable):
class SourcePatch(dbtClassMixin, Replaceable):
name: str = field(
metadata=dict(description='The name of the source to override'),
)
@@ -320,7 +327,7 @@ class SourcePatch(JsonSchemaMixin, Replaceable):
tags: Optional[List[str]] = None
def to_patch_dict(self) -> Dict[str, Any]:
dct = self.to_dict(omit_none=True)
dct = self.to_dict()
remove_keys = ('name', 'overrides', 'tables', 'path')
for key in remove_keys:
if key in dct:
@@ -340,7 +347,7 @@ class SourcePatch(JsonSchemaMixin, Replaceable):
@dataclass
class UnparsedDocumentation(JsonSchemaMixin, Replaceable):
class UnparsedDocumentation(dbtClassMixin, Replaceable):
package_name: str
root_path: str
path: str
@@ -400,13 +407,13 @@ class MaturityType(StrEnum):
@dataclass
class ExposureOwner(JsonSchemaMixin, Replaceable):
class ExposureOwner(dbtClassMixin, Replaceable):
email: str
name: Optional[str] = None
@dataclass
class UnparsedExposure(JsonSchemaMixin, Replaceable):
class UnparsedExposure(dbtClassMixin, Replaceable):
name: str
type: ExposureType
owner: ExposureOwner

View File

@@ -4,25 +4,39 @@ from dbt.helper_types import NoValue
from dbt.logger import GLOBAL_LOGGER as logger # noqa
from dbt import tracking
from dbt import ui
from hologram import JsonSchemaMixin, ValidationError
from hologram.helpers import HyphenatedJsonSchemaMixin, register_pattern, \
ExtensibleJsonSchemaMixin
from dbt.dataclass_schema import (
dbtClassMixin, ValidationError,
HyphenatedDbtClassMixin,
ExtensibleDbtClassMixin,
register_pattern, ValidatedStringMixin
)
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Union, Any, NewType
from typing import Optional, List, Dict, Union, Any
from mashumaro.types import SerializableType
PIN_PACKAGE_URL = 'https://docs.getdbt.com/docs/package-management#section-specifying-package-versions' # noqa
DEFAULT_SEND_ANONYMOUS_USAGE_STATS = True
Name = NewType('Name', str)
class Name(ValidatedStringMixin):
ValidationRegex = r'^[^\d\W]\w*$'
register_pattern(Name, r'^[^\d\W]\w*$')
class SemverString(str, SerializableType):
def _serialize(self) -> str:
return self
@classmethod
def _deserialize(cls, value: str) -> 'SemverString':
return SemverString(value)
# this does not support the full semver (does not allow a trailing -fooXYZ) and
# is not restrictive enough for full semver, (allows '1.0'). But it's like
# 'semver lite'.
SemverString = NewType('SemverString', str)
register_pattern(
SemverString,
r'^(?:0|[1-9]\d*)\.(?:0|[1-9]\d*)(\.(?:0|[1-9]\d*))?$',
@@ -30,15 +44,15 @@ register_pattern(
@dataclass
class Quoting(JsonSchemaMixin, Mergeable):
identifier: Optional[bool]
schema: Optional[bool]
database: Optional[bool]
project: Optional[bool]
class Quoting(dbtClassMixin, Mergeable):
schema: Optional[bool] = None
database: Optional[bool] = None
project: Optional[bool] = None
identifier: Optional[bool] = None
@dataclass
class Package(Replaceable, HyphenatedJsonSchemaMixin):
class Package(Replaceable, HyphenatedDbtClassMixin):
pass
@@ -54,7 +68,7 @@ RawVersion = Union[str, float]
@dataclass
class GitPackage(Package):
git: str
revision: Optional[RawVersion]
revision: Optional[RawVersion] = None
warn_unpinned: Optional[bool] = None
def get_revisions(self) -> List[str]:
@@ -80,7 +94,7 @@ PackageSpec = Union[LocalPackage, GitPackage, RegistryPackage]
@dataclass
class PackageConfig(JsonSchemaMixin, Replaceable):
class PackageConfig(dbtClassMixin, Replaceable):
packages: List[PackageSpec]
@@ -96,13 +110,13 @@ class ProjectPackageMetadata:
@dataclass
class Downloads(ExtensibleJsonSchemaMixin, Replaceable):
class Downloads(ExtensibleDbtClassMixin, Replaceable):
tarball: str
@dataclass
class RegistryPackageMetadata(
ExtensibleJsonSchemaMixin,
ExtensibleDbtClassMixin,
ProjectPackageMetadata,
):
downloads: Downloads
@@ -154,7 +168,7 @@ BANNED_PROJECT_NAMES = {
@dataclass
class Project(HyphenatedJsonSchemaMixin, Replaceable):
class Project(HyphenatedDbtClassMixin, Replaceable):
name: Name
version: Union[SemverString, float]
config_version: int
@@ -191,18 +205,16 @@ class Project(HyphenatedJsonSchemaMixin, Replaceable):
query_comment: Optional[Union[QueryComment, NoValue, str]] = NoValue()
@classmethod
def from_dict(cls, data, validate=True) -> 'Project':
result = super().from_dict(data, validate=validate)
if result.name in BANNED_PROJECT_NAMES:
def validate(cls, data):
super().validate(data)
if data['name'] in BANNED_PROJECT_NAMES:
raise ValidationError(
f'Invalid project name: {result.name} is a reserved word'
f"Invalid project name: {data['name']} is a reserved word"
)
return result
@dataclass
class UserConfig(ExtensibleJsonSchemaMixin, Replaceable, UserConfigContract):
class UserConfig(ExtensibleDbtClassMixin, Replaceable, UserConfigContract):
send_anonymous_usage_stats: bool = DEFAULT_SEND_ANONYMOUS_USAGE_STATS
use_colors: Optional[bool] = None
partial_parse: Optional[bool] = None
@@ -222,7 +234,7 @@ class UserConfig(ExtensibleJsonSchemaMixin, Replaceable, UserConfigContract):
@dataclass
class ProfileConfig(HyphenatedJsonSchemaMixin, Replaceable):
class ProfileConfig(HyphenatedDbtClassMixin, Replaceable):
profile_name: str = field(metadata={'preserve_underscore': True})
target_name: str = field(metadata={'preserve_underscore': True})
config: UserConfig
@@ -233,10 +245,10 @@ class ProfileConfig(HyphenatedJsonSchemaMixin, Replaceable):
@dataclass
class ConfiguredQuoting(Quoting, Replaceable):
identifier: bool
schema: bool
database: Optional[bool]
project: Optional[bool]
identifier: bool = True
schema: bool = True
database: Optional[bool] = None
project: Optional[bool] = None
@dataclass
@@ -249,5 +261,5 @@ class Configuration(Project, ProfileConfig):
@dataclass
class ProjectList(JsonSchemaMixin):
class ProjectList(dbtClassMixin):
projects: Dict[str, Project]

View File

@@ -1,12 +1,11 @@
from collections.abc import Mapping
from dataclasses import dataclass, fields
from typing import (
Optional, TypeVar, Generic, Dict,
Optional, Dict,
)
from typing_extensions import Protocol
from hologram import JsonSchemaMixin
from hologram.helpers import StrEnum
from dbt.dataclass_schema import dbtClassMixin, StrEnum
from dbt import deprecations
from dbt.contracts.util import Replaceable
@@ -32,7 +31,7 @@ class HasQuoting(Protocol):
quoting: Dict[str, bool]
class FakeAPIObject(JsonSchemaMixin, Replaceable, Mapping):
class FakeAPIObject(dbtClassMixin, Replaceable, Mapping):
# override the mapping truthiness, len is always >1
def __bool__(self):
return True
@@ -58,16 +57,13 @@ class FakeAPIObject(JsonSchemaMixin, Replaceable, Mapping):
return self.from_dict(value)
T = TypeVar('T')
@dataclass
class _ComponentObject(FakeAPIObject, Generic[T]):
database: T
schema: T
identifier: T
class Policy(FakeAPIObject):
database: bool = True
schema: bool = True
identifier: bool = True
def get_part(self, key: ComponentName) -> T:
def get_part(self, key: ComponentName) -> bool:
if key == ComponentName.Database:
return self.database
elif key == ComponentName.Schema:
@@ -80,25 +76,18 @@ class _ComponentObject(FakeAPIObject, Generic[T]):
.format(key, list(ComponentName))
)
def replace_dict(self, dct: Dict[ComponentName, T]):
kwargs: Dict[str, T] = {}
def replace_dict(self, dct: Dict[ComponentName, bool]):
kwargs: Dict[str, bool] = {}
for k, v in dct.items():
kwargs[str(k)] = v
return self.replace(**kwargs)
@dataclass
class Policy(_ComponentObject[bool]):
database: bool = True
schema: bool = True
identifier: bool = True
@dataclass
class Path(_ComponentObject[Optional[str]]):
database: Optional[str]
schema: Optional[str]
identifier: Optional[str]
class Path(FakeAPIObject):
database: Optional[str] = None
schema: Optional[str] = None
identifier: Optional[str] = None
def __post_init__(self):
# handle pesky jinja2.Undefined sneaking in here and messing up rende
@@ -120,3 +109,22 @@ class Path(_ComponentObject[Optional[str]]):
if part is not None:
part = part.lower()
return part
def get_part(self, key: ComponentName) -> Optional[str]:
if key == ComponentName.Database:
return self.database
elif key == ComponentName.Schema:
return self.schema
elif key == ComponentName.Identifier:
return self.identifier
else:
raise ValueError(
'Got a key of {}, expected one of {}'
.format(key, list(ComponentName))
)
def replace_dict(self, dct: Dict[ComponentName, str]):
kwargs: Dict[str, str] = {}
for k, v in dct.items():
kwargs[str(k)] = v
return self.replace(**kwargs)

View File

@@ -17,20 +17,21 @@ from dbt.logger import (
GLOBAL_LOGGER as logger,
)
from dbt.utils import lowercase
from hologram.helpers import StrEnum
from hologram import JsonSchemaMixin
from dbt.dataclass_schema import dbtClassMixin, StrEnum
import agate
from dataclasses import dataclass, field
from datetime import datetime
from typing import Union, Dict, List, Optional, Any, NamedTuple, Sequence
from typing import (
Union, Dict, List, Optional, Any, NamedTuple, Sequence,
)
from dbt.clients.system import write_json
@dataclass
class TimingInfo(JsonSchemaMixin):
class TimingInfo(dbtClassMixin):
name: str
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
@@ -87,13 +88,20 @@ class FreshnessStatus(StrEnum):
@dataclass
class BaseResult(JsonSchemaMixin):
class BaseResult(dbtClassMixin):
status: Union[RunStatus, TestStatus, FreshnessStatus]
timing: List[TimingInfo]
thread_id: str
execution_time: float
message: Optional[Union[str, int]]
adapter_response: Dict[str, Any]
message: Optional[Union[str, int]]
@classmethod
def __pre_deserialize__(cls, data, options=None):
data = super().__pre_deserialize__(data, options=options)
if 'message' not in data:
data['message'] = None
return data
@dataclass
@@ -103,7 +111,11 @@ class NodeResult(BaseResult):
@dataclass
class RunResult(NodeResult):
agate_table: Optional[agate.Table] = None
agate_table: Optional[agate.Table] = field(
default=None, metadata={
'serialize': lambda x: None, 'deserialize': lambda x: None
}
)
@property
def skipped(self):
@@ -111,7 +123,7 @@ class RunResult(NodeResult):
@dataclass
class ExecutionResult(JsonSchemaMixin):
class ExecutionResult(dbtClassMixin):
results: Sequence[BaseResult]
elapsed_time: float
@@ -193,8 +205,8 @@ class RunResultsArtifact(ExecutionResult, ArtifactMixin):
args=args
)
def write(self, path: str, omit_none=False):
write_json(path, self.to_dict(omit_none=omit_none))
def write(self, path: str):
write_json(path, self.to_dict(options={'keep_none': True}))
@dataclass
@@ -253,14 +265,14 @@ class FreshnessErrorEnum(StrEnum):
@dataclass
class SourceFreshnessRuntimeError(JsonSchemaMixin):
class SourceFreshnessRuntimeError(dbtClassMixin):
unique_id: str
error: Optional[Union[str, int]]
status: FreshnessErrorEnum
@dataclass
class SourceFreshnessOutput(JsonSchemaMixin):
class SourceFreshnessOutput(dbtClassMixin):
unique_id: str
max_loaded_at: datetime
snapshotted_at: datetime
@@ -374,40 +386,40 @@ CatalogKey = NamedTuple(
@dataclass
class StatsItem(JsonSchemaMixin):
class StatsItem(dbtClassMixin):
id: str
label: str
value: Primitive
description: Optional[str]
include: bool
description: Optional[str] = None
StatsDict = Dict[str, StatsItem]
@dataclass
class ColumnMetadata(JsonSchemaMixin):
class ColumnMetadata(dbtClassMixin):
type: str
comment: Optional[str]
index: int
name: str
comment: Optional[str] = None
ColumnMap = Dict[str, ColumnMetadata]
@dataclass
class TableMetadata(JsonSchemaMixin):
class TableMetadata(dbtClassMixin):
type: str
database: Optional[str]
schema: str
name: str
comment: Optional[str]
owner: Optional[str]
database: Optional[str] = None
comment: Optional[str] = None
owner: Optional[str] = None
@dataclass
class CatalogTable(JsonSchemaMixin, Replaceable):
class CatalogTable(dbtClassMixin, Replaceable):
metadata: TableMetadata
columns: ColumnMap
stats: StatsDict
@@ -430,12 +442,18 @@ class CatalogMetadata(BaseArtifactMetadata):
@dataclass
class CatalogResults(JsonSchemaMixin):
class CatalogResults(dbtClassMixin):
nodes: Dict[str, CatalogTable]
sources: Dict[str, CatalogTable]
errors: Optional[List[str]]
errors: Optional[List[str]] = None
_compile_results: Optional[Any] = None
def __post_serialize__(self, dct, options=None):
dct = super().__post_serialize__(dct, options=options)
if '_compile_results' in dct:
del dct['_compile_results']
return dct
@dataclass
@schema_version('catalog', 1)

View File

@@ -5,8 +5,7 @@ from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Optional, Union, List, Any, Dict, Type, Sequence
from hologram import JsonSchemaMixin
from hologram.helpers import StrEnum
from dbt.dataclass_schema import dbtClassMixin, StrEnum
from dbt.contracts.graph.compiled import CompileResultNode
from dbt.contracts.graph.manifest import WritableManifest
@@ -34,16 +33,25 @@ TaskID = uuid.UUID
@dataclass
class RPCParameters(JsonSchemaMixin):
timeout: Optional[float]
class RPCParameters(dbtClassMixin):
task_tags: TaskTags
timeout: Optional[float]
@classmethod
def __pre_deserialize__(cls, data, options=None):
data = super().__pre_deserialize__(data, options=options)
if 'timeout' not in data:
data['timeout'] = None
if 'task_tags' not in data:
data['task_tags'] = None
return data
@dataclass
class RPCExecParameters(RPCParameters):
name: str
sql: str
macros: Optional[str]
macros: Optional[str] = None
@dataclass
@@ -132,7 +140,7 @@ class StatusParameters(RPCParameters):
@dataclass
class GCSettings(JsonSchemaMixin):
class GCSettings(dbtClassMixin):
# start evicting the longest-ago-ended tasks here
maxsize: int
# start evicting all tasks before now - auto_reap_age when we have this
@@ -254,7 +262,7 @@ class RemoteExecutionResult(ExecutionResult, RemoteResult):
@dataclass
class ResultTable(JsonSchemaMixin):
class ResultTable(dbtClassMixin):
column_names: List[str]
rows: List[Any]
@@ -411,21 +419,31 @@ class TaskHandlerState(StrEnum):
@dataclass
class TaskTiming(JsonSchemaMixin):
class TaskTiming(dbtClassMixin):
state: TaskHandlerState
start: Optional[datetime]
end: Optional[datetime]
elapsed: Optional[float]
# These ought to be defaults but superclass order doesn't
# allow that to work
@classmethod
def __pre_deserialize__(cls, data, options=None):
data = super().__pre_deserialize__(data, options=options)
for field_name in ('start', 'end', 'elapsed'):
if field_name not in data:
data[field_name] = None
return data
@dataclass
class TaskRow(TaskTiming):
task_id: TaskID
request_id: Union[str, int]
request_source: str
method: str
timeout: Optional[float]
tags: TaskTags
request_id: Union[str, int]
tags: TaskTags = None
timeout: Optional[float] = None
@dataclass
@@ -451,7 +469,7 @@ class KillResult(RemoteResult):
@dataclass
@schema_version('remote-manifest-result', 1)
class GetManifestResult(RemoteResult):
manifest: Optional[WritableManifest]
manifest: Optional[WritableManifest] = None
# this is kind of carefuly structured: BlocksManifestTasks is implied by
@@ -475,6 +493,16 @@ class PollResult(RemoteResult, TaskTiming):
end: Optional[datetime]
elapsed: Optional[float]
# These ought to be defaults but superclass order doesn't
# allow that to work
@classmethod
def __pre_deserialize__(cls, data, options=None):
data = super().__pre_deserialize__(data, options=options)
for field_name in ('start', 'end', 'elapsed'):
if field_name not in data:
data[field_name] = None
return data
@dataclass
@schema_version('poll-remote-deps-result', 1)

View File

@@ -1,18 +1,18 @@
from dataclasses import dataclass
from hologram import JsonSchemaMixin
from dbt.dataclass_schema import dbtClassMixin
from typing import List, Dict, Any, Union
@dataclass
class SelectorDefinition(JsonSchemaMixin):
class SelectorDefinition(dbtClassMixin):
name: str
definition: Union[str, Dict[str, Any]]
description: str = ''
@dataclass
class SelectorFile(JsonSchemaMixin):
class SelectorFile(dbtClassMixin):
selectors: List[SelectorDefinition]
version: int = 2

View File

@@ -7,13 +7,12 @@ from typing import (
from dbt.clients.system import write_json, read_json
from dbt.exceptions import (
IncompatibleSchemaException,
InternalException,
RuntimeException,
)
from dbt.version import __version__
from dbt.tracking import get_invocation_id
from hologram import JsonSchemaMixin
from dbt.dataclass_schema import dbtClassMixin
MacroKey = Tuple[str, str]
SourceKey = Tuple[str, str]
@@ -57,8 +56,10 @@ class Mergeable(Replaceable):
class Writable:
def write(self, path: str, omit_none: bool = False):
write_json(path, self.to_dict(omit_none=omit_none)) # type: ignore
def write(self, path: str):
write_json(
path, self.to_dict(options={'keep_none': True}) # type: ignore
)
class AdditionalPropertiesMixin:
@@ -69,22 +70,41 @@ class AdditionalPropertiesMixin:
"""
ADDITIONAL_PROPERTIES = True
# This takes attributes in the dictionary that are
# not in the class definitions and puts them in an
# _extra dict in the class
@classmethod
def from_dict(cls, data, validate=True):
self = super().from_dict(data=data, validate=validate)
keys = self.to_dict(validate=False, omit_none=False)
def __pre_deserialize__(cls, data, options=None):
# dir() did not work because fields with
# metadata settings are not found
# The original version of this would create the
# object first and then update extra with the
# extra keys, but that won't work here, so
# we're copying the dict so we don't insert the
# _extra in the original data. This also requires
# that Mashumaro actually build the '_extra' field
cls_keys = cls._get_field_names()
new_dict = {}
for key, value in data.items():
if key not in keys:
self.extra[key] = value
return self
if key not in cls_keys and key != '_extra':
if '_extra' not in new_dict:
new_dict['_extra'] = {}
new_dict['_extra'][key] = value
else:
new_dict[key] = value
data = new_dict
data = super().__pre_deserialize__(data, options=options)
return data
def to_dict(self, omit_none=True, validate=False):
data = super().to_dict(omit_none=omit_none, validate=validate)
def __post_serialize__(self, dct, options=None):
data = super().__post_serialize__(dct, options=options)
data.update(self.extra)
if '_extra' in data:
del data['_extra']
return data
def replace(self, **kwargs):
dct = self.to_dict(omit_none=False, validate=False)
dct = self.to_dict(options={'keep_none': True})
dct.update(kwargs)
return self.from_dict(dct)
@@ -135,7 +155,7 @@ def get_metadata_env() -> Dict[str, str]:
@dataclasses.dataclass
class BaseArtifactMetadata(JsonSchemaMixin):
class BaseArtifactMetadata(dbtClassMixin):
dbt_schema_version: str
dbt_version: str = __version__
generated_at: datetime = dataclasses.field(
@@ -158,7 +178,7 @@ def schema_version(name: str, version: int):
@dataclasses.dataclass
class VersionedSchema(JsonSchemaMixin):
class VersionedSchema(dbtClassMixin):
dbt_schema_version: ClassVar[SchemaVersion]
@classmethod
@@ -180,18 +200,9 @@ class ArtifactMixin(VersionedSchema, Writable, Readable):
metadata: BaseArtifactMetadata
@classmethod
def from_dict(
cls: Type[T], data: Dict[str, Any], validate: bool = True
) -> T:
def validate(cls, data):
super().validate(data)
if cls.dbt_schema_version is None:
raise InternalException(
'Cannot call from_dict with no schema version!'
)
if validate:
expected = str(cls.dbt_schema_version)
found = data.get('metadata', {}).get(SCHEMA_VERSION_KEY)
if found != expected:
raise IncompatibleSchemaException(expected, found)
return super().from_dict(data=data, validate=validate)

View File

@@ -0,0 +1,170 @@
from typing import (
Type, ClassVar, Dict, cast, TypeVar
)
import re
from dataclasses import fields
from enum import Enum
from datetime import datetime
from dateutil.parser import parse
from hologram import JsonSchemaMixin, FieldEncoder, ValidationError
from mashumaro import DataClassDictMixin
from mashumaro.types import SerializableEncoder, SerializableType
class DateTimeSerializableEncoder(SerializableEncoder[datetime]):
@classmethod
def _serialize(cls, value: datetime) -> str:
out = value.isoformat()
# Assume UTC if timezone is missing
if value.tzinfo is None:
out = out + "Z"
return out
@classmethod
def _deserialize(cls, value: str) -> datetime:
return (
value if isinstance(value, datetime) else parse(cast(str, value))
)
TV = TypeVar("TV")
# This class pulls in both JsonSchemaMixin from Hologram and
# DataClassDictMixin from our fork of Mashumaro. The 'to_dict'
# and 'from_dict' methods come from Mashumaro. Building
# jsonschemas for every class and the 'validate' method
# come from Hologram.
class dbtClassMixin(DataClassDictMixin, JsonSchemaMixin):
"""Mixin which adds methods to generate a JSON schema and
convert to and from JSON encodable dicts with validation
against the schema
"""
_serializable_encoders: ClassVar[Dict[str, SerializableEncoder]] = {
'datetime.datetime': DateTimeSerializableEncoder(),
}
_hyphenated: ClassVar[bool] = False
ADDITIONAL_PROPERTIES: ClassVar[bool] = False
# This is called by the mashumaro to_dict in order to handle
# nested classes.
# Munges the dict that's returned.
def __post_serialize__(self, dct, options=None):
keep_none = False
if options and 'keep_none' in options and options['keep_none']:
keep_none = True
if not keep_none: # remove attributes that are None
new_dict = {k: v for k, v in dct.items() if v is not None}
dct = new_dict
if self._hyphenated:
new_dict = {}
for key in dct:
if '_' in key:
new_key = key.replace('_', '-')
new_dict[new_key] = dct[key]
else:
new_dict[key] = dct[key]
dct = new_dict
return dct
# This is called by the mashumaro _from_dict method, before
# performing the conversion to a dict
@classmethod
def __pre_deserialize__(cls, data, options=None):
if cls._hyphenated:
new_dict = {}
for key in data:
if '-' in key:
new_key = key.replace('-', '_')
new_dict[new_key] = data[key]
else:
new_dict[key] = data[key]
data = new_dict
return data
# This is used in the hologram._encode_field method, which calls
# a 'to_dict' method which does not have the same parameters in
# hologram and in mashumaro.
def _local_to_dict(self, **kwargs):
args = {}
if 'omit_none' in kwargs and kwargs['omit_none'] is False:
args['options'] = {'keep_none': True}
return self.to_dict(**args)
class ValidatedStringMixin(str, SerializableType):
ValidationRegex = ''
@classmethod
def _deserialize(cls, value: str) -> 'ValidatedStringMixin':
cls.validate(value)
return ValidatedStringMixin(value)
def _serialize(self) -> str:
return str(self)
@classmethod
def validate(cls, value):
res = re.match(cls.ValidationRegex, value)
if res is None:
raise ValidationError(f"Invalid value: {value}") # TODO
# These classes must be in this order or it doesn't work
class StrEnum(str, SerializableType, Enum):
def __str__(self):
return self.value
# https://docs.python.org/3.6/library/enum.html#using-automatic-values
def _generate_next_value_(name, *_):
return name
def _serialize(self) -> str:
return self.value
@classmethod
def _deserialize(cls, value: str):
return cls(value)
class HyphenatedDbtClassMixin(dbtClassMixin):
# used by from_dict/to_dict
_hyphenated: ClassVar[bool] = True
# used by jsonschema validation, _get_fields
@classmethod
def field_mapping(cls):
result = {}
for field in fields(cls):
skip = field.metadata.get("preserve_underscore")
if skip:
continue
if "_" in field.name:
result[field.name] = field.name.replace("_", "-")
return result
class ExtensibleDbtClassMixin(dbtClassMixin):
ADDITIONAL_PROPERTIES = True
# This is used by Hologram in jsonschema validation
def register_pattern(base_type: Type, pattern: str) -> None:
"""base_type should be a typing.NewType that should always have the given
regex pattern. That means that its underlying type ('__supertype__') had
better be a str!
"""
class PatternEncoder(FieldEncoder):
@property
def json_schema(self):
return {"type": "string", "pattern": pattern}
dbtClassMixin.register_field_encoders({base_type: PatternEncoder()})

View File

@@ -7,14 +7,14 @@ from dbt.node_types import NodeType
from dbt import flags
from dbt.ui import line_wrap_message
import hologram
import dbt.dataclass_schema
def validator_error_message(exc):
"""Given a hologram.ValidationError (which is basically a
"""Given a dbt.dataclass_schema.ValidationError (which is basically a
jsonschema.ValidationError), return the relevant parts as a string
"""
if not isinstance(exc, hologram.ValidationError):
if not isinstance(exc, dbt.dataclass_schema.ValidationError):
return str(exc)
path = "[%s]" % "][".join(map(repr, exc.relative_path))
return 'at path {}: {}'.format(path, exc.message)

View File

@@ -1,6 +1,6 @@
# special support for CLI argument parsing.
import itertools
import yaml
from dbt.clients.yaml_helper import yaml, Loader, Dumper # noqa: F401
from typing import (
Dict, List, Optional, Tuple, Any, Union
@@ -236,7 +236,7 @@ def parse_dict_definition(definition: Dict[str, Any]) -> SelectionSpec:
)
# if key isn't a valid method name, this will raise
base = SelectionCriteria.from_dict(definition, dct)
base = SelectionCriteria.selection_criteria_from_dict(definition, dct)
if diff_arg is None:
return base
else:

View File

@@ -3,7 +3,7 @@ from itertools import chain
from pathlib import Path
from typing import Set, List, Dict, Iterator, Tuple, Any, Union, Type, Optional
from hologram.helpers import StrEnum
from dbt.dataclass_schema import StrEnum
from .graph import UniqueId

View File

@@ -102,7 +102,9 @@ class SelectionCriteria:
return method_name, method_arguments
@classmethod
def from_dict(cls, raw: Any, dct: Dict[str, Any]) -> 'SelectionCriteria':
def selection_criteria_from_dict(
cls, raw: Any, dct: Dict[str, Any]
) -> 'SelectionCriteria':
if 'value' not in dct:
raise RuntimeException(
f'Invalid node spec "{raw}" - no search value!'
@@ -150,7 +152,7 @@ class SelectionCriteria:
# bad spec!
raise RuntimeException(f'Invalid selector spec "{raw}"')
return cls.from_dict(raw, result.groupdict())
return cls.selection_criteria_from_dict(raw, result.groupdict())
class BaseSelectionGroup(Iterable[SelectionSpec], metaclass=ABCMeta):

View File

@@ -2,14 +2,27 @@
from dataclasses import dataclass
from datetime import timedelta
from pathlib import Path
from typing import NewType, Tuple, AbstractSet
from typing import Tuple, AbstractSet, Union
from hologram import (
FieldEncoder, JsonSchemaMixin, JsonDict, ValidationError
from dbt.dataclass_schema import (
dbtClassMixin, ValidationError, StrEnum,
)
from hologram.helpers import StrEnum
from hologram import FieldEncoder, JsonDict
from mashumaro.types import SerializableType
Port = NewType('Port', int)
class Port(int, SerializableType):
@classmethod
def _deserialize(cls, value: Union[int, str]) -> 'Port':
try:
value = int(value)
except ValueError:
raise ValidationError(f'Cannot encode {value} into port number')
return Port(value)
def _serialize(self) -> int:
return self
class PortEncoder(FieldEncoder):
@@ -66,12 +79,12 @@ class NVEnum(StrEnum):
@dataclass
class NoValue(JsonSchemaMixin):
class NoValue(dbtClassMixin):
"""Sometimes, you want a way to say none that isn't None"""
novalue: NVEnum = NVEnum.novalue
JsonSchemaMixin.register_field_encoders({
dbtClassMixin.register_field_encoders({
Port: PortEncoder(),
timedelta: TimeDeltaFieldEncoder(),
Path: PathEncoder(),

View File

@@ -1,4 +1,4 @@
from hologram.helpers import StrEnum
from dbt.dataclass_schema import StrEnum
import json
from typing import Union, Dict, Any

View File

@@ -13,7 +13,7 @@ from typing import Optional, List, ContextManager, Callable, Dict, Any, Set
import colorama
import logbook
from hologram import JsonSchemaMixin
from dbt.dataclass_schema import dbtClassMixin
# Colorama needs some help on windows because we're using logger.info
# intead of print(). If the Windows env doesn't have a TERM var set,
@@ -45,11 +45,10 @@ DEBUG_LOG_FORMAT = (
ExceptionInformation = str
Extras = Dict[str, Any]
@dataclass
class LogMessage(JsonSchemaMixin):
class LogMessage(dbtClassMixin):
timestamp: datetime
message: str
channel: str
@@ -57,7 +56,7 @@ class LogMessage(JsonSchemaMixin):
levelname: str
thread_name: str
process: int
extra: Optional[Extras] = None
extra: Optional[Dict[str, Any]] = None
exc_info: Optional[ExceptionInformation] = None
@classmethod
@@ -215,7 +214,7 @@ class TextOnly(logbook.Processor):
class TimingProcessor(logbook.Processor):
def __init__(self, timing_info: Optional[JsonSchemaMixin] = None):
def __init__(self, timing_info: Optional[dbtClassMixin] = None):
self.timing_info = timing_info
super().__init__()

View File

@@ -1,6 +1,6 @@
from typing import List
from hologram.helpers import StrEnum
from dbt.dataclass_schema import StrEnum
class NodeType(StrEnum):

View File

@@ -13,7 +13,9 @@ class AnalysisParser(SimpleSQLParser[ParsedAnalysisNode]):
)
def parse_from_dict(self, dct, validate=True) -> ParsedAnalysisNode:
return ParsedAnalysisNode.from_dict(dct, validate=validate)
if validate:
ParsedAnalysisNode.validate(dct)
return ParsedAnalysisNode.from_dict(dct)
@property
def resource_type(self) -> NodeType:

View File

@@ -5,7 +5,7 @@ from typing import (
List, Dict, Any, Iterable, Generic, TypeVar
)
from hologram import ValidationError
from dbt.dataclass_schema import ValidationError
from dbt import utils
from dbt.clients.jinja import MacroGenerator
@@ -23,7 +23,7 @@ from dbt.context.context_config import (
from dbt.contracts.files import (
SourceFile, FilePath, FileHash
)
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.manifest import MacroManifest
from dbt.contracts.graph.parsed import HasUniqueID
from dbt.contracts.graph.unparsed import UnparsedNode
from dbt.exceptions import (
@@ -99,7 +99,7 @@ class Parser(BaseParser[FinalValue], Generic[FinalValue]):
results: ParseResult,
project: Project,
root_project: RuntimeConfig,
macro_manifest: Manifest,
macro_manifest: MacroManifest,
) -> None:
super().__init__(results, project)
self.root_project = root_project
@@ -108,9 +108,10 @@ class Parser(BaseParser[FinalValue], Generic[FinalValue]):
class RelationUpdate:
def __init__(
self, config: RuntimeConfig, manifest: Manifest, component: str
self, config: RuntimeConfig, macro_manifest: MacroManifest,
component: str
) -> None:
macro = manifest.find_generate_macro_by_name(
macro = macro_manifest.find_generate_macro_by_name(
component=component,
root_project_name=config.project_name,
)
@@ -120,7 +121,7 @@ class RelationUpdate:
)
root_context = generate_generate_component_name_macro(
macro, config, manifest
macro, config, macro_manifest
)
self.updater = MacroGenerator(macro, root_context)
self.component = component
@@ -144,18 +145,21 @@ class ConfiguredParser(
results: ParseResult,
project: Project,
root_project: RuntimeConfig,
macro_manifest: Manifest,
macro_manifest: MacroManifest,
) -> None:
super().__init__(results, project, root_project, macro_manifest)
self._update_node_database = RelationUpdate(
manifest=macro_manifest, config=root_project, component='database'
macro_manifest=macro_manifest, config=root_project,
component='database'
)
self._update_node_schema = RelationUpdate(
manifest=macro_manifest, config=root_project, component='schema'
macro_manifest=macro_manifest, config=root_project,
component='schema'
)
self._update_node_alias = RelationUpdate(
manifest=macro_manifest, config=root_project, component='alias'
macro_manifest=macro_manifest, config=root_project,
component='alias'
)
@abc.abstractclassmethod
@@ -252,7 +256,7 @@ class ConfiguredParser(
}
dct.update(kwargs)
try:
return self.parse_from_dict(dct)
return self.parse_from_dict(dct, validate=True)
except ValidationError as exc:
msg = validator_error_message(exc)
# this is a bit silly, but build an UnparsedNode just for error
@@ -275,20 +279,24 @@ class ConfiguredParser(
def render_with_context(
self, parsed_node: IntermediateNode, config: ContextConfig
) -> None:
"""Given the parsed node and a ContextConfig to use during parsing,
render the node's sql wtih macro capture enabled.
# Given the parsed node and a ContextConfig to use during parsing,
# render the node's sql wtih macro capture enabled.
# Note: this mutates the config object when config calls are rendered.
Note: this mutates the config object when config() calls are rendered.
"""
# during parsing, we don't have a connection, but we might need one, so
# we have to acquire it.
with get_adapter(self.root_project).connection_for(parsed_node):
context = self._context_for(parsed_node, config)
# this goes through the process of rendering, but just throws away
# the rendered result. The "macro capture" is the point?
get_rendered(
parsed_node.raw_sql, context, parsed_node, capture_macros=True
)
# This is taking the original config for the node, converting it to a dict,
# updating the config with new config passed in, then re-creating the
# config from the dict in the node.
def update_parsed_node_config(
self, parsed_node: IntermediateNode, config_dict: Dict[str, Any]
) -> None:

View File

@@ -12,7 +12,9 @@ class DataTestParser(SimpleSQLParser[ParsedDataTestNode]):
)
def parse_from_dict(self, dct, validate=True) -> ParsedDataTestNode:
return ParsedDataTestNode.from_dict(dct, validate=validate)
if validate:
ParsedDataTestNode.validate(dct)
return ParsedDataTestNode.from_dict(dct)
@property
def resource_type(self) -> NodeType:

View File

@@ -79,7 +79,9 @@ class HookParser(SimpleParser[HookBlock, ParsedHookNode]):
return [path]
def parse_from_dict(self, dct, validate=True) -> ParsedHookNode:
return ParsedHookNode.from_dict(dct, validate=validate)
if validate:
ParsedHookNode.validate(dct)
return ParsedHookNode.from_dict(dct)
@classmethod
def get_compiled_path(cls, block: HookBlock):

View File

@@ -23,7 +23,9 @@ from dbt.config import Project, RuntimeConfig
from dbt.context.docs import generate_runtime_docs
from dbt.contracts.files import FilePath, FileHash
from dbt.contracts.graph.compiled import ManifestNode
from dbt.contracts.graph.manifest import Manifest, Disabled
from dbt.contracts.graph.manifest import (
Manifest, MacroManifest, AnyManifest, Disabled
)
from dbt.contracts.graph.parsed import (
ParsedSourceDefinition, ParsedNode, ParsedMacro, ColumnInfo, ParsedExposure
)
@@ -51,7 +53,7 @@ from dbt.parser.sources import patch_sources
from dbt.ui import warning_tag
from dbt.version import __version__
from hologram import JsonSchemaMixin
from dbt.dataclass_schema import dbtClassMixin
PARTIAL_PARSE_FILE_NAME = 'partial_parse.pickle'
PARSING_STATE = DbtProcessState('parsing')
@@ -59,14 +61,14 @@ DEFAULT_PARTIAL_PARSE = False
@dataclass
class ParserInfo(JsonSchemaMixin):
class ParserInfo(dbtClassMixin):
parser: str
elapsed: float
path_count: int = 0
@dataclass
class ProjectLoaderInfo(JsonSchemaMixin):
class ProjectLoaderInfo(dbtClassMixin):
project_name: str
elapsed: float
parsers: List[ParserInfo]
@@ -74,7 +76,7 @@ class ProjectLoaderInfo(JsonSchemaMixin):
@dataclass
class ManifestLoaderInfo(JsonSchemaMixin, Writable):
class ManifestLoaderInfo(dbtClassMixin, Writable):
path_count: int = 0
is_partial_parse_enabled: Optional[bool] = None
parse_project_elapsed: Optional[float] = None
@@ -137,16 +139,19 @@ class ManifestLoader:
self,
root_project: RuntimeConfig,
all_projects: Mapping[str, Project],
macro_hook: Optional[Callable[[Manifest], Any]] = None,
macro_hook: Optional[Callable[[AnyManifest], Any]] = None,
) -> None:
self.root_project: RuntimeConfig = root_project
self.all_projects: Mapping[str, Project] = all_projects
self.macro_hook: Callable[[Manifest], Any]
self.macro_hook: Callable[[AnyManifest], Any]
if macro_hook is None:
self.macro_hook = lambda m: None
else:
self.macro_hook = macro_hook
# results holds all of the nodes created by parsing,
# in dictionaries: nodes, sources, docs, macros, exposures,
# macro_patches, patches, source_patches, files, etc
self.results: ParseResult = make_parse_result(
root_project, all_projects,
)
@@ -210,7 +215,7 @@ class ManifestLoader:
def parse_project(
self,
project: Project,
macro_manifest: Manifest,
macro_manifest: MacroManifest,
old_results: Optional[ParseResult],
) -> None:
parsers: List[Parser] = []
@@ -252,7 +257,7 @@ class ManifestLoader:
self._perf_info.path_count + total_path_count
)
def load_only_macros(self) -> Manifest:
def load_only_macros(self) -> MacroManifest:
old_results = self.read_parse_results()
for project in self.all_projects.values():
@@ -261,17 +266,20 @@ class ManifestLoader:
self.parse_with_cache(path, parser, old_results)
# make a manifest with just the macros to get the context
macro_manifest = Manifest.from_macros(
macro_manifest = MacroManifest(
macros=self.results.macros,
files=self.results.files
)
self.macro_hook(macro_manifest)
return macro_manifest
def load(self, macro_manifest: Manifest):
# This is where the main action happens
def load(self, macro_manifest: MacroManifest):
# if partial parse is enabled, load old results
old_results = self.read_parse_results()
if old_results is not None:
logger.debug('Got an acceptable cached parse result')
# store the macros & files from the adapter macro manifest
self.results.macros.update(macro_manifest.macros)
self.results.files.update(macro_manifest.files)
@@ -423,8 +431,8 @@ class ManifestLoader:
def load_all(
cls,
root_config: RuntimeConfig,
macro_manifest: Manifest,
macro_hook: Callable[[Manifest], Any],
macro_manifest: MacroManifest,
macro_hook: Callable[[AnyManifest], Any],
) -> Manifest:
with PARSING_STATE:
start_load_all = time.perf_counter()
@@ -449,8 +457,8 @@ class ManifestLoader:
def load_macros(
cls,
root_config: RuntimeConfig,
macro_hook: Callable[[Manifest], Any],
) -> Manifest:
macro_hook: Callable[[AnyManifest], Any],
) -> MacroManifest:
with PARSING_STATE:
projects = root_config.load_dependencies()
loader = cls(root_config, projects, macro_hook)
@@ -841,14 +849,14 @@ def process_node(
def load_macro_manifest(
config: RuntimeConfig,
macro_hook: Callable[[Manifest], Any],
) -> Manifest:
macro_hook: Callable[[AnyManifest], Any],
) -> MacroManifest:
return ManifestLoader.load_macros(config, macro_hook)
def load_manifest(
config: RuntimeConfig,
macro_manifest: Manifest,
macro_hook: Callable[[Manifest], Any],
macro_manifest: MacroManifest,
macro_hook: Callable[[AnyManifest], Any],
) -> Manifest:
return ManifestLoader.load_all(config, macro_manifest, macro_hook)

View File

@@ -11,7 +11,9 @@ class ModelParser(SimpleSQLParser[ParsedModelNode]):
)
def parse_from_dict(self, dct, validate=True) -> ParsedModelNode:
return ParsedModelNode.from_dict(dct, validate=validate)
if validate:
ParsedModelNode.validate(dct)
return ParsedModelNode.from_dict(dct)
@property
def resource_type(self) -> NodeType:

View File

@@ -1,7 +1,7 @@
from dataclasses import dataclass, field
from typing import TypeVar, MutableMapping, Mapping, Union, List
from hologram import JsonSchemaMixin
from dbt.dataclass_schema import dbtClassMixin
from dbt.contracts.files import RemoteFile, FileHash, SourceFile
from dbt.contracts.graph.compiled import CompileResultNode
@@ -62,7 +62,7 @@ def dict_field():
@dataclass
class ParseResult(JsonSchemaMixin, Writable, Replaceable):
class ParseResult(dbtClassMixin, Writable, Replaceable):
vars_hash: FileHash
profile_hash: FileHash
project_hashes: MutableMapping[str, FileHash]

View File

@@ -26,7 +26,9 @@ class RPCCallParser(SimpleSQLParser[ParsedRPCNode]):
return []
def parse_from_dict(self, dct, validate=True) -> ParsedRPCNode:
return ParsedRPCNode.from_dict(dct, validate=validate)
if validate:
ParsedRPCNode.validate(dct)
return ParsedRPCNode.from_dict(dct)
@property
def resource_type(self) -> NodeType:

View File

@@ -179,6 +179,7 @@ class TestBuilder(Generic[Testable]):
- or it may not be namespaced (test)
"""
# The 'test_name' is used to find the 'macro' that implements the test
TEST_NAME_PATTERN = re.compile(
r'((?P<test_namespace>([a-zA-Z_][0-9a-zA-Z_]*))\.)?'
r'(?P<test_name>([a-zA-Z_][0-9a-zA-Z_]*))'
@@ -302,6 +303,8 @@ class TestBuilder(Generic[Testable]):
name = '{}_{}'.format(self.namespace, name)
return get_nice_schema_test_name(name, self.target.name, self.args)
# this is the 'raw_sql' that's used in 'render_update' and execution
# of the test macro
def build_raw_sql(self) -> str:
return (
"{{{{ config(severity='{severity}') }}}}"

View File

@@ -6,9 +6,9 @@ from typing import (
Iterable, Dict, Any, Union, List, Optional, Generic, TypeVar, Type
)
from hologram import ValidationError, JsonSchemaMixin
from dbt.dataclass_schema import ValidationError, dbtClassMixin
from dbt.adapters.factory import get_adapter
from dbt.adapters.factory import get_adapter, get_adapter_package_names
from dbt.clients.jinja import get_rendered, add_rendered_test_kwargs
from dbt.clients.yaml_helper import load_yaml_text
from dbt.config.renderer import SchemaYamlRenderer
@@ -20,7 +20,10 @@ from dbt.context.context_config import (
)
from dbt.context.configured import generate_schema_yml
from dbt.context.target import generate_target_context
from dbt.context.providers import generate_parse_exposure
from dbt.context.providers import (
generate_parse_exposure, generate_test_context
)
from dbt.context.macro_resolver import MacroResolver
from dbt.contracts.files import FileHash
from dbt.contracts.graph.manifest import SourceFile
from dbt.contracts.graph.model_config import SourceConfig
@@ -173,6 +176,15 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
self.raw_renderer = SchemaYamlRenderer(ctx)
internal_package_names = get_adapter_package_names(
self.root_project.credentials.type
)
self.macro_resolver = MacroResolver(
self.macro_manifest.macros,
self.root_project.project_name,
internal_package_names
)
@classmethod
def get_compiled_path(cls, block: FileBlock) -> str:
# should this raise an error?
@@ -202,9 +214,11 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
)
def parse_from_dict(self, dct, validate=True) -> ParsedSchemaTestNode:
return ParsedSchemaTestNode.from_dict(dct, validate=validate)
if validate:
ParsedSchemaTestNode.validate(dct)
return ParsedSchemaTestNode.from_dict(dct)
def _parse_format_version(
def _check_format_version(
self, yaml: YamlBlock
) -> None:
path = yaml.path.relative_path
@@ -374,7 +388,8 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
'checksum': FileHash.empty().to_dict(),
}
try:
return self.parse_from_dict(dct)
ParsedSchemaTestNode.validate(dct)
return ParsedSchemaTestNode.from_dict(dct)
except ValidationError as exc:
msg = validator_error_message(exc)
# this is a bit silly, but build an UnparsedNode just for error
@@ -387,6 +402,7 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
)
raise CompilationException(msg, node=node) from exc
# lots of time spent in this method
def _parse_generic_test(
self,
target: Testable,
@@ -425,6 +441,7 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
# is not necessarily this package's name
fqn = self.get_fqn(fqn_path, builder.fqn_name)
# this is the config that is used in render_update
config = self.initial_config(fqn)
metadata = {
@@ -447,9 +464,53 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
column_name=column_name,
test_metadata=metadata,
)
self.render_update(node, config)
self.render_test_update(node, config, builder)
return node
# This does special shortcut processing for the two
# most common internal macros, not_null and unique,
# which avoids the jinja rendering to resolve config
# and variables, etc, which might be in the macro.
# In the future we will look at generalizing this
# more to handle additional macros or to use static
# parsing to avoid jinja overhead.
def render_test_update(self, node, config, builder):
macro_unique_id = self.macro_resolver.get_macro_id(
node.package_name, 'test_' + builder.name)
# Add the depends_on here so we can limit the macros added
# to the context in rendering processing
node.depends_on.add_macro(macro_unique_id)
if (macro_unique_id in
['macro.dbt.test_not_null', 'macro.dbt.test_unique']):
self.update_parsed_node(node, config)
node.unrendered_config['severity'] = builder.severity()
node.config['severity'] = builder.severity()
# source node tests are processed at patch_source time
if isinstance(builder.target, UnpatchedSourceDefinition):
sources = [builder.target.fqn[-2], builder.target.fqn[-1]]
node.sources.append(sources)
else: # all other nodes
node.refs.append([builder.target.name])
else:
try:
# make a base context that doesn't have the magic kwargs field
context = generate_test_context(
node, self.root_project, self.macro_manifest, config,
self.macro_resolver,
)
# update with rendered test kwargs (which collects any refs)
add_rendered_test_kwargs(context, node, capture_macros=True)
# the parsed node is not rendered in the native context.
get_rendered(
node.raw_sql, context, node, capture_macros=True
)
self.update_parsed_node(node, config)
except ValidationError as exc:
# we got a ValidationError - probably bad types in config()
msg = validator_error_message(exc)
raise CompilationException(msg, node=node) from exc
def parse_source_test(
self,
target: UnpatchedSourceDefinition,
@@ -561,10 +622,13 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
def parse_file(self, block: FileBlock) -> None:
dct = self._yaml_from_file(block.file)
# mark the file as seen, even if there are no macros in it
# mark the file as seen, in ParseResult.files
self.results.get_file(block.file)
if dct:
try:
# This does a deep_map to check for circular references
dct = self.raw_renderer.render_data(dct)
except CompilationException as exc:
raise CompilationException(
@@ -572,28 +636,58 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
f'project {self.project.project_name}: {exc}'
) from exc
# contains the FileBlock and the data (dictionary)
yaml_block = YamlBlock.from_file_block(block, dct)
self._parse_format_version(yaml_block)
# checks version
self._check_format_version(yaml_block)
parser: YamlDocsReader
for key in NodeType.documentable():
plural = key.pluralize()
if key == NodeType.Source:
parser = SourceParser(self, yaml_block, plural)
elif key == NodeType.Macro:
parser = MacroPatchParser(self, yaml_block, plural)
elif key == NodeType.Analysis:
parser = AnalysisPatchParser(self, yaml_block, plural)
elif key == NodeType.Exposure:
# handle exposures separately, but they are
# technically still "documentable"
continue
else:
parser = TestablePatchParser(self, yaml_block, plural)
# There are 7 kinds of parsers:
# Model, Seed, Snapshot, Source, Macro, Analysis, Exposures
# NonSourceParser.parse(), TestablePatchParser is a variety of
# NodePatchParser
if 'models' in dct:
parser = TestablePatchParser(self, yaml_block, 'models')
for test_block in parser.parse():
self.parse_tests(test_block)
self.parse_exposures(yaml_block)
# NonSourceParser.parse()
if 'seeds' in dct:
parser = TestablePatchParser(self, yaml_block, 'seeds')
for test_block in parser.parse():
self.parse_tests(test_block)
# NonSourceParser.parse()
if 'snapshots' in dct:
parser = TestablePatchParser(self, yaml_block, 'snapshots')
for test_block in parser.parse():
self.parse_tests(test_block)
# This parser uses SourceParser.parse() which doesn't return
# any test blocks. Source tests are handled at a later point
# in the process.
if 'sources' in dct:
parser = SourceParser(self, yaml_block, 'sources')
parser.parse()
# NonSourceParser.parse()
if 'macros' in dct:
parser = MacroPatchParser(self, yaml_block, 'macros')
for test_block in parser.parse():
self.parse_tests(test_block)
# NonSourceParser.parse()
if 'analyses' in dct:
parser = AnalysisPatchParser(self, yaml_block, 'analyses')
for test_block in parser.parse():
self.parse_tests(test_block)
# parse exposures
if 'exposures' in dct:
self.parse_exposures(yaml_block)
Parsed = TypeVar(
@@ -610,11 +704,14 @@ NonSourceTarget = TypeVar(
)
# abstract base class (ABCMeta)
class YamlReader(metaclass=ABCMeta):
def __init__(
self, schema_parser: SchemaParser, yaml: YamlBlock, key: str
) -> None:
self.schema_parser = schema_parser
# key: models, seeds, snapshots, sources, macros,
# analyses, exposures
self.key = key
self.yaml = yaml
@@ -634,6 +731,9 @@ class YamlReader(metaclass=ABCMeta):
def root_project(self):
return self.schema_parser.root_project
# for the different schema subparsers ('models', 'source', etc)
# get the list of dicts pointed to by the key in the yaml config,
# ensure that the dicts have string keys
def get_key_dicts(self) -> Iterable[Dict[str, Any]]:
data = self.yaml.data.get(self.key, [])
if not isinstance(data, list):
@@ -643,7 +743,10 @@ class YamlReader(metaclass=ABCMeta):
)
path = self.yaml.path.original_file_path
# for each dict in the data (which is a list of dicts)
for entry in data:
# check that entry is a dict and that all dict values
# are strings
if coerce_dict_str(entry) is not None:
yield entry
else:
@@ -659,19 +762,22 @@ class YamlDocsReader(YamlReader):
raise NotImplementedError('parse is abstract')
T = TypeVar('T', bound=JsonSchemaMixin)
T = TypeVar('T', bound=dbtClassMixin)
class SourceParser(YamlDocsReader):
def _target_from_dict(self, cls: Type[T], data: Dict[str, Any]) -> T:
path = self.yaml.path.original_file_path
try:
cls.validate(data)
return cls.from_dict(data)
except (ValidationError, JSONValidationException) as exc:
msg = error_context(path, self.key, data, exc)
raise CompilationException(msg) from exc
# the other parse method returns TestBlocks. This one doesn't.
def parse(self) -> List[TestBlock]:
# get a verified list of dicts for the key handled by this parser
for data in self.get_key_dicts():
data = self.project.credentials.translate_aliases(
data, recurse=True
@@ -714,10 +820,12 @@ class SourceParser(YamlDocsReader):
self.results.add_source(self.yaml.file, result)
# This class has three main subclasses: TestablePatchParser (models,
# seeds, snapshots), MacroPatchParser, and AnalysisPatchParser
class NonSourceParser(YamlDocsReader, Generic[NonSourceTarget, Parsed]):
@abstractmethod
def _target_type(self) -> Type[NonSourceTarget]:
raise NotImplementedError('_unsafe_from_dict not implemented')
raise NotImplementedError('_target_type not implemented')
@abstractmethod
def get_block(self, node: NonSourceTarget) -> TargetBlock:
@@ -732,33 +840,55 @@ class NonSourceParser(YamlDocsReader, Generic[NonSourceTarget, Parsed]):
def parse(self) -> List[TestBlock]:
node: NonSourceTarget
test_blocks: List[TestBlock] = []
# get list of 'node' objects
# UnparsedNodeUpdate (TestablePatchParser, models, seeds, snapshots)
# = HasColumnTests, HasTests
# UnparsedAnalysisUpdate (UnparsedAnalysisParser, analyses)
# = HasColumnDocs, HasDocs
# UnparsedMacroUpdate (MacroPatchParser, 'macros')
# = HasDocs
# correspond to this parser's 'key'
for node in self.get_unparsed_target():
# node_block is a TargetBlock (Macro or Analysis)
# or a TestBlock (all of the others)
node_block = self.get_block(node)
if isinstance(node_block, TestBlock):
# TestablePatchParser = models, seeds, snapshots
test_blocks.append(node_block)
if isinstance(node, (HasColumnDocs, HasColumnTests)):
# UnparsedNodeUpdate and UnparsedAnalysisUpdate
refs: ParserRef = ParserRef.from_target(node)
else:
refs = ParserRef()
# This adds the node_block to self.results (a ParseResult
# object) as a ParsedNodePatch or ParsedMacroPatch
self.parse_patch(node_block, refs)
return test_blocks
def get_unparsed_target(self) -> Iterable[NonSourceTarget]:
path = self.yaml.path.original_file_path
for data in self.get_key_dicts():
# get verified list of dicts for the 'key' that this
# parser handles
key_dicts = self.get_key_dicts()
for data in key_dicts:
# add extra data to each dict. This updates the dicts
# in the parser yaml
data.update({
'original_file_path': path,
'yaml_key': self.key,
'package_name': self.project.project_name,
})
try:
model = self._target_type().from_dict(data)
# target_type: UnparsedNodeUpdate, UnparsedAnalysisUpdate,
# or UnparsedMacroUpdate
self._target_type().validate(data)
node = self._target_type().from_dict(data)
except (ValidationError, JSONValidationException) as exc:
msg = error_context(path, self.key, data, exc)
raise CompilationException(msg) from exc
else:
yield model
yield node
class NodePatchParser(
@@ -866,6 +996,7 @@ class ExposureParser(YamlReader):
def parse(self) -> Iterable[ParsedExposure]:
for data in self.get_key_dicts():
try:
UnparsedExposure.validate(data)
unparsed = UnparsedExposure.from_dict(data)
except (ValidationError, JSONValidationException) as exc:
msg = error_context(self.yaml.path, self.key, data, exc)

View File

@@ -13,7 +13,9 @@ class SeedParser(SimpleSQLParser[ParsedSeedNode]):
)
def parse_from_dict(self, dct, validate=True) -> ParsedSeedNode:
return ParsedSeedNode.from_dict(dct, validate=validate)
if validate:
ParsedSeedNode.validate(dct)
return ParsedSeedNode.from_dict(dct)
@property
def resource_type(self) -> NodeType:

View File

@@ -1,7 +1,7 @@
import os
from typing import List
from hologram import ValidationError
from dbt.dataclass_schema import ValidationError
from dbt.contracts.graph.parsed import (
IntermediateSnapshotNode, ParsedSnapshotNode
@@ -26,7 +26,9 @@ class SnapshotParser(
)
def parse_from_dict(self, dct, validate=True) -> IntermediateSnapshotNode:
return IntermediateSnapshotNode.from_dict(dct, validate=validate)
if validate:
IntermediateSnapshotNode.validate(dct)
return IntermediateSnapshotNode.from_dict(dct)
@property
def resource_type(self) -> NodeType:

View File

@@ -6,7 +6,7 @@ from typing import (
Set,
)
from dbt.config import RuntimeConfig
from dbt.contracts.graph.manifest import Manifest, SourceKey
from dbt.contracts.graph.manifest import MacroManifest, SourceKey
from dbt.contracts.graph.parsed import (
UnpatchedSourceDefinition,
ParsedSourceDefinition,
@@ -33,7 +33,7 @@ class SourcePatcher:
) -> None:
self.results = results
self.root_project = root_project
self.macro_manifest = Manifest.from_macros(
self.macro_manifest = MacroManifest(
macros=self.results.macros,
files=self.results.files
)

View File

@@ -3,7 +3,7 @@ little bit too much to go anywhere else.
"""
from dbt.adapters.factory import get_adapter
from dbt.parser.manifest import load_manifest
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.manifest import Manifest, MacroManifest
from dbt.config import RuntimeConfig
@@ -23,10 +23,10 @@ def get_full_manifest(
config.clear_dependencies()
adapter.clear_macro_manifest()
internal: Manifest = adapter.load_macro_manifest()
macro_manifest: MacroManifest = adapter.load_macro_manifest()
return load_manifest(
config,
internal,
macro_manifest,
adapter.connections.set_query_header,
)

View File

@@ -1,8 +1,7 @@
import logbook
import logbook.queues
from jsonrpc.exceptions import JSONRPCError
from hologram import JsonSchemaMixin
from hologram.helpers import StrEnum
from dbt.dataclass_schema import StrEnum
from dataclasses import dataclass, field
from datetime import datetime, timedelta
@@ -25,8 +24,11 @@ class QueueMessageType(StrEnum):
terminating = frozenset((Error, Result, Timeout))
# This class was subclassed from JsonSchemaMixin, but it
# doesn't appear to be necessary, and Mashumaro does not
# handle logbook.LogRecord
@dataclass
class QueueMessage(JsonSchemaMixin):
class QueueMessage:
message_type: QueueMessageType

View File

@@ -3,7 +3,7 @@ from abc import abstractmethod
from copy import deepcopy
from typing import List, Optional, Type, TypeVar, Generic, Dict, Any
from hologram import JsonSchemaMixin, ValidationError
from dbt.dataclass_schema import dbtClassMixin, ValidationError
from dbt.contracts.rpc import RPCParameters, RemoteResult, RemoteMethodFlags
from dbt.exceptions import NotImplementedException, InternalException
@@ -109,7 +109,7 @@ class RemoteBuiltinMethod(RemoteMethod[Parameters, Result]):
'the run() method on builtins should never be called'
)
def __call__(self, **kwargs: Dict[str, Any]) -> JsonSchemaMixin:
def __call__(self, **kwargs: Dict[str, Any]) -> dbtClassMixin:
try:
params = self.get_parameters().from_dict(kwargs)
except ValidationError as exc:

View File

@@ -1,7 +1,7 @@
import json
from typing import Callable, Dict, Any
from hologram import JsonSchemaMixin
from dbt.dataclass_schema import dbtClassMixin
from jsonrpc.exceptions import (
JSONRPCParseError,
JSONRPCInvalidRequestException,
@@ -90,11 +90,14 @@ class ResponseManager(JSONRPCResponseManager):
@classmethod
def _get_responses(cls, requests, dispatcher):
for output in super()._get_responses(requests, dispatcher):
# if it's a result, check if it's a JsonSchemaMixin and if so call
# if it's a result, check if it's a dbtClassMixin and if so call
# to_dict
if hasattr(output, 'result'):
if isinstance(output.result, JsonSchemaMixin):
output.result = output.result.to_dict(omit_none=False)
if isinstance(output.result, dbtClassMixin):
# Note: errors in to_dict do not show up anywhere in
# the output and all you get is a generic 500 error
output.result = \
output.result.to_dict(options={'keep_none': True})
yield output
@classmethod

View File

@@ -9,7 +9,7 @@ from typing import (
)
from typing_extensions import Protocol
from hologram import JsonSchemaMixin, ValidationError
from dbt.dataclass_schema import dbtClassMixin, ValidationError
import dbt.exceptions
import dbt.flags
@@ -283,7 +283,7 @@ class RequestTaskHandler(threading.Thread, TaskHandlerProtocol):
# - The actual thread that this represents, which writes its data to
# the result and logs. The atomicity of list.append() and item
# assignment means we don't need a lock.
self.result: Optional[JsonSchemaMixin] = None
self.result: Optional[dbtClassMixin] = None
self.error: Optional[RPCException] = None
self.state: TaskHandlerState = TaskHandlerState.NotStarted
self.logs: List[LogMessage] = []
@@ -453,6 +453,7 @@ class RequestTaskHandler(threading.Thread, TaskHandlerProtocol):
)
try:
cls.validate(self.task_kwargs)
return cls.from_dict(self.task_kwargs)
except ValidationError as exc:
# raise a TypeError to indicate invalid parameters so we get a nice

View File

@@ -14,11 +14,11 @@ from dbt.contracts.rpc import (
class TaskHandlerProtocol(Protocol):
started: Optional[datetime]
ended: Optional[datetime]
state: TaskHandlerState
task_id: TaskID
process: Optional[multiprocessing.Process]
state: TaskHandlerState
started: Optional[datetime] = None
ended: Optional[datetime] = None
process: Optional[multiprocessing.Process] = None
@property
def request_id(self) -> Union[str, int]:

View File

@@ -4,8 +4,7 @@ import re
from dbt.exceptions import VersionsNotCompatibleException
import dbt.utils
from hologram import JsonSchemaMixin
from hologram.helpers import StrEnum
from dbt.dataclass_schema import dbtClassMixin, StrEnum
from typing import Optional
@@ -18,12 +17,12 @@ class Matchers(StrEnum):
@dataclass
class VersionSpecification(JsonSchemaMixin):
major: Optional[str]
minor: Optional[str]
patch: Optional[str]
prerelease: Optional[str]
build: Optional[str]
class VersionSpecification(dbtClassMixin):
major: Optional[str] = None
minor: Optional[str] = None
patch: Optional[str] = None
prerelease: Optional[str] = None
build: Optional[str] = None
matcher: Matchers = Matchers.EXACT

View File

@@ -48,7 +48,6 @@ Check your database credentials and try again. For more information, visit:
{url}
'''.lstrip()
MISSING_PROFILE_MESSAGE = '''
dbt looked for a profiles.yml file in {path}, but did
not find one. For more information on configuring your profile, consult the
@@ -90,6 +89,7 @@ class DebugTask(BaseTask):
self.profile_name: Optional[str] = None
self.project: Optional[Project] = None
self.project_fail_details = ''
self.any_failure = False
self.messages: List[str] = []
@property
@@ -111,7 +111,7 @@ class DebugTask(BaseTask):
def run(self):
if self.args.config_dir:
self.path_info()
return
return not self.any_failure
version = get_installed_version().to_version_string(skip_matcher=True)
print('dbt version: {}'.format(version))
@@ -129,6 +129,11 @@ class DebugTask(BaseTask):
print(message)
print('')
return not self.any_failure
def interpret_results(self, results):
return results
def _load_project(self):
if not os.path.exists(self.project_path):
self.project_fail_details = FILE_NOT_FOUND
@@ -245,6 +250,7 @@ class DebugTask(BaseTask):
self.messages.append(MISSING_PROFILE_MESSAGE.format(
path=self.profile_path, url=ProfileConfigDocs
))
self.any_failure = True
return red('ERROR not found')
try:
@@ -283,6 +289,7 @@ class DebugTask(BaseTask):
dbt.clients.system.run_cmd(os.getcwd(), ['git', '--help'])
except dbt.exceptions.ExecutableError as exc:
self.messages.append('Error from git --help: {!s}'.format(exc))
self.any_failure = True
return red('ERROR')
return green('OK found')
@@ -310,6 +317,8 @@ class DebugTask(BaseTask):
def _log_project_fail(self):
if not self.project_fail_details:
return
self.any_failure = True
if self.project_fail_details == FILE_NOT_FOUND:
return
print('Project loading failed for the following reason:')
@@ -319,6 +328,8 @@ class DebugTask(BaseTask):
def _log_profile_fail(self):
if not self.profile_fail_details:
return
self.any_failure = True
if self.profile_fail_details == FILE_NOT_FOUND:
return
print('Profile loading failed for the following reason:')
@@ -347,6 +358,7 @@ class DebugTask(BaseTask):
result = self.attempt_connection(self.profile)
if result is not None:
self.messages.append(result)
self.any_failure = True
return red('ERROR')
return green('OK connection ok')

View File

@@ -3,7 +3,7 @@ import shutil
from datetime import datetime
from typing import Dict, List, Any, Optional, Tuple, Set
from hologram import ValidationError
from dbt.dataclass_schema import ValidationError
from .compile import CompileTask

View File

@@ -110,7 +110,7 @@ class ListTask(GraphRunnableTask):
for node in self._iterate_selected_nodes():
yield json.dumps({
k: v
for k, v in node.to_dict(omit_none=False).items()
for k, v in node.to_dict(options={'keep_none': True}).items()
if k in self.ALLOWED_KEYS
})

View File

@@ -8,12 +8,17 @@
# snakeviz dbt.cprof
from dbt.task.base import ConfiguredTask
from dbt.adapters.factory import get_adapter
from dbt.parser.manifest import Manifest, ManifestLoader, _check_manifest
from dbt.parser.manifest import (
Manifest, MacroManifest, ManifestLoader, _check_manifest
)
from dbt.logger import DbtProcessState, print_timestamped_line
from dbt.clients.system import write_file
from dbt.graph import Graph
import time
from typing import Optional
import os
import json
import dbt.utils
MANIFEST_FILE_NAME = 'manifest.json'
PERF_INFO_FILE_NAME = 'perf_info.json'
@@ -33,7 +38,8 @@ class ParseTask(ConfiguredTask):
def write_perf_info(self):
path = os.path.join(self.config.target_path, PERF_INFO_FILE_NAME)
self.loader._perf_info.write(path)
write_file(path, json.dumps(self.loader._perf_info,
cls=dbt.utils.JSONEncoder, indent=4))
print_timestamped_line(f"Performance info: {path}")
# This method takes code that normally exists in other files
@@ -47,7 +53,7 @@ class ParseTask(ConfiguredTask):
def get_full_manifest(self):
adapter = get_adapter(self.config) # type: ignore
macro_manifest: Manifest = adapter.load_macro_manifest()
macro_manifest: MacroManifest = adapter.load_macro_manifest()
print_timestamped_line("Macro manifest loaded")
root_config = self.config
macro_hook = adapter.connections.set_query_header

View File

@@ -1,6 +1,6 @@
import abc
import shlex
import yaml
from dbt.clients.yaml_helper import Dumper, yaml # noqa: F401
from typing import Type, Optional

View File

@@ -3,7 +3,7 @@ import threading
import time
from typing import List, Dict, Any, Iterable, Set, Tuple, Optional, AbstractSet
from hologram import JsonSchemaMixin
from dbt.dataclass_schema import dbtClassMixin
from .compile import CompileRunner, CompileTask
@@ -96,6 +96,7 @@ def get_hooks_by_tags(
def get_hook(source, index):
hook_dict = get_hook_dict(source)
hook_dict.setdefault('index', index)
Hook.validate(hook_dict)
return Hook.from_dict(hook_dict)
@@ -191,7 +192,7 @@ class ModelRunner(CompileRunner):
def _build_run_model_result(self, model, context):
result = context['load_result']('main')
adapter_response = {}
if isinstance(result.response, JsonSchemaMixin):
if isinstance(result.response, dbtClassMixin):
adapter_response = result.response.to_dict()
return RunResult(
node=model,

View File

@@ -1,6 +1,8 @@
from typing import Optional
from dbt.clients import yaml_helper
from dbt.clients.yaml_helper import ( # noqa:F401
yaml, safe_load, Loader, Dumper,
)
from dbt.logger import GLOBAL_LOGGER as logger
from dbt import version as dbt_version
from snowplow_tracker import Subject, Tracker, Emitter, logger as sp_logger
@@ -12,25 +14,37 @@ import pytz
import platform
import uuid
import requests
import yaml
import os
from extensions import tracking
sp_logger.setLevel(100)
COLLECTOR_URL = "fishtownanalytics.sinter-collect.com"
COLLECTOR_PROTOCOL = "https"
# COLLECTOR_URL = "fishtownanalytics.sinter-collect.com"
# COLLECTOR_PROTOCOL = "https"
INVOCATION_SPEC = 'iglu:com.dbt/invocation/jsonschema/1-0-1'
PLATFORM_SPEC = 'iglu:com.dbt/platform/jsonschema/1-0-0'
RUN_MODEL_SPEC = 'iglu:com.dbt/run_model/jsonschema/1-0-1'
INVOCATION_ENV_SPEC = 'iglu:com.dbt/invocation_env/jsonschema/1-0-0'
PACKAGE_INSTALL_SPEC = 'iglu:com.dbt/package_install/jsonschema/1-0-0'
RPC_REQUEST_SPEC = 'iglu:com.dbt/rpc_request/jsonschema/1-0-1'
DEPRECATION_WARN_SPEC = 'iglu:com.dbt/deprecation_warn/jsonschema/1-0-0'
LOAD_ALL_TIMING_SPEC = 'iglu:com.dbt/load_all_timing/jsonschema/1-0-0'
# INVOCATION_SPEC = 'iglu:com.dbt/invocation/jsonschema/1-0-1'
# PLATFORM_SPEC = 'iglu:com.dbt/platform/jsonschema/1-0-0'
# RUN_MODEL_SPEC = 'iglu:com.dbt/run_model/jsonschema/1-0-1'
# INVOCATION_ENV_SPEC = 'iglu:com.dbt/invocation_env/jsonschema/1-0-0'
# PACKAGE_INSTALL_SPEC = 'iglu:com.dbt/package_install/jsonschema/1-0-0'
# RPC_REQUEST_SPEC = 'iglu:com.dbt/rpc_request/jsonschema/1-0-1'
# DEPRECATION_WARN_SPEC = 'iglu:com.dbt/deprecation_warn/jsonschema/1-0-0'
# LOAD_ALL_TIMING_SPEC = 'iglu:com.dbt/load_all_timing/jsonschema/1-0-0'
DBT_INVOCATION_ENV = 'DBT_INVOCATION_ENV'
# DBT_INVOCATION_ENV = 'DBT_INVOCATION_ENV'
COLLECTOR_URL = tracking.connector_url()
COLLECTOR_PROTOCOL = tracking.collector_protocol()
INVOCATION_SPEC = tracking.invocation_spec()
PLATFORM_SPEC = tracking.platform_spec()
RUN_MODEL_SPEC = tracking.run_model_spec()
INVOCATION_ENV_SPEC = tracking.invocation_env_spec()
PACKAGE_INSTALL_SPEC = tracking.package_install_spec()
RPC_REQUEST_SPEC = tracking.rpc_request_spec()
DEPRECATION_WARN_SPEC = tracking.deprecation_warn_spec()
LOAD_ALL_TIMING_SPEC = tracking.load_all_timing_spec()
DBT_INVOCATION_ENV = tracking.dbt_invocation_env()
class TimeoutEmitter(Emitter):
def __init__(self):
@@ -147,7 +161,7 @@ class User:
else:
with open(self.cookie_path, "r") as fh:
try:
user = yaml_helper.safe_load(fh)
user = safe_load(fh)
if user is None:
user = self.set_cookie()
except yaml.reader.ReaderError:

View File

@@ -298,7 +298,7 @@ def filter_null_values(input: Dict[K_T, Optional[V_T]]) -> Dict[K_T, V_T]:
def add_ephemeral_model_prefix(s: str) -> str:
return '__dbt__CTE__{}'.format(s)
return '__dbt__cte__{}'.format(s)
def timestring() -> str:
@@ -415,7 +415,7 @@ def restrict_to(*restrictions):
def coerce_dict_str(value: Any) -> Optional[Dict[str, Any]]:
"""For annoying mypy reasons, this helper makes dealing with nested dicts
easier. You get either `None` if it's not a Dict[str, Any], or the
Dict[str, Any] you expected (to pass it to JsonSchemaMixin.from_dict(...)).
Dict[str, Any] you expected (to pass it to dbtClassMixin.from_dict(...)).
"""
if (isinstance(value, dict) and all(isinstance(k, str) for k in value)):
return value
@@ -539,7 +539,9 @@ def fqn_search(
level_config = root.get(level, None)
if not isinstance(level_config, dict):
break
yield copy.deepcopy(level_config)
# This used to do a 'deepcopy',
# but it didn't seem to be necessary
yield level_config
root = level_config

View File

@@ -0,0 +1 @@
from . import tracking

View File

@@ -0,0 +1,15 @@
[package]
name = "extensions-tracking"
version = "0.1.0"
publish = false
edition = "2018"
workspace = "../.."
[lib]
crate-type = [ "cdylib",]
name = "extensions_tracking"
path = "lib.rs"
[dependencies.pyo3]
version = "0.13.1"
features = [ "extension-module",]

View File

@@ -0,0 +1,45 @@
use pyo3::prelude::*;
use pyo3::wrap_pyfunction;
#[pyfunction]
pub fn collector_url() -> PyResult<String> { Ok("fishtownanalytics.sinter-collect.com".to_owned().to_owned()) }
#[pyfunction]
pub fn collector_protocol() -> PyResult<String> { Ok("https".to_owned()) }
#[pyfunction]
pub fn invocation_spec() -> PyResult<String> { Ok("iglu:com.dbt/invocation/jsonschema/1-0-1".to_owned()) }
#[pyfunction]
pub fn platform_spec() -> PyResult<String> { Ok("iglu:com.dbt/platform/jsonschema/1-0-0".to_owned()) }
#[pyfunction]
pub fn run_model_spec() -> PyResult<String> { Ok("iglu:com.dbt/run_model/jsonschema/1-0-1".to_owned()) }
#[pyfunction]
pub fn invocation_new_spec() -> PyResult<String> { Ok("iglu:com.dbt/invocation_env/jsonschema/1-0-0".to_owned()) }
#[pyfunction]
pub fn package_install_spec() -> PyResult<String> { Ok("iglu:com.dbt/package_install/jsonschema/1-0-0".to_owned()) }
#[pyfunction]
pub fn rpc_request_spec() -> PyResult<String> { Ok("iglu:com.dbt/rpc_request/jsonschema/1-0-1".to_owned()) }
#[pyfunction]
pub fn deprecation_warn_spec() -> PyResult<String> { Ok("iglu:com.dbt/deprecation_warn/jsonschema/1-0-0".to_owned()) }
#[pyfunction]
pub fn load_all_timing_spec() -> PyResult<String> { Ok("iglu:com.dbt/load_all_timing/jsonschema/1-0-0".to_owned()) }
#[pyfunction]
pub fn dbt_invocation_env() -> PyResult<String> { Ok("DBT_INVOCATION_ENV".to_owned()) }
/// This module is a python module implemented in Rust.
/// the function name must match the library name in Cargo.toml
#[pymodule]
fn tracking(_: Python, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(collector_url))?;
m.add_wrapped(wrap_pyfunction!(collector_protocol))?;
m.add_wrapped(wrap_pyfunction!(invocation_spec))?;
m.add_wrapped(wrap_pyfunction!(platform_spec))?;
m.add_wrapped(wrap_pyfunction!(run_model_spec))?;
m.add_wrapped(wrap_pyfunction!(invocation_new_spec))?;
m.add_wrapped(wrap_pyfunction!(package_install_spec))?;
m.add_wrapped(wrap_pyfunction!(deprecation_warn_spec))?;
m.add_wrapped(wrap_pyfunction!(load_all_timing_spec))?;
m.add_wrapped(wrap_pyfunction!(dbt_invocation_env))?;
Ok(())
}

2
core/pyproject.toml Normal file
View File

@@ -0,0 +1,2 @@
[build-system]
requires = ["setuptools", "wheel", "setuptools-rust"]

View File

@@ -1,6 +1,7 @@
#!/usr/bin/env python
import os
import sys
import setuptools_rust as rust
if sys.version_info < (3, 6):
print('Error: dbt does not support this version of Python.')
@@ -37,6 +38,9 @@ setup(
author="Fishtown Analytics",
author_email="info@fishtownanalytics.com",
url="https://github.com/fishtown-analytics/dbt",
rust_extensions=rust.find_rust_extensions(
binding=rust.Binding.PyO3, strip=rust.Strip.Debug
),
packages=find_namespace_packages(include=['dbt', 'dbt.*']),
package_data={
'dbt': [
@@ -70,7 +74,7 @@ setup(
'json-rpc>=1.12,<2',
'werkzeug>=0.15,<2.0',
'dataclasses==0.6;python_version<"3.7"',
'hologram==0.0.12',
# 'hologram==0.0.12', # must be updated prior to release
'logbook>=1.5,<1.6',
'typing-extensions>=3.7.4,<3.8',
# the following are all to match snowflake-connector-python

View File

@@ -13,3 +13,6 @@ mypy==0.782
wheel
twine
pytest-logbook>=1.2.0,<1.3
git+https://github.com/fishtown-analytics/hologram.git@mashumaro-support
git+https://github.com/fishtown-analytics/dbt-mashumaro.git@dbt-customizations
jsonschema

View File

@@ -27,7 +27,7 @@ from dbt.adapters.base import BaseConnectionManager, Credentials
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.version import __version__ as dbt_version
from hologram.helpers import StrEnum
from dbt.dataclass_schema import StrEnum
BQ_QUERY_JOB_SPLIT = '-----Query Job SQL Follows-----'

View File

@@ -1,6 +1,6 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Any, Set, Union
from hologram import JsonSchemaMixin, ValidationError
from dbt.dataclass_schema import dbtClassMixin, ValidationError
import dbt.deprecations
import dbt.exceptions
@@ -47,7 +47,7 @@ def sql_escape(string):
@dataclass
class PartitionConfig(JsonSchemaMixin):
class PartitionConfig(dbtClassMixin):
field: str
data_type: str = 'date'
granularity: str = 'day'
@@ -69,6 +69,7 @@ class PartitionConfig(JsonSchemaMixin):
if raw_partition_by is None:
return None
try:
cls.validate(raw_partition_by)
return cls.from_dict(raw_partition_by)
except ValidationError as exc:
msg = dbt.exceptions.validator_error_message(exc)
@@ -84,7 +85,7 @@ class PartitionConfig(JsonSchemaMixin):
@dataclass
class GrantTarget(JsonSchemaMixin):
class GrantTarget(dbtClassMixin):
dataset: str
project: str
@@ -111,6 +112,8 @@ class BigqueryConfig(AdapterConfig):
partitions: Optional[List[str]] = None
grant_access_to: Optional[List[Dict[str, str]]] = None
hours_to_expiration: Optional[int] = None
require_partition_filter: Optional[bool] = None
partition_expiration_days: Optional[int] = None
class BigQueryAdapter(BaseAdapter):
@@ -788,6 +791,14 @@ class BigQueryAdapter(BaseAdapter):
labels = config.get('labels', {})
opts['labels'] = list(labels.items())
if config.get('require_partition_filter'):
opts['require_partition_filter'] = config.get(
'require_partition_filter')
if config.get('partition_expiration_days') is not None:
opts['partition_expiration_days'] = config.get(
'partition_expiration_days')
return opts
@available.parse_none
@@ -798,6 +809,7 @@ class BigQueryAdapter(BaseAdapter):
conn = self.connections.get_thread_connection()
client = conn.handle
GrantTarget.validate(grant_target_dict)
grant_target = GrantTarget.from_dict(grant_target_dict)
dataset = client.get_dataset(
self.connections.dataset_from_id(grant_target.render())

View File

@@ -54,6 +54,7 @@
set _dbt_max_partition = (
select max({{ partition_by.field }}) from {{ this }}
where {{ partition_by.field }} is not null
);
-- 1. create a temp table

View File

@@ -17,9 +17,9 @@ from typing import Optional
class PostgresCredentials(Credentials):
host: str
user: str
role: Optional[str]
port: Port
password: str # on postgres the password is mandatory
role: Optional[str] = None
search_path: Optional[str] = None
keepalives_idle: int = 0 # 0 means to use the default value
sslmode: Optional[str] = None

View File

@@ -10,8 +10,7 @@ import dbt.flags
import boto3
from hologram import FieldEncoder, JsonSchemaMixin
from hologram.helpers import StrEnum
from dbt.dataclass_schema import FieldEncoder, dbtClassMixin, StrEnum
from dataclasses import dataclass, field
from typing import Optional, List
@@ -28,7 +27,7 @@ class IAMDurationEncoder(FieldEncoder):
return {'type': 'integer', 'minimum': 0, 'maximum': 65535}
JsonSchemaMixin.register_field_encoders({IAMDuration: IAMDurationEncoder()})
dbtClassMixin.register_field_encoders({IAMDuration: IAMDurationEncoder()})
class RedshiftConnectionMethod(StrEnum):

View File

@@ -30,16 +30,16 @@ _TOKEN_REQUEST_URL = 'https://{}.snowflakecomputing.com/oauth/token-request'
class SnowflakeCredentials(Credentials):
account: str
user: str
warehouse: Optional[str]
role: Optional[str]
password: Optional[str]
authenticator: Optional[str]
private_key_path: Optional[str]
private_key_passphrase: Optional[str]
token: Optional[str]
oauth_client_id: Optional[str]
oauth_client_secret: Optional[str]
query_tag: Optional[str]
warehouse: Optional[str] = None
role: Optional[str] = None
password: Optional[str] = None
authenticator: Optional[str] = None
private_key_path: Optional[str] = None
private_key_passphrase: Optional[str] = None
token: Optional[str] = None
oauth_client_id: Optional[str] = None
oauth_client_secret: Optional[str] = None
query_tag: Optional[str] = None
client_session_keep_alive: bool = False
def __post_init__(self):
@@ -305,7 +305,9 @@ class SnowflakeConnectionManager(SQLConnectionManager):
# empty queries. this avoids using exceptions as flow control,
# and also allows us to return the status of the last cursor
without_comments = re.sub(
re.compile('^.*(--.*)$', re.MULTILINE),
re.compile(
r'(\".*?\"|\'.*?\')|(/\*.*?\*/|--[^\r\n]*$)', re.MULTILINE
),
'', individual_query).strip()
if without_comments == "":

17
scripts/check_libyaml.py Executable file
View File

@@ -0,0 +1,17 @@
#!/usr/bin/env python
try:
from yaml import (
CLoader as Loader,
CSafeLoader as SafeLoader,
CDumper as Dumper
)
except ImportError:
from yaml import (
Loader, SafeLoader, Dumper
)
if Loader.__name__ == 'CLoader':
print("libyaml is working")
elif Loader.__name__ == 'Loader':
print("libyaml is not working")
print("Check the python executable and pyyaml for libyaml support")

View File

@@ -3,7 +3,7 @@ from dataclasses import dataclass
from typing import Dict, Any
import json
from hologram import JsonSchemaMixin
from dbt.dataclass_schema import dbtClassMixin
from dbt.contracts.graph.manifest import WritableManifest
from dbt.contracts.results import (
CatalogArtifact, RunResultsArtifact, FreshnessExecutionResultArtifact
@@ -11,7 +11,7 @@ from dbt.contracts.results import (
@dataclass
class Schemas(JsonSchemaMixin):
class Schemas(dbtClassMixin):
manifest: Dict[str, Any]
catalog: Dict[str, Any]
run_results: Dict[str, Any]

View File

@@ -4,7 +4,7 @@ import inspect
import json
from dataclasses import dataclass
from typing import List, Optional, Iterable, Union, Dict, Any
from hologram import JsonSchemaMixin
from dbt.dataclass_schema import dbtClassMixin
from dbt.context.base import BaseContext
@@ -21,20 +21,20 @@ CONTEXTS_MAP = {
@dataclass
class ContextValue(JsonSchemaMixin):
class ContextValue(dbtClassMixin):
name: str
value: str # a type description
doc: Optional[str]
@dataclass
class MethodArgument(JsonSchemaMixin):
class MethodArgument(dbtClassMixin):
name: str
value: str # a type description
@dataclass
class ContextMethod(JsonSchemaMixin):
class ContextMethod(dbtClassMixin):
name: str
args: List[MethodArgument]
result: str # a type description
@@ -42,7 +42,7 @@ class ContextMethod(JsonSchemaMixin):
@dataclass
class Unknown(JsonSchemaMixin):
class Unknown(dbtClassMixin):
name: str
value: str
doc: Optional[str]
@@ -96,7 +96,7 @@ def collect(cls):
@dataclass
class ContextCatalog(JsonSchemaMixin):
class ContextCatalog(dbtClassMixin):
base: List[ContextMember]
target: List[ContextMember]
model: List[ContextMember]

View File

@@ -7,10 +7,11 @@ import subprocess
import sys
# Python version defaults to 3.6
# To run postgres integration tests: `dtr.py -i --pg` (this is the default)
# To run postgres integration tests, clearing `dbt.log` beforehand: `dtr.py -il --pg`
# To run postgres + redshift integration tests: `dtr.py -i --pg --rs`
# To drop to pdb on failure, add `--pdb`
# To run postgres integration tests: `dtr.py -i -t pg` (this is the default)
# To run postgres integration tests, clearing `dbt.log` beforehand: `dtr.py -il -t pg`
# dtr.py -i -t pg -a test/integration/029_docs_generate_tests
# To run postgres + redshift integration tests: `dtr.py -i -t pg -t rs`
# To drop to pdb on failure, add `--pdb` or `-p`
# To run mypy tests: `dtr.py -m`.
# To run flake8 test: `dtr.py -f`.
# To run unit tests: `dtr.py -u`
@@ -82,12 +83,12 @@ def parse_args(argv):
)
parser.add_argument('-v', '--python-version',
default='36', choices=['27', '36', '37', '38'],
default='38', choices=['36', '37', '38'],
help='what python version to run')
parser.add_argument(
'-t', '--types',
default=None,
help='The types of tests to run, if this is an integration run, as csv'
help='The types of tests to run, if this is an integration run'
)
parser.add_argument(
'-c', '--continue',

View File

@@ -2,6 +2,7 @@
import os
import sys
if sys.version_info < (3, 6):
print('Error: dbt does not support this version of Python.')
print('Please upgrade to Python 3.6 or higher.')
@@ -36,11 +37,9 @@ setup(
description=description,
long_description=long_description,
long_description_content_type='text/markdown',
author="Fishtown Analytics",
author_email="info@fishtownanalytics.com",
url="https://github.com/fishtown-analytics/dbt",
packages=[],
install_requires=[
'dbt-core=={}'.format(package_version),
'dbt-postgres=={}'.format(package_version),

View File

@@ -531,7 +531,7 @@ class TestBadSnapshot(DBTIntegrationTest):
with self.assertRaises(dbt.exceptions.CompilationException) as exc:
self.run_dbt(['compile'], expect_pass=False)
self.assertIn('target_schema', str(exc.exception))
self.assertIn('Compilation Error in model ref_snapshot', str(exc.exception))
class TestCheckCols(TestSimpleSnapshotFiles):

View File

@@ -29,15 +29,15 @@ class TestEphemeralMulti(DBTIntegrationTest):
sql_file = re.sub(r'\d+', '', sql_file)
expected_sql = ('create view "dbt"."test_ephemeral_"."double_dependent__dbt_tmp" as ('
'with __dbt__CTE__base as ('
'with __dbt__cte__base as ('
'select * from test_ephemeral_.seed'
'), __dbt__CTE__base_copy as ('
'select * from __dbt__CTE__base'
'), __dbt__cte__base_copy as ('
'select * from __dbt__cte__base'
')-- base_copy just pulls from base. Make sure the listed'
'-- graph of CTEs all share the same dbt_cte__base cte'
"select * from __dbt__CTE__base where gender = 'Male'"
"select * from __dbt__cte__base where gender = 'Male'"
'union all'
"select * from __dbt__CTE__base_copy where gender = 'Female'"
"select * from __dbt__cte__base_copy where gender = 'Female'"
');')
sql_file = "".join(sql_file.split())
expected_sql = "".join(expected_sql.split())
@@ -79,11 +79,11 @@ class TestEphemeralNested(DBTIntegrationTest):
sql_file = re.sub(r'\d+', '', sql_file)
expected_sql = (
'create view "dbt"."test_ephemeral_"."root_view__dbt_tmp" as ('
'with __dbt__CTE__ephemeral_level_two as ('
'with __dbt__cte__ephemeral_level_two as ('
'select * from "dbt"."test_ephemeral_"."source_table"'
'), __dbt__CTE__ephemeral as ('
'select * from __dbt__CTE__ephemeral_level_two'
')select * from __dbt__CTE__ephemeral'
'), __dbt__cte__ephemeral as ('
'select * from __dbt__cte__ephemeral_level_two'
')select * from __dbt__cte__ephemeral'
');')
sql_file = "".join(sql_file.split())

View File

@@ -4,7 +4,9 @@
config(
materialized="table",
partition_by=var('partition_by'),
cluster_by=var('cluster_by')
cluster_by=var('cluster_by'),
partition_expiration_days=var('partition_expiration_days'),
require_partition_filter=var('require_partition_filter')
)
}}

View File

@@ -1,4 +1,6 @@
""""Test adapter specific config options."""
from pprint import pprint
from test.integration.base import DBTIntegrationTest, use_profile
import textwrap
import yaml
@@ -32,7 +34,7 @@ class TestBigqueryAdapterSpecific(DBTIntegrationTest):
@use_profile('bigquery')
def test_bigquery_hours_to_expiration(self):
_, stdout = self.run_dbt_and_capture(['--debug', 'run'])
self.assertIn(
'expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL '
'4 hour)', stdout)

View File

@@ -29,45 +29,91 @@ class TestChangingPartitions(DBTIntegrationTest):
@use_profile('bigquery')
def test_bigquery_add_partition(self):
before = {"partition_by": None, "cluster_by": None}
after = {"partition_by": {'field': 'cur_time',
'data_type': 'timestamp'}, "cluster_by": None}
before = {"partition_by": None,
"cluster_by": None,
'partition_expiration_days': None,
'require_partition_filter': None}
after = {"partition_by": {'field': 'cur_time', 'data_type': 'timestamp'},
"cluster_by": None,
'partition_expiration_days': 7,
'require_partition_filter': True}
self.run_changes(before, after)
self.test_partitions({"expected": 1})
@use_profile('bigquery')
def test_bigquery_add_partition_year(self):
before = {"partition_by": None, "cluster_by": None}
after = {"partition_by": {'field': 'cur_time', 'data_type': 'timestamp', 'granularity': 'year'}, "cluster_by": None}
before = {"partition_by": None,
"cluster_by": None,
'partition_expiration_days': None,
'require_partition_filter': None}
after = {"partition_by": {'field': 'cur_time', 'data_type': 'timestamp', 'granularity': 'year'},
"cluster_by": None,
'partition_expiration_days': None,
'require_partition_filter': None}
self.run_changes(before, after)
self.test_partitions({"expected": 1})
@use_profile('bigquery')
def test_bigquery_add_partition_month(self):
before = {"partition_by": None, "cluster_by": None}
after = {"partition_by": {'field': 'cur_time', 'data_type': 'timestamp', 'granularity': 'month'}, "cluster_by": None}
before = {"partition_by": None,
"cluster_by": None,
'partition_expiration_days': None,
'require_partition_filter': None}
after = {"partition_by": {'field': 'cur_time', 'data_type': 'timestamp', 'granularity': 'month'},
"cluster_by": None,
'partition_expiration_days': None,
'require_partition_filter': None}
self.run_changes(before, after)
self.test_partitions({"expected": 1})
@use_profile('bigquery')
def test_bigquery_add_partition_hour(self):
before = {"partition_by": None, "cluster_by": None}
after = {"partition_by": {'field': 'cur_time', 'data_type': 'timestamp', 'granularity': 'hour'}, "cluster_by": None}
before = {"partition_by": None,
"cluster_by": None,
'partition_expiration_days': None,
'require_partition_filter': None}
after = {"partition_by": {'field': 'cur_time', 'data_type': 'timestamp', 'granularity': 'hour'},
"cluster_by": None,
'partition_expiration_days': None,
'require_partition_filter': None}
self.run_changes(before, after)
self.test_partitions({"expected": 1})
@use_profile('bigquery')
def test_bigquery_add_partition_hour(self):
before = {"partition_by": {'field': 'cur_time', 'data_type': 'timestamp', 'granularity': 'day'},
"cluster_by": None,
'partition_expiration_days': None,
'require_partition_filter': None}
after = {"partition_by": {'field': 'cur_time', 'data_type': 'timestamp', 'granularity': 'hour'},
"cluster_by": None,
'partition_expiration_days': None,
'require_partition_filter': None}
self.run_changes(before, after)
self.test_partitions({"expected": 1})
@use_profile('bigquery')
def test_bigquery_remove_partition(self):
before = {"partition_by": {'field': 'cur_time',
'data_type': 'timestamp'}, "cluster_by": None}
after = {"partition_by": None, "cluster_by": None}
before = {"partition_by": {'field': 'cur_time', 'data_type': 'timestamp'},
"cluster_by": None,
'partition_expiration_days': None,
'require_partition_filter': None}
after = {"partition_by": None,
"cluster_by": None,
'partition_expiration_days': None,
'require_partition_filter': None}
self.run_changes(before, after)
@use_profile('bigquery')
def test_bigquery_change_partitions(self):
before = {"partition_by": {'field': 'cur_time',
'data_type': 'timestamp'}, "cluster_by": None}
after = {"partition_by": {'field': "cur_date"}, "cluster_by": None}
before = {"partition_by": {'field': 'cur_time', 'data_type': 'timestamp'},
"cluster_by": None,
'partition_expiration_days': None,
'require_partition_filter': None}
after = {"partition_by": {'field': "cur_date"},
"cluster_by": None,
'partition_expiration_days': 7,
'require_partition_filter': True}
self.run_changes(before, after)
self.test_partitions({"expected": 1})
self.run_changes(after, before)
@@ -75,10 +121,14 @@ class TestChangingPartitions(DBTIntegrationTest):
@use_profile('bigquery')
def test_bigquery_change_partitions_from_int(self):
before = {"partition_by": {"field": "id", "data_type": "int64", "range": {
"start": 0, "end": 10, "interval": 1}}, "cluster_by": None}
after = {"partition_by": {"field": "cur_date",
"data_type": "date"}, "cluster_by": None}
before = {"partition_by": {"field": "id", "data_type": "int64", "range": {"start": 0, "end": 10, "interval": 1}},
"cluster_by": None,
'partition_expiration_days': None,
'require_partition_filter': None}
after = {"partition_by": {"field": "cur_date", "data_type": "date"},
"cluster_by": None,
'partition_expiration_days': None,
'require_partition_filter': None}
self.run_changes(before, after)
self.test_partitions({"expected": 1})
self.run_changes(after, before)
@@ -86,29 +136,48 @@ class TestChangingPartitions(DBTIntegrationTest):
@use_profile('bigquery')
def test_bigquery_add_clustering(self):
before = {"partition_by": {'field': 'cur_time',
'data_type': 'timestamp'}, "cluster_by": None}
after = {"partition_by": {'field': "cur_date"}, "cluster_by": "id"}
before = {"partition_by": {'field': 'cur_time', 'data_type': 'timestamp'},
"cluster_by": None,
'partition_expiration_days': None,
'require_partition_filter': None}
after = {"partition_by": {'field': "cur_date"},
"cluster_by": "id",
'partition_expiration_days': None,
'require_partition_filter': None}
self.run_changes(before, after)
@use_profile('bigquery')
def test_bigquery_remove_clustering(self):
before = {"partition_by": {'field': 'cur_time',
'data_type': 'timestamp'}, "cluster_by": "id"}
after = {"partition_by": {'field': "cur_date"}, "cluster_by": None}
before = {"partition_by": {'field': 'cur_time', 'data_type': 'timestamp'},
"cluster_by": "id",
'partition_expiration_days': None,
'require_partition_filter': None}
after = {"partition_by": {'field': "cur_date"},
"cluster_by": None,
'partition_expiration_days': None,
'require_partition_filter': None}
self.run_changes(before, after)
@use_profile('bigquery')
def test_bigquery_change_clustering(self):
before = {"partition_by": {'field': 'cur_time',
'data_type': 'timestamp'}, "cluster_by": "id"}
after = {"partition_by": {'field': "cur_date"}, "cluster_by": "name"}
before = {"partition_by": {'field': 'cur_time', 'data_type': 'timestamp'},
"cluster_by": "id",
'partition_expiration_days': None,
'require_partition_filter': None}
after = {"partition_by": {'field': "cur_date"},
"cluster_by": "name",
'partition_expiration_days': None,
'require_partition_filter': None}
self.run_changes(before, after)
@use_profile('bigquery')
def test_bigquery_change_clustering_strict(self):
before = {'partition_by': {'field': 'cur_time',
'data_type': 'timestamp'}, 'cluster_by': 'id'}
after = {'partition_by': {'field': 'cur_date',
'data_type': 'date'}, 'cluster_by': 'name'}
before = {'partition_by': {'field': 'cur_time', 'data_type': 'timestamp'},
'cluster_by': 'id',
'partition_expiration_days': None,
'require_partition_filter': None}
after = {'partition_by': {'field': 'cur_date', 'data_type': 'date'},
'cluster_by': 'name',
'partition_expiration_days': None,
'require_partition_filter': None}
self.run_changes(before, after)

View File

@@ -92,6 +92,7 @@ class TestStrictUndefined(DBTIntegrationTest):
'database': None,
'schema': None,
'alias': None,
'check_cols': None,
},
'alias': 'my_snapshot',
'resource_type': 'snapshot',

View File

@@ -66,17 +66,17 @@ class TestDebug(DBTIntegrationTest):
@use_profile('postgres')
def test_postgres_nopass(self):
self.run_dbt(['debug', '--target', 'nopass'])
self.run_dbt(['debug', '--target', 'nopass'], expect_pass=False)
self.assertGotValue(re.compile(r'\s+profiles\.yml file'), 'ERROR invalid')
@use_profile('postgres')
def test_postgres_wronguser(self):
self.run_dbt(['debug', '--target', 'wronguser'])
self.run_dbt(['debug', '--target', 'wronguser'], expect_pass=False)
self.assertGotValue(re.compile(r'\s+Connection test'), 'ERROR')
@use_profile('postgres')
def test_postgres_empty_target(self):
self.run_dbt(['debug', '--target', 'none_target'])
self.run_dbt(['debug', '--target', 'none_target'], expect_pass=False)
self.assertGotValue(re.compile(r"\s+output 'none_target'"), 'misconfigured')
@@ -110,7 +110,7 @@ class TestDebugInvalidProject(DBTIntegrationTest):
def test_postgres_empty_project(self):
with open('dbt_project.yml', 'w') as f:
pass
self.run_dbt(['debug', '--profile', 'test'])
self.run_dbt(['debug', '--profile', 'test'], expect_pass=False)
splitout = self.capsys.readouterr().out.split('\n')
for line in splitout:
if line.strip().startswith('dbt_project.yml file'):
@@ -124,7 +124,7 @@ class TestDebugInvalidProject(DBTIntegrationTest):
self.use_default_project(overrides={
'invalid-key': 'not a valid key so this is bad project',
})
self.run_dbt(['debug', '--profile', 'test'])
self.run_dbt(['debug', '--profile', 'test'], expect_pass=False)
splitout = self.capsys.readouterr().out.split('\n')
for line in splitout:
if line.strip().startswith('dbt_project.yml file'):
@@ -134,7 +134,7 @@ class TestDebugInvalidProject(DBTIntegrationTest):
@use_profile('postgres')
def test_postgres_not_found_project_dir(self):
self.run_dbt(['debug', '--project-dir', 'nopass'])
self.run_dbt(['debug', '--project-dir', 'nopass'], expect_pass=False)
splitout = self.capsys.readouterr().out.split('\n')
for line in splitout:
if line.strip().startswith('dbt_project.yml file'):
@@ -151,7 +151,7 @@ class TestDebugInvalidProject(DBTIntegrationTest):
os.makedirs('custom', exist_ok=True)
with open("custom/dbt_project.yml", 'w') as f:
yaml.safe_dump(project_config, f, default_flow_style=True)
self.run_dbt(['debug', '--project-dir', 'custom'])
self.run_dbt(['debug', '--project-dir', 'custom'], expect_pass=False)
splitout = self.capsys.readouterr().out.split('\n')
for line in splitout:
if line.strip().startswith('dbt_project.yml file'):

View File

@@ -66,7 +66,7 @@ class ServerProcess(dbt.flags.MP_CONTEXT.Process):
def start(self):
super().start()
for _ in range(180):
for _ in range(240):
if self.is_up():
break
time.sleep(0.5)
@@ -87,11 +87,11 @@ def query_url(url, query):
return requests.post(url, headers=headers, data=json.dumps(query))
_select_from_ephemeral = '''with __dbt__CTE__ephemeral_model as (
_select_from_ephemeral = '''with __dbt__cte__ephemeral_model as (
select 1 as id
)select * from __dbt__CTE__ephemeral_model'''
)select * from __dbt__cte__ephemeral_model'''
def addr_in_use(err, *args):

View File

@@ -6,7 +6,7 @@ from contextlib import contextmanager
from requests.exceptions import ConnectionError
from unittest.mock import patch, MagicMock, Mock, create_autospec, ANY
import hologram
import dbt.dataclass_schema
import dbt.flags as flags
@@ -19,6 +19,7 @@ from dbt.adapters.base.query_headers import MacroQueryStringSetter
from dbt.clients import agate_helper
import dbt.exceptions
from dbt.logger import GLOBAL_LOGGER as logger # noqa
from dbt.context.providers import RuntimeConfigObject
import google.cloud.bigquery
@@ -364,7 +365,7 @@ class TestBigQueryRelation(unittest.TestCase):
'identifier': False
}
}
BigQueryRelation.from_dict(kwargs)
BigQueryRelation.validate(kwargs)
def test_view_relation(self):
kwargs = {
@@ -379,7 +380,7 @@ class TestBigQueryRelation(unittest.TestCase):
'schema': True
}
}
BigQueryRelation.from_dict(kwargs)
BigQueryRelation.validate(kwargs)
def test_table_relation(self):
kwargs = {
@@ -394,7 +395,7 @@ class TestBigQueryRelation(unittest.TestCase):
'schema': True
}
}
BigQueryRelation.from_dict(kwargs)
BigQueryRelation.validate(kwargs)
def test_external_source_relation(self):
kwargs = {
@@ -409,7 +410,7 @@ class TestBigQueryRelation(unittest.TestCase):
'schema': True
}
}
BigQueryRelation.from_dict(kwargs)
BigQueryRelation.validate(kwargs)
def test_invalid_relation(self):
kwargs = {
@@ -424,8 +425,8 @@ class TestBigQueryRelation(unittest.TestCase):
'schema': True
}
}
with self.assertRaises(hologram.ValidationError):
BigQueryRelation.from_dict(kwargs)
with self.assertRaises(dbt.dataclass_schema.ValidationError):
BigQueryRelation.validate(kwargs)
class TestBigQueryInformationSchema(unittest.TestCase):
@@ -451,6 +452,7 @@ class TestBigQueryInformationSchema(unittest.TestCase):
'identifier': True,
}
}
BigQueryRelation.validate(kwargs)
relation = BigQueryRelation.from_dict(kwargs)
info_schema = relation.information_schema()
@@ -808,7 +810,7 @@ class TestBigQueryAdapter(BaseTestBigQueryAdapter):
def test_hours_to_expiration(self):
adapter = self.get_adapter('oauth')
mock_config = create_autospec(
dbt.context.providers.RuntimeConfigObject)
RuntimeConfigObject)
config = {'hours_to_expiration': 4}
mock_config.get.side_effect = lambda name: config.get(name)
@@ -822,7 +824,7 @@ class TestBigQueryAdapter(BaseTestBigQueryAdapter):
def test_hours_to_expiration_temporary(self):
adapter = self.get_adapter('oauth')
mock_config = create_autospec(
dbt.context.providers.RuntimeConfigObject)
RuntimeConfigObject)
config={'hours_to_expiration': 4}
mock_config.get.side_effect = lambda name: config.get(name)

View File

@@ -80,7 +80,7 @@ class CompilerTest(unittest.TestCase):
def mock_generate_runtime_model_context(model, config, manifest):
def ref(name):
result = f'__dbt__CTE__{name}'
result = f'__dbt__cte__{name}'
unique_id = f'model.root.{name}'
model.extra_ctes.append(InjectedCTE(id=unique_id, sql=None))
return result
@@ -121,7 +121,7 @@ class CompilerTest(unittest.TestCase):
extra_ctes=[InjectedCTE(id='model.root.ephemeral', sql='select * from source_table')],
compiled_sql=(
'with cte as (select * from something_else) '
'select * from __dbt__CTE__ephemeral'),
'select * from __dbt__cte__ephemeral'),
checksum=FileHash.from_contents(''),
),
'model.root.ephemeral': CompiledModelNode(
@@ -168,10 +168,10 @@ class CompilerTest(unittest.TestCase):
self.assertEqual(result.extra_ctes_injected, True)
self.assertEqualIgnoreWhitespace(
result.compiled_sql,
('with __dbt__CTE__ephemeral as ('
('with __dbt__cte__ephemeral as ('
'select * from source_table'
'), cte as (select * from something_else) '
'select * from __dbt__CTE__ephemeral'))
'select * from __dbt__cte__ephemeral'))
self.assertEqual(
manifest.nodes['model.root.ephemeral'].extra_ctes_injected,
@@ -296,7 +296,7 @@ class CompilerTest(unittest.TestCase):
compiled=True,
extra_ctes_injected=False,
extra_ctes=[InjectedCTE(id='model.root.ephemeral', sql='select * from source_table')],
compiled_sql='select * from __dbt__CTE__ephemeral',
compiled_sql='select * from __dbt__cte__ephemeral',
checksum=FileHash.from_contents(''),
),
'model.root.ephemeral': CompiledModelNode(
@@ -345,10 +345,10 @@ class CompilerTest(unittest.TestCase):
self.assertTrue(result.extra_ctes_injected)
self.assertEqualIgnoreWhitespace(
result.compiled_sql,
('with __dbt__CTE__ephemeral as ('
('with __dbt__cte__ephemeral as ('
'select * from source_table'
') '
'select * from __dbt__CTE__ephemeral'))
'select * from __dbt__cte__ephemeral'))
print(f"\n---- line 349 ----")
self.assertFalse(manifest.nodes['model.root.ephemeral'].extra_ctes_injected)
@@ -423,7 +423,7 @@ class CompilerTest(unittest.TestCase):
compiled=True,
extra_ctes_injected=False,
extra_ctes=[InjectedCTE(id='model.root.ephemeral', sql='select * from source_table')],
compiled_sql='select * from __dbt__CTE__ephemeral',
compiled_sql='select * from __dbt__cte__ephemeral',
checksum=FileHash.from_contents(''),
),
'model.root.ephemeral': parsed_ephemeral,
@@ -454,10 +454,10 @@ class CompilerTest(unittest.TestCase):
self.assertTrue(result.extra_ctes_injected)
self.assertEqualIgnoreWhitespace(
result.compiled_sql,
('with __dbt__CTE__ephemeral as ('
('with __dbt__cte__ephemeral as ('
'select * from source_table'
') '
'select * from __dbt__CTE__ephemeral'))
'select * from __dbt__cte__ephemeral'))
self.assertTrue(manifest.nodes['model.root.ephemeral'].extra_ctes_injected)
@@ -488,7 +488,7 @@ class CompilerTest(unittest.TestCase):
compiled=True,
extra_ctes_injected=False,
extra_ctes=[InjectedCTE(id='model.root.ephemeral', sql=None)],
compiled_sql='select * from __dbt__CTE__ephemeral',
compiled_sql='select * from __dbt__cte__ephemeral',
checksum=FileHash.from_contents(''),
),
@@ -552,12 +552,12 @@ class CompilerTest(unittest.TestCase):
self.assertTrue(result.extra_ctes_injected)
self.assertEqualIgnoreWhitespace(
result.compiled_sql,
('with __dbt__CTE__ephemeral_level_two as ('
('with __dbt__cte__ephemeral_level_two as ('
'select * from source_table'
'), __dbt__CTE__ephemeral as ('
'select * from __dbt__CTE__ephemeral_level_two'
'), __dbt__cte__ephemeral as ('
'select * from __dbt__cte__ephemeral_level_two'
') '
'select * from __dbt__CTE__ephemeral'))
'select * from __dbt__cte__ephemeral'))
self.assertTrue(manifest.nodes['model.root.ephemeral'].compiled)
self.assertTrue(manifest.nodes['model.root.ephemeral_level_two'].compiled)

View File

@@ -415,16 +415,6 @@ def test_query_header_context(config, manifest_fx):
assert_has_keys(REQUIRED_QUERY_HEADER_KEYS, MAYBE_KEYS, ctx)
def test_macro_parse_context(config, manifest_fx, get_adapter, get_include_paths):
ctx = providers.generate_parser_macro(
macro=manifest_fx.macros['macro.root.macro_a'],
config=config,
manifest=manifest_fx,
package_name='root',
)
assert_has_keys(REQUIRED_MACRO_KEYS, MAYBE_KEYS, ctx)
def test_macro_runtime_context(config, manifest_fx, get_adapter, get_include_paths):
ctx = providers.generate_runtime_macro(
macro=manifest_fx.macros['macro.root.macro_a'],

View File

@@ -16,6 +16,7 @@ from .utils import (
assert_fails_validation,
dict_replace,
replace_config,
compare_dicts,
)

Some files were not shown because too many files have changed in this diff Show More