mirror of
https://github.com/dbt-labs/dbt-core
synced 2025-12-18 22:11:27 +00:00
Compare commits
5 Commits
enable-pos
...
experiment
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b7cfd70ba3 | ||
|
|
1ba1fb964b | ||
|
|
98f573ef29 | ||
|
|
c744cca96f | ||
|
|
76571819f3 |
@@ -5,10 +5,11 @@ from hologram import JsonSchemaMixin
|
||||
from dbt.exceptions import RuntimeException
|
||||
|
||||
from typing import Dict, ClassVar, Any, Optional
|
||||
from dbt.contracts.jsonschema import dbtClassMixin
|
||||
|
||||
|
||||
@dataclass
|
||||
class Column(JsonSchemaMixin):
|
||||
class Column(dbtClassMixin):
|
||||
TYPE_LABELS: ClassVar[Dict[str, str]] = {
|
||||
'STRING': 'TEXT',
|
||||
'TIMESTAMP': 'TIMESTAMP',
|
||||
|
||||
@@ -47,13 +47,16 @@ class BaseRelation(FakeAPIObject, Hashable):
|
||||
return False
|
||||
return self.to_dict() == other.to_dict()
|
||||
|
||||
# TODO : Unclear why we're leveraging a type system to implement inheritance?
|
||||
@classmethod
|
||||
def get_default_quote_policy(cls) -> Policy:
|
||||
return cls._get_field_named('quote_policy').default
|
||||
#return cls._get_field_named('quote_policy').default
|
||||
return Policy()
|
||||
|
||||
@classmethod
|
||||
def get_default_include_policy(cls) -> Policy:
|
||||
return cls._get_field_named('include_policy').default
|
||||
#return cls._get_field_named('include_policy').default
|
||||
return Policy()
|
||||
|
||||
def get(self, key, default=None):
|
||||
"""Override `.get` to return a metadata object so we don't break
|
||||
|
||||
@@ -137,7 +137,7 @@ class Profile(HasCredentials):
|
||||
def validate(self):
|
||||
try:
|
||||
if self.credentials:
|
||||
self.credentials.to_dict(validate=True)
|
||||
self.credentials.serialize(validate=True)
|
||||
ProfileConfig.from_dict(
|
||||
self.to_profile_info(serialize_credentials=True)
|
||||
)
|
||||
|
||||
@@ -306,7 +306,7 @@ class PartialProject(RenderComponents):
|
||||
)
|
||||
|
||||
try:
|
||||
cfg = ProjectContract.from_dict(rendered.project_dict)
|
||||
cfg = ProjectContract.deserialize(rendered.project_dict)
|
||||
except ValidationError as e:
|
||||
raise DbtProjectError(validator_error_message(e)) from e
|
||||
# name/version are required in the Project definition, so we can assume
|
||||
@@ -586,7 +586,9 @@ class Project:
|
||||
|
||||
def validate(self):
|
||||
try:
|
||||
ProjectContract.from_dict(self.to_project_config())
|
||||
# TODO : Jank; need to do this to handle aliasing between hyphens and underscores
|
||||
as_dict = self.to_project_config()
|
||||
ProjectContract.deserialize(as_dict)
|
||||
except ValidationError as e:
|
||||
raise DbtProjectError(validator_error_message(e)) from e
|
||||
|
||||
|
||||
@@ -174,7 +174,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
||||
:raises DbtProjectError: If the configuration fails validation.
|
||||
"""
|
||||
try:
|
||||
Configuration.from_dict(self.serialize())
|
||||
Configuration.deserialize(self.serialize())
|
||||
except ValidationError as e:
|
||||
raise DbtProjectError(validator_error_message(e)) from e
|
||||
|
||||
|
||||
@@ -165,7 +165,9 @@ class ContextConfigGenerator(BaseContextConfigGenerator[C]):
|
||||
# Calculate the defaults. We don't want to validate the defaults,
|
||||
# because it might be invalid in the case of required config members
|
||||
# (such as on snapshots!)
|
||||
result = config_cls.from_dict({}, validate=False)
|
||||
result = config_cls.from_dict({})
|
||||
# TODO - why validate here?
|
||||
# result.validate()
|
||||
return result
|
||||
|
||||
def _update_from_config(
|
||||
|
||||
@@ -17,9 +17,12 @@ from dbt.utils import translate_aliases
|
||||
|
||||
from dbt.logger import GLOBAL_LOGGER as logger
|
||||
|
||||
from dbt.contracts.jsonschema import dbtClassMixin, ValidatedStringMixin
|
||||
from mashumaro.types import SerializableType
|
||||
|
||||
Identifier = NewType('Identifier', str)
|
||||
register_pattern(Identifier, r'^[A-Za-z_][A-Za-z0-9_]+$')
|
||||
|
||||
class Identifier(ValidatedStringMixin):
|
||||
ValidationRegex = r'^[A-Za-z_][A-Za-z0-9_]+$'
|
||||
|
||||
|
||||
class ConnectionState(StrEnum):
|
||||
@@ -28,22 +31,30 @@ class ConnectionState(StrEnum):
|
||||
CLOSED = 'closed'
|
||||
FAIL = 'fail'
|
||||
|
||||
# I think that... this is not the right way to do this
|
||||
# TODO!!
|
||||
class DoNotSerializeType(SerializableType):
|
||||
def _serialize(self) -> None:
|
||||
return None
|
||||
|
||||
# TODO : Figure out ExtensibleJsonSchemaMixin?
|
||||
@dataclass(init=False)
|
||||
class Connection(ExtensibleJsonSchemaMixin, Replaceable):
|
||||
class Connection(dbtClassMixin, Replaceable):
|
||||
type: Identifier
|
||||
name: Optional[str]
|
||||
state: ConnectionState = ConnectionState.INIT
|
||||
transaction_open: bool = False
|
||||
# prevent serialization
|
||||
_handle: Optional[Any] = None
|
||||
_credentials: JsonSchemaMixin = field(init=False)
|
||||
#_handle: Optional[Any] = None
|
||||
#_credentials: dbtClassMixin = field(init=False)
|
||||
_handle: Optional[DoNotSerializeType] = None
|
||||
_credentials: Optional[DoNotSerializeType] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
type: Identifier,
|
||||
name: Optional[str],
|
||||
credentials: JsonSchemaMixin,
|
||||
credentials: dbtClassMixin,
|
||||
state: ConnectionState = ConnectionState.INIT,
|
||||
transaction_open: bool = False,
|
||||
handle: Optional[Any] = None,
|
||||
@@ -102,7 +113,8 @@ class LazyHandle:
|
||||
# will work.
|
||||
@dataclass # type: ignore
|
||||
class Credentials(
|
||||
ExtensibleJsonSchemaMixin,
|
||||
# ExtensibleJsonSchemaMixin,
|
||||
dbtClassMixin,
|
||||
Replaceable,
|
||||
metaclass=abc.ABCMeta
|
||||
):
|
||||
@@ -121,7 +133,8 @@ class Credentials(
|
||||
) -> Iterable[Tuple[str, Any]]:
|
||||
"""Return an ordered iterator of key/value pairs for pretty-printing.
|
||||
"""
|
||||
as_dict = self.to_dict(omit_none=False, with_aliases=with_aliases)
|
||||
# TODO: Does this... work?
|
||||
as_dict = self.serialize(omit_none=False, with_aliases=with_aliases)
|
||||
connection_keys = set(self._connection_keys())
|
||||
aliases: List[str] = []
|
||||
if with_aliases:
|
||||
@@ -136,10 +149,11 @@ class Credentials(
|
||||
def _connection_keys(self) -> Tuple[str, ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data):
|
||||
data = cls.translate_aliases(data)
|
||||
return super().from_dict(data)
|
||||
# TODO TODO TODO
|
||||
# @classmethod
|
||||
# def from_dict(cls, data):
|
||||
# data = cls.translate_aliases(data)
|
||||
# return super().from_dict(data)
|
||||
|
||||
@classmethod
|
||||
def translate_aliases(
|
||||
@@ -147,15 +161,16 @@ class Credentials(
|
||||
) -> Dict[str, Any]:
|
||||
return translate_aliases(kwargs, cls._ALIASES, recurse)
|
||||
|
||||
def to_dict(self, omit_none=True, validate=False, *, with_aliases=False):
|
||||
serialized = super().to_dict(omit_none=omit_none, validate=validate)
|
||||
if with_aliases:
|
||||
serialized.update({
|
||||
new_name: serialized[canonical_name]
|
||||
for new_name, canonical_name in self._ALIASES.items()
|
||||
if canonical_name in serialized
|
||||
})
|
||||
return serialized
|
||||
# TODO TODO TODO
|
||||
# def to_dict(self, omit_none=True, validate=False, *, with_aliases=False):
|
||||
# serialized = super().to_dict(omit_none=omit_none, validate=validate)
|
||||
# if with_aliases:
|
||||
# serialized.update({
|
||||
# new_name: serialized[canonical_name]
|
||||
# for new_name, canonical_name in self._ALIASES.items()
|
||||
# if canonical_name in serialized
|
||||
# })
|
||||
# return serialized
|
||||
|
||||
|
||||
class UserConfigContract(Protocol):
|
||||
@@ -205,7 +220,7 @@ DEFAULT_QUERY_COMMENT = '''
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryComment(JsonSchemaMixin):
|
||||
class QueryComment(dbtClassMixin):
|
||||
comment: str = DEFAULT_QUERY_COMMENT
|
||||
append: bool = False
|
||||
|
||||
|
||||
@@ -9,13 +9,15 @@ from dbt.exceptions import InternalException
|
||||
|
||||
from .util import MacroKey, SourceKey
|
||||
|
||||
from dbt.contracts.jsonschema import dbtClassMixin
|
||||
|
||||
|
||||
MAXIMUM_SEED_SIZE = 1 * 1024 * 1024
|
||||
MAXIMUM_SEED_SIZE_NAME = '1MB'
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilePath(JsonSchemaMixin):
|
||||
class FilePath(dbtClassMixin):
|
||||
searched_path: str
|
||||
relative_path: str
|
||||
project_root: str
|
||||
@@ -51,7 +53,7 @@ class FilePath(JsonSchemaMixin):
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileHash(JsonSchemaMixin):
|
||||
class FileHash(dbtClassMixin):
|
||||
name: str # the hash type name
|
||||
checksum: str # the hashlib.hash_type().hexdigest() of the file contents
|
||||
|
||||
@@ -91,7 +93,7 @@ class FileHash(JsonSchemaMixin):
|
||||
|
||||
|
||||
@dataclass
|
||||
class RemoteFile(JsonSchemaMixin):
|
||||
class RemoteFile(dbtClassMixin):
|
||||
@property
|
||||
def searched_path(self) -> str:
|
||||
return 'from remote system'
|
||||
@@ -110,7 +112,7 @@ class RemoteFile(JsonSchemaMixin):
|
||||
|
||||
|
||||
@dataclass
|
||||
class SourceFile(JsonSchemaMixin):
|
||||
class SourceFile(dbtClassMixin):
|
||||
"""Define a source file in dbt"""
|
||||
path: Union[FilePath, RemoteFile] # the path information
|
||||
checksum: FileHash
|
||||
|
||||
@@ -23,15 +23,17 @@ from hologram import JsonSchemaMixin
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, List, Union, Dict, Type
|
||||
|
||||
from dbt.contracts.jsonschema import dbtClassMixin
|
||||
|
||||
|
||||
@dataclass
|
||||
class InjectedCTE(JsonSchemaMixin, Replaceable):
|
||||
class InjectedCTE(dbtClassMixin, Replaceable):
|
||||
id: str
|
||||
sql: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompiledNodeMixin(JsonSchemaMixin):
|
||||
class CompiledNodeMixin(dbtClassMixin):
|
||||
# this is a special mixin class to provide a required argument. If a node
|
||||
# is missing a `compiled` flag entirely, it must not be a CompiledNode.
|
||||
compiled: bool
|
||||
@@ -179,7 +181,8 @@ def parsed_instance_for(compiled: CompiledNode) -> ParsedResource:
|
||||
.format(compiled.resource_type))
|
||||
|
||||
# validate=False to allow extra keys from compiling
|
||||
return cls.from_dict(compiled.to_dict(), validate=False)
|
||||
# TODO : This is probably not going to do the right thing....
|
||||
return cls.deserialize(compiled.to_dict(), validate=False)
|
||||
|
||||
|
||||
NonSourceCompiledNode = Union[
|
||||
|
||||
@@ -508,10 +508,10 @@ class Manifest:
|
||||
"""
|
||||
self.flat_graph = {
|
||||
'nodes': {
|
||||
k: v.to_dict(omit_none=False) for k, v in self.nodes.items()
|
||||
k: v.serialize(omit_none=False) for k, v in self.nodes.items()
|
||||
},
|
||||
'sources': {
|
||||
k: v.to_dict(omit_none=False) for k, v in self.sources.items()
|
||||
k: v.serialize(omit_none=False) for k, v in self.sources.items()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -764,7 +764,7 @@ class Manifest:
|
||||
)
|
||||
|
||||
def to_dict(self, omit_none=True, validate=False):
|
||||
return self.writable_manifest().to_dict(
|
||||
return self.writable_manifest().serialize(
|
||||
omit_none=omit_none, validate=validate
|
||||
)
|
||||
|
||||
|
||||
@@ -21,6 +21,9 @@ from dbt.contracts.util import Replaceable, list_str
|
||||
from dbt import hooks
|
||||
from dbt.node_types import NodeType
|
||||
|
||||
from dbt.contracts.jsonschema import dbtClassMixin
|
||||
from mashumaro.types import SerializableType
|
||||
|
||||
|
||||
M = TypeVar('M', bound='Metadata')
|
||||
|
||||
@@ -170,9 +173,19 @@ def insensitive_patterns(*patterns: str):
|
||||
return '^({})$'.format('|'.join(lowercased))
|
||||
|
||||
|
||||
# TODO?
|
||||
Severity = NewType('Severity', str)
|
||||
|
||||
register_pattern(Severity, insensitive_patterns('warn', 'error'))
|
||||
class Severity(str, SerializableType):
|
||||
@classmethod
|
||||
def _deserialize(cls, value: str) -> 'Severity':
|
||||
# TODO : Validate here?
|
||||
return Severity(value)
|
||||
|
||||
def _serialize(self) -> str:
|
||||
# TODO : Validate here?
|
||||
return self
|
||||
|
||||
|
||||
class SnapshotStrategy(StrEnum):
|
||||
@@ -185,7 +198,7 @@ class All(StrEnum):
|
||||
|
||||
|
||||
@dataclass
|
||||
class Hook(JsonSchemaMixin, Replaceable):
|
||||
class Hook(dbtClassMixin, Replaceable):
|
||||
sql: str
|
||||
transaction: bool = True
|
||||
index: Optional[int] = None
|
||||
@@ -196,7 +209,7 @@ T = TypeVar('T', bound='BaseConfig')
|
||||
|
||||
@dataclass
|
||||
class BaseConfig(
|
||||
AdditionalPropertiesAllowed, Replaceable, MutableMapping[str, Any]
|
||||
dbtClassMixin, Replaceable, MutableMapping[str, Any]
|
||||
):
|
||||
# Implement MutableMapping so this config will behave as some macros expect
|
||||
# during parsing (notably, syntax like `{{ node.config['schema'] }}`)
|
||||
@@ -294,23 +307,25 @@ class BaseConfig(
|
||||
"""
|
||||
result = {}
|
||||
|
||||
for fld, target_field in cls._get_fields():
|
||||
if target_field not in data:
|
||||
continue
|
||||
# TODO : This is not correct.... must implement without reflection
|
||||
|
||||
data_attr = data.pop(target_field)
|
||||
if target_field not in src:
|
||||
result[target_field] = data_attr
|
||||
continue
|
||||
# for fld, target_field in cls._get_fields():
|
||||
# if target_field not in data:
|
||||
# continue
|
||||
|
||||
merge_behavior = MergeBehavior.from_field(fld)
|
||||
self_attr = src[target_field]
|
||||
# data_attr = data.pop(target_field)
|
||||
# if target_field not in src:
|
||||
# result[target_field] = data_attr
|
||||
# continue
|
||||
|
||||
result[target_field] = _merge_field_value(
|
||||
merge_behavior=merge_behavior,
|
||||
self_value=self_attr,
|
||||
other_value=data_attr,
|
||||
)
|
||||
# merge_behavior = MergeBehavior.from_field(fld)
|
||||
# self_attr = src[target_field]
|
||||
|
||||
# result[target_field] = _merge_field_value(
|
||||
# merge_behavior=merge_behavior,
|
||||
# self_value=self_attr,
|
||||
# other_value=data_attr,
|
||||
# )
|
||||
return result
|
||||
|
||||
def to_dict(
|
||||
@@ -320,7 +335,7 @@ class BaseConfig(
|
||||
*,
|
||||
omit_hidden: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
result = super().to_dict(omit_none=omit_none, validate=validate)
|
||||
result = super().serialize(omit_none=omit_none, validate=validate)
|
||||
if omit_hidden and not omit_none:
|
||||
for fld, target_field in self._get_fields():
|
||||
if target_field not in result:
|
||||
@@ -344,7 +359,9 @@ class BaseConfig(
|
||||
"""
|
||||
# sadly, this is a circular import
|
||||
from dbt.adapters.factory import get_config_class_by_name
|
||||
dct = self.to_dict(omit_none=False, validate=False, omit_hidden=False)
|
||||
# TODO : omit_hidden?
|
||||
# dct = self.serialize(omit_none=False, validate=False, omit_hidden=False)
|
||||
dct = self.serialize(omit_none=False, validate=False)
|
||||
|
||||
adapter_config_cls = get_config_class_by_name(adapter_type)
|
||||
|
||||
@@ -358,11 +375,11 @@ class BaseConfig(
|
||||
dct.update(data)
|
||||
|
||||
# any validation failures must have come from the update
|
||||
return self.from_dict(dct, validate=validate)
|
||||
return self.deserialize(dct, validate=validate)
|
||||
|
||||
def finalize_and_validate(self: T) -> T:
|
||||
# from_dict will validate for us
|
||||
dct = self.to_dict(omit_none=False, validate=False)
|
||||
dct = self.serialize(omit_none=False, validate=False)
|
||||
return self.from_dict(dct)
|
||||
|
||||
def replace(self, **kwargs):
|
||||
@@ -372,7 +389,7 @@ class BaseConfig(
|
||||
for key, value in kwargs.items():
|
||||
new_key = mapping.get(key, key)
|
||||
dct[new_key] = value
|
||||
return self.from_dict(dct, validate=False)
|
||||
return self.deserialize(dct, validate=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -435,12 +452,15 @@ class NodeConfig(BaseConfig):
|
||||
for key in hooks.ModelHookType:
|
||||
if key in data:
|
||||
data[key] = [hooks.get_hook_dict(h) for h in data[key]]
|
||||
return super().from_dict(data, validate=validate)
|
||||
return super().deserialize(data, validate=validate)
|
||||
|
||||
@classmethod
|
||||
def field_mapping(cls):
|
||||
return {'post_hook': 'post-hook', 'pre_hook': 'pre-hook'}
|
||||
|
||||
def validate(self):
|
||||
# TODO : Not implemented!
|
||||
pass
|
||||
|
||||
@dataclass
|
||||
class SeedConfig(NodeConfig):
|
||||
@@ -454,63 +474,10 @@ class TestConfig(NodeConfig):
|
||||
severity: Severity = Severity('ERROR')
|
||||
|
||||
|
||||
SnapshotVariants = Union[
|
||||
'TimestampSnapshotConfig',
|
||||
'CheckSnapshotConfig',
|
||||
'GenericSnapshotConfig',
|
||||
]
|
||||
|
||||
|
||||
def _relevance_without_strategy(error: jsonschema.ValidationError):
|
||||
# calculate the 'relevance' of an error the normal jsonschema way, except
|
||||
# if the validator is in the 'strategy' field and its conflicting with the
|
||||
# 'enum'. This suppresses `"'timestamp' is not one of ['check']` and such
|
||||
if 'strategy' in error.path and error.validator in {'enum', 'not'}:
|
||||
length = 1
|
||||
else:
|
||||
length = -len(error.path)
|
||||
validator = error.validator
|
||||
return length, validator not in {'anyOf', 'oneOf'}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SnapshotWrapper(JsonSchemaMixin):
|
||||
"""This is a little wrapper to let us serialize/deserialize the
|
||||
SnapshotVariants union.
|
||||
"""
|
||||
config: SnapshotVariants # mypy: ignore
|
||||
|
||||
@classmethod
|
||||
def validate(cls, data: Any):
|
||||
config = data.get('config', {})
|
||||
|
||||
if config.get('strategy') == 'check':
|
||||
schema = _validate_schema(CheckSnapshotConfig)
|
||||
to_validate = config
|
||||
|
||||
elif config.get('strategy') == 'timestamp':
|
||||
schema = _validate_schema(TimestampSnapshotConfig)
|
||||
to_validate = config
|
||||
|
||||
else:
|
||||
h_cls = cast(Hashable, cls)
|
||||
schema = _validate_schema(h_cls)
|
||||
to_validate = data
|
||||
|
||||
validator = jsonschema.Draft7Validator(schema)
|
||||
|
||||
error = jsonschema.exceptions.best_match(
|
||||
validator.iter_errors(to_validate),
|
||||
key=_relevance_without_strategy,
|
||||
)
|
||||
|
||||
if error is not None:
|
||||
raise ValidationError.create_from(error) from error
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmptySnapshotConfig(NodeConfig):
|
||||
materialized: str = 'snapshot'
|
||||
strategy: str = None
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
@@ -519,117 +486,17 @@ class SnapshotConfig(EmptySnapshotConfig):
|
||||
target_schema: str = field(init=False, metadata=dict(init_required=True))
|
||||
target_database: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
unique_key: str,
|
||||
target_schema: str,
|
||||
target_database: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> None:
|
||||
self.unique_key = unique_key
|
||||
self.target_schema = target_schema
|
||||
self.target_database = target_database
|
||||
# kwargs['materialized'] = materialized
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# type hacks...
|
||||
@classmethod
|
||||
def _get_fields(cls) -> List[Tuple[Field, str]]: # type: ignore
|
||||
fields: List[Tuple[Field, str]] = []
|
||||
for old_field, name in super()._get_fields():
|
||||
new_field = old_field
|
||||
# tell hologram we're really an initvar
|
||||
if old_field.metadata and old_field.metadata.get('init_required'):
|
||||
new_field = field(init=True, metadata=old_field.metadata)
|
||||
new_field.name = old_field.name
|
||||
new_field.type = old_field.type
|
||||
new_field._field_type = old_field._field_type # type: ignore
|
||||
fields.append((new_field, name))
|
||||
return fields
|
||||
|
||||
def finalize_and_validate(self: 'SnapshotConfig') -> SnapshotVariants:
|
||||
data = self.to_dict()
|
||||
return SnapshotWrapper.from_dict({'config': data}).config
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
class GenericSnapshotConfig(SnapshotConfig):
|
||||
strategy: str = field(init=False, metadata=dict(init_required=True))
|
||||
|
||||
def __init__(self, strategy: str, **kwargs) -> None:
|
||||
self.strategy = strategy
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def _collect_json_schema(
|
||||
cls, definitions: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
# this is the method you want to override in hologram if you want
|
||||
# to do clever things about the json schema and have classes that
|
||||
# contain instances of your JsonSchemaMixin respect the change.
|
||||
schema = super()._collect_json_schema(definitions)
|
||||
|
||||
# Instead of just the strategy we'd calculate normally, say
|
||||
# "this strategy except none of our specialization strategies".
|
||||
strategies = [schema['properties']['strategy']]
|
||||
for specialization in (TimestampSnapshotConfig, CheckSnapshotConfig):
|
||||
strategies.append(
|
||||
{'not': specialization.json_schema()['properties']['strategy']}
|
||||
)
|
||||
|
||||
schema['properties']['strategy'] = {
|
||||
'allOf': strategies
|
||||
}
|
||||
return schema
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
class TimestampSnapshotConfig(SnapshotConfig):
|
||||
strategy: str = field(
|
||||
init=False,
|
||||
metadata=dict(
|
||||
restrict=[str(SnapshotStrategy.Timestamp)],
|
||||
init_required=True,
|
||||
),
|
||||
)
|
||||
updated_at: str = field(init=False, metadata=dict(init_required=True))
|
||||
|
||||
def __init__(
|
||||
self, strategy: str, updated_at: str, **kwargs
|
||||
) -> None:
|
||||
self.strategy = strategy
|
||||
self.updated_at = updated_at
|
||||
super().__init__(**kwargs)
|
||||
updated_at: str
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
class CheckSnapshotConfig(SnapshotConfig):
|
||||
strategy: str = field(
|
||||
init=False,
|
||||
metadata=dict(
|
||||
restrict=[str(SnapshotStrategy.Check)],
|
||||
init_required=True,
|
||||
),
|
||||
)
|
||||
# TODO: is there a way to get this to accept tuples of strings? Adding
|
||||
# `Tuple[str, ...]` to the list of types results in this:
|
||||
# ['email'] is valid under each of {'type': 'array', 'items':
|
||||
# {'type': 'string'}}, {'type': 'array', 'items': {'type': 'string'}}
|
||||
# but without it, parsing gets upset about values like `('email',)`
|
||||
# maybe hologram itself should support this behavior? It's not like tuples
|
||||
# are meaningful in json
|
||||
check_cols: Union[All, List[str]] = field(
|
||||
init=False,
|
||||
metadata=dict(init_required=True),
|
||||
)
|
||||
check_cols: Union[All, List[str]]
|
||||
|
||||
def __init__(
|
||||
self, strategy: str, check_cols: Union[All, List[str]],
|
||||
**kwargs
|
||||
) -> None:
|
||||
self.strategy = strategy
|
||||
self.check_cols = check_cols
|
||||
super().__init__(**kwargs)
|
||||
|
||||
class CheckSnapshotConfig(SnapshotConfig):
|
||||
pass
|
||||
|
||||
|
||||
RESOURCE_TYPES: Dict[NodeType, Type[BaseConfig]] = {
|
||||
|
||||
@@ -31,6 +31,8 @@ from dbt.logger import GLOBAL_LOGGER as logger # noqa
|
||||
from dbt import flags
|
||||
from dbt.node_types import NodeType
|
||||
|
||||
from dbt.contracts.jsonschema import dbtClassMixin
|
||||
|
||||
|
||||
from .model_config import (
|
||||
NodeConfig,
|
||||
@@ -38,20 +40,15 @@ from .model_config import (
|
||||
TestConfig,
|
||||
SourceConfig,
|
||||
EmptySnapshotConfig,
|
||||
SnapshotVariants,
|
||||
)
|
||||
# import these 3 so the SnapshotVariants forward ref works.
|
||||
from .model_config import ( # noqa
|
||||
TimestampSnapshotConfig,
|
||||
CheckSnapshotConfig,
|
||||
GenericSnapshotConfig,
|
||||
SnapshotConfig,
|
||||
)
|
||||
|
||||
|
||||
# TODO : Figure out AdditionalPropertiesMixin and ExtensibleJsonSchemaMixin
|
||||
@dataclass
|
||||
class ColumnInfo(
|
||||
AdditionalPropertiesMixin,
|
||||
ExtensibleJsonSchemaMixin,
|
||||
dbtClassMixin,
|
||||
#AdditionalPropertiesMixin,
|
||||
#ExtensibleJsonSchemaMixin,
|
||||
Replaceable
|
||||
):
|
||||
name: str
|
||||
@@ -64,7 +61,7 @@ class ColumnInfo(
|
||||
|
||||
|
||||
@dataclass
|
||||
class HasFqn(JsonSchemaMixin, Replaceable):
|
||||
class HasFqn(dbtClassMixin, Replaceable):
|
||||
fqn: List[str]
|
||||
|
||||
def same_fqn(self, other: 'HasFqn') -> bool:
|
||||
@@ -72,12 +69,12 @@ class HasFqn(JsonSchemaMixin, Replaceable):
|
||||
|
||||
|
||||
@dataclass
|
||||
class HasUniqueID(JsonSchemaMixin, Replaceable):
|
||||
class HasUniqueID(dbtClassMixin, Replaceable):
|
||||
unique_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class MacroDependsOn(JsonSchemaMixin, Replaceable):
|
||||
class MacroDependsOn(dbtClassMixin, Replaceable):
|
||||
macros: List[str] = field(default_factory=list)
|
||||
|
||||
# 'in' on lists is O(n) so this is O(n^2) for # of macros
|
||||
@@ -96,12 +93,12 @@ class DependsOn(MacroDependsOn):
|
||||
|
||||
|
||||
@dataclass
|
||||
class HasRelationMetadata(JsonSchemaMixin, Replaceable):
|
||||
class HasRelationMetadata(dbtClassMixin, Replaceable):
|
||||
database: Optional[str]
|
||||
schema: str
|
||||
|
||||
|
||||
class ParsedNodeMixins(JsonSchemaMixin):
|
||||
class ParsedNodeMixins(dbtClassMixin):
|
||||
resource_type: NodeType
|
||||
depends_on: DependsOn
|
||||
config: NodeConfig
|
||||
@@ -132,8 +129,8 @@ class ParsedNodeMixins(JsonSchemaMixin):
|
||||
self.meta = patch.meta
|
||||
self.docs = patch.docs
|
||||
if flags.STRICT_MODE:
|
||||
assert isinstance(self, JsonSchemaMixin)
|
||||
self.to_dict(validate=True, omit_none=False)
|
||||
assert isinstance(self, dbtClassMixin)
|
||||
self.serialize(validate=True, omit_none=False)
|
||||
|
||||
def get_materialization(self):
|
||||
return self.config.materialized
|
||||
@@ -335,14 +332,14 @@ class ParsedSeedNode(ParsedNode):
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestMetadata(JsonSchemaMixin, Replaceable):
|
||||
class TestMetadata(dbtClassMixin, Replaceable):
|
||||
namespace: Optional[str]
|
||||
name: str
|
||||
kwargs: Dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class HasTestMetadata(JsonSchemaMixin):
|
||||
class HasTestMetadata(dbtClassMixin):
|
||||
test_metadata: TestMetadata
|
||||
|
||||
|
||||
@@ -394,7 +391,7 @@ class IntermediateSnapshotNode(ParsedNode):
|
||||
@dataclass
|
||||
class ParsedSnapshotNode(ParsedNode):
|
||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Snapshot]})
|
||||
config: SnapshotVariants
|
||||
config: SnapshotConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -443,8 +440,8 @@ class ParsedMacro(UnparsedBaseNode, HasUniqueID):
|
||||
self.docs = patch.docs
|
||||
self.arguments = patch.arguments
|
||||
if flags.STRICT_MODE:
|
||||
assert isinstance(self, JsonSchemaMixin)
|
||||
self.to_dict(validate=True, omit_none=False)
|
||||
assert isinstance(self, dbtClassMixin)
|
||||
self.serialize(validate=True, omit_none=False)
|
||||
|
||||
def same_contents(self, other: Optional['ParsedMacro']) -> bool:
|
||||
if other is None:
|
||||
|
||||
@@ -16,9 +16,11 @@ from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Union, Dict, Any, Sequence
|
||||
|
||||
from dbt.contracts.jsonschema import dbtClassMixin
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnparsedBaseNode(JsonSchemaMixin, Replaceable):
|
||||
class UnparsedBaseNode(dbtClassMixin, Replaceable):
|
||||
package_name: str
|
||||
root_path: str
|
||||
path: str
|
||||
@@ -66,18 +68,20 @@ class UnparsedRunHook(UnparsedNode):
|
||||
|
||||
|
||||
@dataclass
|
||||
class Docs(JsonSchemaMixin, Replaceable):
|
||||
class Docs(dbtClassMixin, Replaceable):
|
||||
show: bool = True
|
||||
|
||||
|
||||
# TODO : This should have AdditionalPropertiesMixin and ExtensibleJsonSchemaMixin
|
||||
@dataclass
|
||||
class HasDocs(AdditionalPropertiesMixin, ExtensibleJsonSchemaMixin,
|
||||
Replaceable):
|
||||
class HasDocs(dbtClassMixin, Replaceable):
|
||||
name: str
|
||||
description: str = ''
|
||||
meta: Dict[str, Any] = field(default_factory=dict)
|
||||
data_type: Optional[str] = None
|
||||
docs: Docs = field(default_factory=Docs)
|
||||
|
||||
# TODO : How do we handle these additional fields with mashurmaro?
|
||||
_extra: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@@ -100,7 +104,7 @@ class UnparsedColumn(HasTests):
|
||||
|
||||
|
||||
@dataclass
|
||||
class HasColumnDocs(JsonSchemaMixin, Replaceable):
|
||||
class HasColumnDocs(dbtClassMixin, Replaceable):
|
||||
columns: Sequence[HasDocs] = field(default_factory=list)
|
||||
|
||||
|
||||
@@ -110,7 +114,7 @@ class HasColumnTests(HasColumnDocs):
|
||||
|
||||
|
||||
@dataclass
|
||||
class HasYamlMetadata(JsonSchemaMixin):
|
||||
class HasYamlMetadata(dbtClassMixin):
|
||||
original_file_path: str
|
||||
yaml_key: str
|
||||
package_name: str
|
||||
@@ -127,7 +131,7 @@ class UnparsedNodeUpdate(HasColumnTests, HasTests, HasYamlMetadata):
|
||||
|
||||
|
||||
@dataclass
|
||||
class MacroArgument(JsonSchemaMixin):
|
||||
class MacroArgument(dbtClassMixin):
|
||||
name: str
|
||||
type: Optional[str] = None
|
||||
description: str = ''
|
||||
@@ -148,7 +152,7 @@ class TimePeriod(StrEnum):
|
||||
|
||||
|
||||
@dataclass
|
||||
class Time(JsonSchemaMixin, Replaceable):
|
||||
class Time(dbtClassMixin, Replaceable):
|
||||
count: int
|
||||
period: TimePeriod
|
||||
|
||||
@@ -159,7 +163,7 @@ class Time(JsonSchemaMixin, Replaceable):
|
||||
|
||||
|
||||
@dataclass
|
||||
class FreshnessThreshold(JsonSchemaMixin, Mergeable):
|
||||
class FreshnessThreshold(dbtClassMixin, Mergeable):
|
||||
warn_after: Optional[Time] = None
|
||||
error_after: Optional[Time] = None
|
||||
filter: Optional[str] = None
|
||||
@@ -212,7 +216,7 @@ class ExternalTable(AdditionalPropertiesAllowed, Mergeable):
|
||||
|
||||
|
||||
@dataclass
|
||||
class Quoting(JsonSchemaMixin, Mergeable):
|
||||
class Quoting(dbtClassMixin, Mergeable):
|
||||
database: Optional[bool] = None
|
||||
schema: Optional[bool] = None
|
||||
identifier: Optional[bool] = None
|
||||
@@ -231,14 +235,15 @@ class UnparsedSourceTableDefinition(HasColumnTests, HasTests):
|
||||
tags: List[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self, omit_none=True, validate=False):
|
||||
result = super().to_dict(omit_none=omit_none, validate=validate)
|
||||
result = super().serialize(omit_none=omit_none, validate=validate)
|
||||
|
||||
if omit_none and self.freshness is None:
|
||||
result['freshness'] = None
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnparsedSourceDefinition(JsonSchemaMixin, Replaceable):
|
||||
class UnparsedSourceDefinition(dbtClassMixin, Replaceable):
|
||||
name: str
|
||||
description: str = ''
|
||||
meta: Dict[str, Any] = field(default_factory=dict)
|
||||
@@ -258,14 +263,14 @@ class UnparsedSourceDefinition(JsonSchemaMixin, Replaceable):
|
||||
return 'sources'
|
||||
|
||||
def to_dict(self, omit_none=True, validate=False):
|
||||
result = super().to_dict(omit_none=omit_none, validate=validate)
|
||||
result = super().serialize(omit_none=omit_none, validate=validate)
|
||||
if omit_none and self.freshness is None:
|
||||
result['freshness'] = None
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class SourceTablePatch(JsonSchemaMixin):
|
||||
class SourceTablePatch(dbtClassMixin):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
meta: Optional[Dict[str, Any]] = None
|
||||
@@ -283,7 +288,7 @@ class SourceTablePatch(JsonSchemaMixin):
|
||||
columns: Optional[Sequence[UnparsedColumn]] = None
|
||||
|
||||
def to_patch_dict(self) -> Dict[str, Any]:
|
||||
dct = self.to_dict(omit_none=True)
|
||||
dct = self.serialize(omit_none=True)
|
||||
remove_keys = ('name')
|
||||
for key in remove_keys:
|
||||
if key in dct:
|
||||
@@ -296,7 +301,7 @@ class SourceTablePatch(JsonSchemaMixin):
|
||||
|
||||
|
||||
@dataclass
|
||||
class SourcePatch(JsonSchemaMixin, Replaceable):
|
||||
class SourcePatch(dbtClassMixin, Replaceable):
|
||||
name: str = field(
|
||||
metadata=dict(description='The name of the source to override'),
|
||||
)
|
||||
@@ -320,7 +325,7 @@ class SourcePatch(JsonSchemaMixin, Replaceable):
|
||||
tags: Optional[List[str]] = None
|
||||
|
||||
def to_patch_dict(self) -> Dict[str, Any]:
|
||||
dct = self.to_dict(omit_none=True)
|
||||
dct = self.serialize(omit_none=True)
|
||||
remove_keys = ('name', 'overrides', 'tables', 'path')
|
||||
for key in remove_keys:
|
||||
if key in dct:
|
||||
@@ -340,7 +345,7 @@ class SourcePatch(JsonSchemaMixin, Replaceable):
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnparsedDocumentation(JsonSchemaMixin, Replaceable):
|
||||
class UnparsedDocumentation(dbtClassMixin, Replaceable):
|
||||
package_name: str
|
||||
root_path: str
|
||||
path: str
|
||||
@@ -400,7 +405,7 @@ class MaturityType(StrEnum):
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExposureOwner(JsonSchemaMixin, Replaceable):
|
||||
class ExposureOwner(dbtClassMixin, Replaceable):
|
||||
email: str
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
66
core/dbt/contracts/jsonschema.py
Normal file
66
core/dbt/contracts/jsonschema.py
Normal file
@@ -0,0 +1,66 @@
|
||||
|
||||
from dataclasses import dataclass, fields, Field
|
||||
from typing import (
|
||||
Optional, TypeVar, Generic, Dict, get_type_hints, List, Tuple
|
||||
)
|
||||
from mashumaro import DataClassDictMixin
|
||||
from mashumaro.types import SerializableType
|
||||
import re
|
||||
|
||||
|
||||
"""
|
||||
This is a throwaway shim to match the JsonSchemaMixin interface
|
||||
|
||||
that downstream consumers (ie. PostgresRelation) is expecting.
|
||||
I imagine that we would try to remove code that depends on this type
|
||||
reflection if we pursue an approach like the one shown here
|
||||
"""
|
||||
class dbtClassMixin(DataClassDictMixin):
|
||||
@classmethod
|
||||
def field_mapping(cls) -> Dict[str, str]:
|
||||
"""Defines the mapping of python field names to JSON field names.
|
||||
|
||||
The main use-case is to allow JSON field names which are Python keywords
|
||||
"""
|
||||
return {}
|
||||
|
||||
def serialize(self, omit_none=False, validate=False, with_aliases: Optional[Dict[str, str]]=None):
|
||||
dct = self.to_dict()
|
||||
|
||||
if with_aliases:
|
||||
# TODO : Mutating these dicts is a TERRIBLE idea - remove this
|
||||
for aliased_name, canonical_name in self._ALIASES.items():
|
||||
if aliased_name in dct:
|
||||
dct[canonical_name] = dct.pop(aliased_name)
|
||||
|
||||
return dct
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, data, validate=False, with_aliases=False):
|
||||
if with_aliases:
|
||||
# TODO : Mutating these dicts is a TERRIBLE idea - remove this
|
||||
for aliased_name, canonical_name in cls._ALIASES.items():
|
||||
if aliased_name in data:
|
||||
data[canonical_name] = data.pop(aliased_name)
|
||||
|
||||
# TODO .... implement these?
|
||||
return cls.from_dict(data)
|
||||
|
||||
|
||||
class ValidatedStringMixin(str, SerializableType):
|
||||
ValidationRegex = None
|
||||
|
||||
@classmethod
|
||||
def _deserialize(cls, value: str) -> 'ValidatedStringMixin':
|
||||
cls.validate(value)
|
||||
return ValidatedStringMixin(value)
|
||||
|
||||
def _serialize(self) -> str:
|
||||
return str(self)
|
||||
|
||||
@classmethod
|
||||
def validate(cls, value) -> str:
|
||||
res = re.match(cls.ValidationRegex, value)
|
||||
|
||||
if res is None:
|
||||
raise ValidationError(f"Invalid value: {value}") # TODO
|
||||
@@ -12,12 +12,16 @@ from hologram.helpers import HyphenatedJsonSchemaMixin, register_pattern, \
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, List, Dict, Union, Any, NewType
|
||||
|
||||
from dbt.contracts.jsonschema import dbtClassMixin, ValidatedStringMixin
|
||||
from mashumaro.types import SerializableType
|
||||
|
||||
PIN_PACKAGE_URL = 'https://docs.getdbt.com/docs/package-management#section-specifying-package-versions' # noqa
|
||||
DEFAULT_SEND_ANONYMOUS_USAGE_STATS = True
|
||||
|
||||
|
||||
Name = NewType('Name', str)
|
||||
register_pattern(Name, r'^[^\d\W]\w*$')
|
||||
class Name(ValidatedStringMixin):
|
||||
ValidationRegex = r'^[^\d\W]\w*$'
|
||||
|
||||
|
||||
# this does not support the full semver (does not allow a trailing -fooXYZ) and
|
||||
# is not restrictive enough for full semver, (allows '1.0'). But it's like
|
||||
@@ -28,17 +32,26 @@ register_pattern(
|
||||
r'^(?:0|[1-9]\d*)\.(?:0|[1-9]\d*)(\.(?:0|[1-9]\d*))?$',
|
||||
)
|
||||
|
||||
class SemverString(str, SerializableType):
|
||||
def _serialize(self) -> str:
|
||||
return self
|
||||
|
||||
@dataclass
|
||||
class Quoting(JsonSchemaMixin, Mergeable):
|
||||
identifier: Optional[bool]
|
||||
schema: Optional[bool]
|
||||
database: Optional[bool]
|
||||
project: Optional[bool]
|
||||
@classmethod
|
||||
def _deserialize(cls, value: str) -> 'SemverString':
|
||||
return SemverString(value)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Package(Replaceable, HyphenatedJsonSchemaMixin):
|
||||
class Quoting(dbtClassMixin, Mergeable):
|
||||
identifier: Optional[bool] = None
|
||||
schema: Optional[bool] = None
|
||||
database: Optional[bool] = None
|
||||
project: Optional[bool] = None
|
||||
|
||||
|
||||
# TODO .... hyphenation.... what.... why
|
||||
@dataclass
|
||||
class Package(dbtClassMixin, Replaceable):
|
||||
pass
|
||||
|
||||
|
||||
@@ -80,7 +93,7 @@ PackageSpec = Union[LocalPackage, GitPackage, RegistryPackage]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PackageConfig(JsonSchemaMixin, Replaceable):
|
||||
class PackageConfig(dbtClassMixin, Replaceable):
|
||||
packages: List[PackageSpec]
|
||||
|
||||
|
||||
@@ -96,13 +109,13 @@ class ProjectPackageMetadata:
|
||||
|
||||
|
||||
@dataclass
|
||||
class Downloads(ExtensibleJsonSchemaMixin, Replaceable):
|
||||
class Downloads(dbtClassMixin, Replaceable):
|
||||
tarball: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class RegistryPackageMetadata(
|
||||
ExtensibleJsonSchemaMixin,
|
||||
dbtClassMixin,
|
||||
ProjectPackageMetadata,
|
||||
):
|
||||
downloads: Downloads
|
||||
@@ -153,7 +166,7 @@ BANNED_PROJECT_NAMES = {
|
||||
|
||||
|
||||
@dataclass
|
||||
class Project(HyphenatedJsonSchemaMixin, Replaceable):
|
||||
class Project(dbtClassMixin, Replaceable):
|
||||
name: Name
|
||||
version: Union[SemverString, float]
|
||||
config_version: int
|
||||
@@ -189,9 +202,30 @@ class Project(HyphenatedJsonSchemaMixin, Replaceable):
|
||||
packages: List[PackageSpec] = field(default_factory=list)
|
||||
query_comment: Optional[Union[QueryComment, NoValue, str]] = NoValue()
|
||||
|
||||
_ALIASES = {
|
||||
'config-version': 'config_version',
|
||||
'source-paths': 'source_paths',
|
||||
'macro-paths': 'macro_paths',
|
||||
'data-paths': 'data_paths',
|
||||
'test-paths': 'test_paths',
|
||||
'analysis-paths': 'analysis_paths',
|
||||
'docs-paths': 'docs_paths',
|
||||
'asset-paths': 'asset_paths',
|
||||
'target-path': 'target_path',
|
||||
'snapshot-paths': 'snapshot_paths',
|
||||
'clean-targets': 'clean_targets',
|
||||
'log-path': 'log_path',
|
||||
'modules-path': 'modules_path',
|
||||
'on-run-start': 'on_run_start',
|
||||
'on-run-end': 'on_run_end',
|
||||
'require-dbt-version': 'require_dbt_version',
|
||||
'project-root': 'project_root',
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data, validate=True) -> 'Project':
|
||||
result = super().from_dict(data, validate=validate)
|
||||
def deserialize(cls, data) -> 'Project':
|
||||
# TODO : Deserialize this with aliases - right now, only implemented for serializer?
|
||||
result = super().deserialize(data, with_aliases=True)
|
||||
if result.name in BANNED_PROJECT_NAMES:
|
||||
raise ValidationError(
|
||||
f'Invalid project name: {result.name} is a reserved word'
|
||||
@@ -200,8 +234,9 @@ class Project(HyphenatedJsonSchemaMixin, Replaceable):
|
||||
return result
|
||||
|
||||
|
||||
# TODO : Make extensible?
|
||||
@dataclass
|
||||
class UserConfig(ExtensibleJsonSchemaMixin, Replaceable, UserConfigContract):
|
||||
class UserConfig(dbtClassMixin, Replaceable, UserConfigContract):
|
||||
send_anonymous_usage_stats: bool = DEFAULT_SEND_ANONYMOUS_USAGE_STATS
|
||||
use_colors: Optional[bool] = None
|
||||
partial_parse: Optional[bool] = None
|
||||
@@ -221,7 +256,7 @@ class UserConfig(ExtensibleJsonSchemaMixin, Replaceable, UserConfigContract):
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProfileConfig(HyphenatedJsonSchemaMixin, Replaceable):
|
||||
class ProfileConfig(dbtClassMixin, Replaceable):
|
||||
profile_name: str = field(metadata={'preserve_underscore': True})
|
||||
target_name: str = field(metadata={'preserve_underscore': True})
|
||||
config: UserConfig
|
||||
@@ -234,8 +269,8 @@ class ProfileConfig(HyphenatedJsonSchemaMixin, Replaceable):
|
||||
class ConfiguredQuoting(Quoting, Replaceable):
|
||||
identifier: bool
|
||||
schema: bool
|
||||
database: Optional[bool]
|
||||
project: Optional[bool]
|
||||
database: Optional[bool] = None
|
||||
project: Optional[bool] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -248,5 +283,5 @@ class Configuration(Project, ProfileConfig):
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProjectList(JsonSchemaMixin):
|
||||
class ProjectList(dbtClassMixin):
|
||||
projects: Dict[str, Project]
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, fields
|
||||
from dataclasses import dataclass, fields, Field
|
||||
from typing import (
|
||||
Optional, TypeVar, Generic, Dict,
|
||||
Optional, TypeVar, Generic, Dict, get_type_hints, List, Tuple
|
||||
)
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from hologram import JsonSchemaMixin
|
||||
from hologram.helpers import StrEnum
|
||||
|
||||
from dbt import deprecations
|
||||
@@ -13,6 +12,8 @@ from dbt.contracts.util import Replaceable
|
||||
from dbt.exceptions import CompilationException
|
||||
from dbt.utils import deep_merge
|
||||
|
||||
from dbt.contracts.jsonschema import dbtClassMixin
|
||||
|
||||
|
||||
class RelationType(StrEnum):
|
||||
Table = 'table'
|
||||
@@ -32,7 +33,7 @@ class HasQuoting(Protocol):
|
||||
quoting: Dict[str, bool]
|
||||
|
||||
|
||||
class FakeAPIObject(JsonSchemaMixin, Replaceable, Mapping):
|
||||
class FakeAPIObject(dbtClassMixin, Replaceable, Mapping):
|
||||
# override the mapping truthiness, len is always >1
|
||||
def __bool__(self):
|
||||
return True
|
||||
@@ -58,16 +59,22 @@ class FakeAPIObject(JsonSchemaMixin, Replaceable, Mapping):
|
||||
return self.from_dict(value)
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
"""
|
||||
mashumaro does not support generic types, so I just duplicated
|
||||
the generic type class that was here previously to get things
|
||||
working. We'd probably want to do this differently
|
||||
in the future
|
||||
|
||||
TODO
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class _ComponentObject(FakeAPIObject, Generic[T]):
|
||||
database: T
|
||||
schema: T
|
||||
identifier: T
|
||||
class Policy(FakeAPIObject):
|
||||
database: bool = True
|
||||
schema: bool = True
|
||||
identifier: bool = True
|
||||
|
||||
def get_part(self, key: ComponentName) -> T:
|
||||
def get_part(self, key: ComponentName) -> bool:
|
||||
if key == ComponentName.Database:
|
||||
return self.database
|
||||
elif key == ComponentName.Schema:
|
||||
@@ -80,22 +87,15 @@ class _ComponentObject(FakeAPIObject, Generic[T]):
|
||||
.format(key, list(ComponentName))
|
||||
)
|
||||
|
||||
def replace_dict(self, dct: Dict[ComponentName, T]):
|
||||
kwargs: Dict[str, T] = {}
|
||||
def replace_dict(self, dct: Dict[ComponentName, bool]):
|
||||
kwargs: Dict[str, bool] = {}
|
||||
for k, v in dct.items():
|
||||
kwargs[str(k)] = v
|
||||
return self.replace(**kwargs)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Policy(_ComponentObject[bool]):
|
||||
database: bool = True
|
||||
schema: bool = True
|
||||
identifier: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class Path(_ComponentObject[Optional[str]]):
|
||||
class Path(FakeAPIObject):
|
||||
database: Optional[str]
|
||||
schema: Optional[str]
|
||||
identifier: Optional[str]
|
||||
@@ -120,3 +120,22 @@ class Path(_ComponentObject[Optional[str]]):
|
||||
if part is not None:
|
||||
part = part.lower()
|
||||
return part
|
||||
|
||||
def get_part(self, key: ComponentName) -> str:
|
||||
if key == ComponentName.Database:
|
||||
return self.database
|
||||
elif key == ComponentName.Schema:
|
||||
return self.schema
|
||||
elif key == ComponentName.Identifier:
|
||||
return self.identifier
|
||||
else:
|
||||
raise ValueError(
|
||||
'Got a key of {}, expected one of {}'
|
||||
.format(key, list(ComponentName))
|
||||
)
|
||||
|
||||
def replace_dict(self, dct: Dict[ComponentName, str]):
|
||||
kwargs: Dict[str, str] = {}
|
||||
for k, v in dct.items():
|
||||
kwargs[str(k)] = v
|
||||
return self.replace(**kwargs)
|
||||
|
||||
@@ -21,6 +21,9 @@ from dbt.utils import lowercase
|
||||
from hologram.helpers import StrEnum
|
||||
from hologram import JsonSchemaMixin
|
||||
|
||||
from dbt.contracts.jsonschema import dbtClassMixin
|
||||
from mashumaro.types import SerializableType
|
||||
|
||||
import agate
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
@@ -31,7 +34,7 @@ from dbt.clients.system import write_json
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimingInfo(JsonSchemaMixin):
|
||||
class TimingInfo(dbtClassMixin):
|
||||
name: str
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
@@ -88,7 +91,7 @@ class FreshnessStatus(StrEnum):
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseResult(JsonSchemaMixin):
|
||||
class BaseResult(dbtClassMixin):
|
||||
status: Union[RunStatus, TestStatus, FreshnessStatus]
|
||||
timing: List[TimingInfo]
|
||||
thread_id: str
|
||||
@@ -109,15 +112,14 @@ class PartialNodeResult(NodeResult):
|
||||
def skipped(self):
|
||||
return False
|
||||
|
||||
# TODO : Does this work? no-op agate table serialization in RunModelResult
|
||||
class SerializableAgateTable(agate.Table, SerializableType):
|
||||
def _serialize(self) -> None:
|
||||
return None
|
||||
|
||||
@dataclass
|
||||
class RunModelResult(NodeResult):
|
||||
agate_table: Optional[agate.Table] = None
|
||||
|
||||
def to_dict(self, *args, **kwargs):
|
||||
dct = super().to_dict(*args, **kwargs)
|
||||
dct.pop('agate_table', None)
|
||||
return dct
|
||||
agate_table: Optional[SerializableAgateTable] = None
|
||||
|
||||
@property
|
||||
def skipped(self):
|
||||
@@ -125,7 +127,7 @@ class RunModelResult(NodeResult):
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionResult(JsonSchemaMixin):
|
||||
class ExecutionResult(dbtClassMixin):
|
||||
results: Sequence[BaseResult]
|
||||
elapsed_time: float
|
||||
|
||||
@@ -210,7 +212,8 @@ class RunResultsArtifact(ExecutionResult, ArtifactMixin):
|
||||
)
|
||||
|
||||
def write(self, path: str, omit_none=False):
|
||||
write_json(path, self.to_dict(omit_none=omit_none))
|
||||
# TODO: Implement omit_none
|
||||
write_json(path, self.to_dict())
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -268,14 +271,14 @@ class FreshnessErrorEnum(StrEnum):
|
||||
|
||||
|
||||
@dataclass
|
||||
class SourceFreshnessRuntimeError(JsonSchemaMixin):
|
||||
class SourceFreshnessRuntimeError(dbtClassMixin):
|
||||
unique_id: str
|
||||
error: Optional[Union[str, int]]
|
||||
status: FreshnessErrorEnum
|
||||
|
||||
|
||||
@dataclass
|
||||
class SourceFreshnessOutput(JsonSchemaMixin):
|
||||
class SourceFreshnessOutput(dbtClassMixin):
|
||||
unique_id: str
|
||||
max_loaded_at: datetime
|
||||
snapshotted_at: datetime
|
||||
@@ -383,7 +386,7 @@ CatalogKey = NamedTuple(
|
||||
|
||||
|
||||
@dataclass
|
||||
class StatsItem(JsonSchemaMixin):
|
||||
class StatsItem(dbtClassMixin):
|
||||
id: str
|
||||
label: str
|
||||
value: Primitive
|
||||
@@ -395,7 +398,7 @@ StatsDict = Dict[str, StatsItem]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ColumnMetadata(JsonSchemaMixin):
|
||||
class ColumnMetadata(dbtClassMixin):
|
||||
type: str
|
||||
comment: Optional[str]
|
||||
index: int
|
||||
@@ -406,7 +409,7 @@ ColumnMap = Dict[str, ColumnMetadata]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TableMetadata(JsonSchemaMixin):
|
||||
class TableMetadata(dbtClassMixin):
|
||||
type: str
|
||||
database: Optional[str]
|
||||
schema: str
|
||||
@@ -416,7 +419,7 @@ class TableMetadata(JsonSchemaMixin):
|
||||
|
||||
|
||||
@dataclass
|
||||
class CatalogTable(JsonSchemaMixin, Replaceable):
|
||||
class CatalogTable(dbtClassMixin, Replaceable):
|
||||
metadata: TableMetadata
|
||||
columns: ColumnMap
|
||||
stats: StatsDict
|
||||
@@ -439,7 +442,7 @@ class CatalogMetadata(BaseArtifactMetadata):
|
||||
|
||||
|
||||
@dataclass
|
||||
class CatalogResults(JsonSchemaMixin):
|
||||
class CatalogResults(dbtClassMixin):
|
||||
nodes: Dict[str, CatalogTable]
|
||||
sources: Dict[str, CatalogTable]
|
||||
errors: Optional[List[str]]
|
||||
|
||||
@@ -26,6 +26,8 @@ from dbt.exceptions import InternalException
|
||||
from dbt.logger import LogMessage
|
||||
from dbt.utils import restrict_to
|
||||
|
||||
from dbt.contracts.jsonschema import dbtClassMixin
|
||||
|
||||
|
||||
TaskTags = Optional[Dict[str, Any]]
|
||||
TaskID = uuid.UUID
|
||||
@@ -34,7 +36,7 @@ TaskID = uuid.UUID
|
||||
|
||||
|
||||
@dataclass
|
||||
class RPCParameters(JsonSchemaMixin):
|
||||
class RPCParameters(dbtClassMixin):
|
||||
timeout: Optional[float]
|
||||
task_tags: TaskTags
|
||||
|
||||
@@ -132,7 +134,7 @@ class StatusParameters(RPCParameters):
|
||||
|
||||
|
||||
@dataclass
|
||||
class GCSettings(JsonSchemaMixin):
|
||||
class GCSettings(dbtClassMixin):
|
||||
# start evicting the longest-ago-ended tasks here
|
||||
maxsize: int
|
||||
# start evicting all tasks before now - auto_reap_age when we have this
|
||||
@@ -254,7 +256,7 @@ class RemoteExecutionResult(ExecutionResult, RemoteResult):
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResultTable(JsonSchemaMixin):
|
||||
class ResultTable(dbtClassMixin):
|
||||
column_names: List[str]
|
||||
rows: List[Any]
|
||||
|
||||
@@ -411,7 +413,7 @@ class TaskHandlerState(StrEnum):
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskTiming(JsonSchemaMixin):
|
||||
class TaskTiming(dbtClassMixin):
|
||||
state: TaskHandlerState
|
||||
start: Optional[datetime]
|
||||
end: Optional[datetime]
|
||||
|
||||
@@ -3,16 +3,18 @@ from hologram import JsonSchemaMixin
|
||||
|
||||
from typing import List, Dict, Any, Union
|
||||
|
||||
from dbt.contracts.jsonschema import dbtClassMixin
|
||||
|
||||
|
||||
@dataclass
|
||||
class SelectorDefinition(JsonSchemaMixin):
|
||||
class SelectorDefinition(dbtClassMixin):
|
||||
name: str
|
||||
definition: Union[str, Dict[str, Any]]
|
||||
description: str = ''
|
||||
|
||||
|
||||
@dataclass
|
||||
class SelectorFile(JsonSchemaMixin):
|
||||
class SelectorFile(dbtClassMixin):
|
||||
selectors: List[SelectorDefinition]
|
||||
version: int = 2
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from dbt.exceptions import (
|
||||
from dbt.version import __version__
|
||||
from dbt.tracking import get_invocation_id
|
||||
from hologram import JsonSchemaMixin
|
||||
from dbt.contracts.jsonschema import dbtClassMixin
|
||||
|
||||
MacroKey = Tuple[str, str]
|
||||
SourceKey = Tuple[str, str]
|
||||
@@ -58,7 +59,8 @@ class Mergeable(Replaceable):
|
||||
|
||||
class Writable:
|
||||
def write(self, path: str, omit_none: bool = False):
|
||||
write_json(path, self.to_dict(omit_none=omit_none)) # type: ignore
|
||||
# TODO : Do this faster?
|
||||
write_json(path, self.serialize(omit_none=omit_none)) # type: ignore
|
||||
|
||||
|
||||
class AdditionalPropertiesMixin:
|
||||
@@ -71,20 +73,20 @@ class AdditionalPropertiesMixin:
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data, validate=True):
|
||||
self = super().from_dict(data=data, validate=validate)
|
||||
keys = self.to_dict(validate=False, omit_none=False)
|
||||
self = super().deserialize(data=data, validate=validate)
|
||||
keys = self.serialize(validate=False, omit_none=False)
|
||||
for key, value in data.items():
|
||||
if key not in keys:
|
||||
self.extra[key] = value
|
||||
return self
|
||||
|
||||
def to_dict(self, omit_none=True, validate=False):
|
||||
data = super().to_dict(omit_none=omit_none, validate=validate)
|
||||
data = super().serialize(omit_none=omit_none, validate=validate)
|
||||
data.update(self.extra)
|
||||
return data
|
||||
|
||||
def replace(self, **kwargs):
|
||||
dct = self.to_dict(omit_none=False, validate=False)
|
||||
dct = self.serialize(omit_none=False, validate=False)
|
||||
dct.update(kwargs)
|
||||
return self.from_dict(dct)
|
||||
|
||||
@@ -135,7 +137,7 @@ def get_metadata_env() -> Dict[str, str]:
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BaseArtifactMetadata(JsonSchemaMixin):
|
||||
class BaseArtifactMetadata(dbtClassMixin):
|
||||
dbt_schema_version: str
|
||||
dbt_version: str = __version__
|
||||
generated_at: datetime = dataclasses.field(
|
||||
@@ -158,7 +160,7 @@ def schema_version(name: str, version: int):
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class VersionedSchema(JsonSchemaMixin):
|
||||
class VersionedSchema(dbtClassMixin):
|
||||
dbt_schema_version: ClassVar[SchemaVersion]
|
||||
|
||||
@classmethod
|
||||
@@ -194,4 +196,4 @@ class ArtifactMixin(VersionedSchema, Writable, Readable):
|
||||
if found != expected:
|
||||
raise IncompatibleSchemaException(expected, found)
|
||||
|
||||
return super().from_dict(data=data, validate=validate)
|
||||
return super().deserialize(data=data, validate=validate)
|
||||
|
||||
@@ -4,8 +4,12 @@ from typing import (
|
||||
import networkx as nx # type: ignore
|
||||
|
||||
from dbt.exceptions import InternalException
|
||||
from dbt.contracts.jsonschema import ValidatedStringMixin
|
||||
from mashumaro.types import SerializableType
|
||||
|
||||
UniqueId = NewType('UniqueId', str)
|
||||
|
||||
class UniqueId(ValidatedStringMixin):
|
||||
ValidationRegex = '.+'
|
||||
|
||||
|
||||
class Graph:
|
||||
|
||||
@@ -2,16 +2,33 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import NewType, Tuple, AbstractSet
|
||||
from typing import NewType, Tuple, AbstractSet, Union
|
||||
|
||||
from hologram import (
|
||||
FieldEncoder, JsonSchemaMixin, JsonDict, ValidationError
|
||||
)
|
||||
from hologram.helpers import StrEnum
|
||||
from dbt.contracts.jsonschema import dbtClassMixin
|
||||
from mashumaro.types import SerializableType
|
||||
|
||||
Port = NewType('Port', int)
|
||||
|
||||
class Port(int, SerializableType):
|
||||
@classmethod
|
||||
def _deserialize(cls, value: Union[int, str]) -> 'Port':
|
||||
try:
|
||||
value = int(value)
|
||||
except ValueError:
|
||||
raise ValidationError(f'Cannot encode {value} into port numbr')
|
||||
|
||||
return Port(value)
|
||||
|
||||
def _serialize(self) -> int:
|
||||
# TODO : Validate here?
|
||||
return self
|
||||
|
||||
|
||||
# TODO
|
||||
class PortEncoder(FieldEncoder):
|
||||
@property
|
||||
def json_schema(self):
|
||||
@@ -66,16 +83,17 @@ class NVEnum(StrEnum):
|
||||
|
||||
|
||||
@dataclass
|
||||
class NoValue(JsonSchemaMixin):
|
||||
class NoValue(dbtClassMixin):
|
||||
"""Sometimes, you want a way to say none that isn't None"""
|
||||
novalue: NVEnum = NVEnum.novalue
|
||||
|
||||
|
||||
JsonSchemaMixin.register_field_encoders({
|
||||
Port: PortEncoder(),
|
||||
timedelta: TimeDeltaFieldEncoder(),
|
||||
Path: PathEncoder(),
|
||||
})
|
||||
# TODO : None of this is right lol
|
||||
# JsonSchemaMixin.register_field_encoders({
|
||||
# Port: PortEncoder(),
|
||||
# timedelta: TimeDeltaFieldEncoder(),
|
||||
# Path: PathEncoder(),
|
||||
# })
|
||||
|
||||
|
||||
FQNPath = Tuple[str, ...]
|
||||
|
||||
@@ -15,6 +15,8 @@ import colorama
|
||||
import logbook
|
||||
from hologram import JsonSchemaMixin
|
||||
|
||||
from dbt.contracts.jsonschema import dbtClassMixin
|
||||
|
||||
# Colorama needs some help on windows because we're using logger.info
|
||||
# intead of print(). If the Windows env doesn't have a TERM var set,
|
||||
# then we should override the logging stream to use the colorama
|
||||
@@ -49,7 +51,7 @@ Extras = Dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LogMessage(JsonSchemaMixin):
|
||||
class LogMessage(dbtClassMixin):
|
||||
timestamp: datetime
|
||||
message: str
|
||||
channel: str
|
||||
@@ -215,7 +217,7 @@ class TextOnly(logbook.Processor):
|
||||
|
||||
|
||||
class TimingProcessor(logbook.Processor):
|
||||
def __init__(self, timing_info: Optional[JsonSchemaMixin] = None):
|
||||
def __init__(self, timing_info: Optional[dbtClassMixin] = None):
|
||||
self.timing_info = timing_info
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -137,6 +137,7 @@ def main(args=None):
|
||||
exit_code = e.code
|
||||
|
||||
except BaseException as e:
|
||||
print(traceback.format_exc())
|
||||
logger.warning("Encountered an error:")
|
||||
logger.warning(str(e))
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ class AnalysisParser(SimpleSQLParser[ParsedAnalysisNode]):
|
||||
)
|
||||
|
||||
def parse_from_dict(self, dct, validate=True) -> ParsedAnalysisNode:
|
||||
return ParsedAnalysisNode.from_dict(dct, validate=validate)
|
||||
return ParsedAnalysisNode.deserialize(dct, validate=validate)
|
||||
|
||||
@property
|
||||
def resource_type(self) -> NodeType:
|
||||
|
||||
@@ -12,7 +12,7 @@ class DataTestParser(SimpleSQLParser[ParsedDataTestNode]):
|
||||
)
|
||||
|
||||
def parse_from_dict(self, dct, validate=True) -> ParsedDataTestNode:
|
||||
return ParsedDataTestNode.from_dict(dct, validate=validate)
|
||||
return ParsedDataTestNode.deserialize(dct, validate=validate)
|
||||
|
||||
@property
|
||||
def resource_type(self) -> NodeType:
|
||||
|
||||
@@ -79,7 +79,7 @@ class HookParser(SimpleParser[HookBlock, ParsedHookNode]):
|
||||
return [path]
|
||||
|
||||
def parse_from_dict(self, dct, validate=True) -> ParsedHookNode:
|
||||
return ParsedHookNode.from_dict(dct, validate=validate)
|
||||
return ParsedHookNode.deserialize(dct, validate=validate)
|
||||
|
||||
@classmethod
|
||||
def get_compiled_path(cls, block: HookBlock):
|
||||
|
||||
@@ -53,20 +53,22 @@ from dbt.version import __version__
|
||||
|
||||
from hologram import JsonSchemaMixin
|
||||
|
||||
from dbt.contracts.jsonschema import dbtClassMixin
|
||||
|
||||
PARTIAL_PARSE_FILE_NAME = 'partial_parse.pickle'
|
||||
PARSING_STATE = DbtProcessState('parsing')
|
||||
DEFAULT_PARTIAL_PARSE = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParserInfo(JsonSchemaMixin):
|
||||
class ParserInfo(dbtClassMixin):
|
||||
parser: str
|
||||
elapsed: float
|
||||
path_count: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProjectLoaderInfo(JsonSchemaMixin):
|
||||
class ProjectLoaderInfo(dbtClassMixin):
|
||||
project_name: str
|
||||
elapsed: float
|
||||
parsers: List[ParserInfo]
|
||||
@@ -74,7 +76,7 @@ class ProjectLoaderInfo(JsonSchemaMixin):
|
||||
|
||||
|
||||
@dataclass
|
||||
class ManifestLoaderInfo(JsonSchemaMixin, Writable):
|
||||
class ManifestLoaderInfo(dbtClassMixin, Writable):
|
||||
path_count: int = 0
|
||||
is_partial_parse_enabled: Optional[bool] = None
|
||||
parse_project_elapsed: Optional[float] = None
|
||||
|
||||
@@ -11,7 +11,7 @@ class ModelParser(SimpleSQLParser[ParsedModelNode]):
|
||||
)
|
||||
|
||||
def parse_from_dict(self, dct, validate=True) -> ParsedModelNode:
|
||||
return ParsedModelNode.from_dict(dct, validate=validate)
|
||||
return ParsedModelNode.deserialize(dct, validate=validate)
|
||||
|
||||
@property
|
||||
def resource_type(self) -> NodeType:
|
||||
|
||||
@@ -3,6 +3,8 @@ from typing import TypeVar, MutableMapping, Mapping, Union, List
|
||||
|
||||
from hologram import JsonSchemaMixin
|
||||
|
||||
from dbt.contracts.jsonschema import dbtClassMixin
|
||||
|
||||
from dbt.contracts.files import RemoteFile, FileHash, SourceFile
|
||||
from dbt.contracts.graph.compiled import CompileResultNode
|
||||
from dbt.contracts.graph.parsed import (
|
||||
@@ -62,7 +64,7 @@ def dict_field():
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParseResult(JsonSchemaMixin, Writable, Replaceable):
|
||||
class ParseResult(dbtClassMixin, Writable, Replaceable):
|
||||
vars_hash: FileHash
|
||||
profile_hash: FileHash
|
||||
project_hashes: MutableMapping[str, FileHash]
|
||||
|
||||
@@ -26,7 +26,7 @@ class RPCCallParser(SimpleSQLParser[ParsedRPCNode]):
|
||||
return []
|
||||
|
||||
def parse_from_dict(self, dct, validate=True) -> ParsedRPCNode:
|
||||
return ParsedRPCNode.from_dict(dct, validate=validate)
|
||||
return ParsedRPCNode.deserialize(dct, validate=validate)
|
||||
|
||||
@property
|
||||
def resource_type(self) -> NodeType:
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import (
|
||||
)
|
||||
|
||||
from hologram import ValidationError, JsonSchemaMixin
|
||||
from dbt.contracts.jsonschema import dbtClassMixin
|
||||
|
||||
from dbt.adapters.factory import get_adapter
|
||||
from dbt.clients.jinja import get_rendered, add_rendered_test_kwargs
|
||||
@@ -119,7 +120,8 @@ class ParserRef:
|
||||
meta=meta,
|
||||
tags=tags,
|
||||
quote=quote,
|
||||
_extra=column.extra
|
||||
# TODO
|
||||
#_extra=column.extra
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -201,7 +203,7 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
||||
)
|
||||
|
||||
def parse_from_dict(self, dct, validate=True) -> ParsedSchemaTestNode:
|
||||
return ParsedSchemaTestNode.from_dict(dct, validate=validate)
|
||||
return ParsedSchemaTestNode.deserialize(dct, validate=validate)
|
||||
|
||||
def _parse_format_version(
|
||||
self, yaml: YamlBlock
|
||||
@@ -654,7 +656,7 @@ class YamlDocsReader(YamlReader):
|
||||
raise NotImplementedError('parse is abstract')
|
||||
|
||||
|
||||
T = TypeVar('T', bound=JsonSchemaMixin)
|
||||
T = TypeVar('T', bound=dbtClassMixin)
|
||||
|
||||
|
||||
class SourceParser(YamlDocsReader):
|
||||
|
||||
@@ -13,7 +13,7 @@ class SeedParser(SimpleSQLParser[ParsedSeedNode]):
|
||||
)
|
||||
|
||||
def parse_from_dict(self, dct, validate=True) -> ParsedSeedNode:
|
||||
return ParsedSeedNode.from_dict(dct, validate=validate)
|
||||
return ParsedSeedNode.deserialize(dct, validate=validate)
|
||||
|
||||
@property
|
||||
def resource_type(self) -> NodeType:
|
||||
|
||||
@@ -26,7 +26,7 @@ class SnapshotParser(
|
||||
)
|
||||
|
||||
def parse_from_dict(self, dct, validate=True) -> IntermediateSnapshotNode:
|
||||
return IntermediateSnapshotNode.from_dict(dct, validate=validate)
|
||||
return IntermediateSnapshotNode.deserialize(dct, validate=validate)
|
||||
|
||||
@property
|
||||
def resource_type(self) -> NodeType:
|
||||
@@ -66,6 +66,9 @@ class SnapshotParser(
|
||||
|
||||
def transform(self, node: IntermediateSnapshotNode) -> ParsedSnapshotNode:
|
||||
try:
|
||||
# TODO : This is not going to work at all
|
||||
# TODO : We need to accept the config and turn that into a concrete
|
||||
# Snapshot node type
|
||||
parsed_node = ParsedSnapshotNode.from_dict(node.to_dict())
|
||||
self.set_snapshot_attributes(parsed_node)
|
||||
return parsed_node
|
||||
|
||||
@@ -15,6 +15,9 @@ from dbt.contracts.rpc import (
|
||||
from dbt.exceptions import InternalException
|
||||
from dbt.utils import restrict_to
|
||||
|
||||
from dbt.contracts.jsonschema import dbtClassMixin
|
||||
from mashumaro.types import SerializableType
|
||||
|
||||
|
||||
class QueueMessageType(StrEnum):
|
||||
Error = 'error'
|
||||
@@ -26,16 +29,41 @@ class QueueMessageType(StrEnum):
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueueMessage(JsonSchemaMixin):
|
||||
class QueueMessage(dbtClassMixin):
|
||||
message_type: QueueMessageType
|
||||
|
||||
|
||||
class SerializableLogRecord(logbook.LogRecord, SerializableType):
|
||||
def _serialize(self):
|
||||
# TODO
|
||||
import ipdb; ipdb.set_trace()
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _deserialize(cls, value):
|
||||
# TODO
|
||||
import ipdb; ipdb.set_trace()
|
||||
pass
|
||||
|
||||
class SerializableJSONRPCError(JSONRPCError, SerializableType):
|
||||
def _serialize(self):
|
||||
# TODO
|
||||
import ipdb; ipdb.set_trace()
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _deserialize(cls, value):
|
||||
# TODO
|
||||
import ipdb; ipdb.set_trace()
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueueLogMessage(QueueMessage):
|
||||
message_type: QueueMessageType = field(
|
||||
metadata=restrict_to(QueueMessageType.Log)
|
||||
)
|
||||
record: logbook.LogRecord
|
||||
record: SerializableLogRecord
|
||||
|
||||
@classmethod
|
||||
def from_record(cls, record: logbook.LogRecord):
|
||||
@@ -50,7 +78,7 @@ class QueueErrorMessage(QueueMessage):
|
||||
message_type: QueueMessageType = field(
|
||||
metadata=restrict_to(QueueMessageType.Error)
|
||||
)
|
||||
error: JSONRPCError
|
||||
error: SerializableJSONRPCError
|
||||
|
||||
@classmethod
|
||||
def from_error(cls, error: JSONRPCError):
|
||||
|
||||
@@ -8,6 +8,8 @@ from hologram import JsonSchemaMixin, ValidationError
|
||||
from dbt.contracts.rpc import RPCParameters, RemoteResult, RemoteMethodFlags
|
||||
from dbt.exceptions import NotImplementedException, InternalException
|
||||
|
||||
from dbt.contracts.jsonschema import dbtClassMixin
|
||||
|
||||
Parameters = TypeVar('Parameters', bound=RPCParameters)
|
||||
Result = TypeVar('Result', bound=RemoteResult)
|
||||
|
||||
@@ -109,7 +111,7 @@ class RemoteBuiltinMethod(RemoteMethod[Parameters, Result]):
|
||||
'the run() method on builtins should never be called'
|
||||
)
|
||||
|
||||
def __call__(self, **kwargs: Dict[str, Any]) -> JsonSchemaMixin:
|
||||
def __call__(self, **kwargs: Dict[str, Any]) -> dbtClassMixin:
|
||||
try:
|
||||
params = self.get_parameters().from_dict(kwargs)
|
||||
except ValidationError as exc:
|
||||
|
||||
@@ -20,6 +20,8 @@ from dbt.rpc.task_handler import RequestTaskHandler
|
||||
from dbt.rpc.method import RemoteMethod
|
||||
from dbt.rpc.task_manager import TaskManager
|
||||
|
||||
from dbt.contracts.jsonschema import dbtClassMixin
|
||||
|
||||
|
||||
def track_rpc_request(task):
|
||||
dbt.tracking.track_rpc_request({
|
||||
@@ -90,11 +92,11 @@ class ResponseManager(JSONRPCResponseManager):
|
||||
@classmethod
|
||||
def _get_responses(cls, requests, dispatcher):
|
||||
for output in super()._get_responses(requests, dispatcher):
|
||||
# if it's a result, check if it's a JsonSchemaMixin and if so call
|
||||
# if it's a result, check if it's a dbtClassMixin and if so call
|
||||
# to_dict
|
||||
if hasattr(output, 'result'):
|
||||
if isinstance(output.result, JsonSchemaMixin):
|
||||
output.result = output.result.to_dict(omit_none=False)
|
||||
if isinstance(output.result, dbtClassMixin):
|
||||
output.result = output.result.serialize(omit_none=False)
|
||||
yield output
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -43,6 +43,9 @@ from dbt.rpc.method import RemoteMethod
|
||||
# we use this in typing only...
|
||||
from queue import Queue # noqa
|
||||
|
||||
from dbt.contracts.jsonschema import dbtClassMixin
|
||||
|
||||
|
||||
|
||||
def sigterm_handler(signum, frame):
|
||||
raise dbt.exceptions.RPCKilledException(signum)
|
||||
@@ -283,7 +286,7 @@ class RequestTaskHandler(threading.Thread, TaskHandlerProtocol):
|
||||
# - The actual thread that this represents, which writes its data to
|
||||
# the result and logs. The atomicity of list.append() and item
|
||||
# assignment means we don't need a lock.
|
||||
self.result: Optional[JsonSchemaMixin] = None
|
||||
self.result: Optional[dbtClassMixin] = None
|
||||
self.error: Optional[RPCException] = None
|
||||
self.state: TaskHandlerState = TaskHandlerState.NotStarted
|
||||
self.logs: List[LogMessage] = []
|
||||
|
||||
@@ -8,6 +8,8 @@ from hologram import JsonSchemaMixin
|
||||
from hologram.helpers import StrEnum
|
||||
from typing import Optional
|
||||
|
||||
from dbt.contracts.jsonschema import dbtClassMixin
|
||||
|
||||
|
||||
class Matchers(StrEnum):
|
||||
GREATER_THAN = '>'
|
||||
@@ -18,12 +20,12 @@ class Matchers(StrEnum):
|
||||
|
||||
|
||||
@dataclass
|
||||
class VersionSpecification(JsonSchemaMixin):
|
||||
major: Optional[str]
|
||||
minor: Optional[str]
|
||||
patch: Optional[str]
|
||||
prerelease: Optional[str]
|
||||
build: Optional[str]
|
||||
class VersionSpecification(dbtClassMixin):
|
||||
major: Optional[str] = None
|
||||
minor: Optional[str] = None
|
||||
patch: Optional[str] = None
|
||||
prerelease: Optional[str] = None
|
||||
build: Optional[str] = None
|
||||
matcher: Matchers = Matchers.EXACT
|
||||
|
||||
|
||||
|
||||
@@ -110,7 +110,7 @@ class ListTask(GraphRunnableTask):
|
||||
for node in self._iterate_selected_nodes():
|
||||
yield json.dumps({
|
||||
k: v
|
||||
for k, v in node.to_dict(omit_none=False).items()
|
||||
for k, v in node.serialize(omit_none=False).items()
|
||||
if k in self.ALLOWED_KEYS
|
||||
})
|
||||
|
||||
|
||||
@@ -16,13 +16,14 @@ from typing import Optional
|
||||
class PostgresCredentials(Credentials):
|
||||
host: str
|
||||
user: str
|
||||
role: Optional[str]
|
||||
port: Port
|
||||
password: str # on postgres the password is mandatory
|
||||
role: Optional[str] = None
|
||||
search_path: Optional[str] = None
|
||||
keepalives_idle: int = 0 # 0 means to use the default value
|
||||
sslmode: Optional[str] = None
|
||||
|
||||
# TODO : Fix all instances of _ALIASES? Or include them in some base class?
|
||||
_ALIASES = {
|
||||
'dbname': 'database',
|
||||
'pass': 'password'
|
||||
|
||||
Reference in New Issue
Block a user