Compare commits

...

5 Commits

Author SHA1 Message Date
Drew Banin
b7cfd70ba3 cleanup 2020-12-20 13:08:08 -05:00
Drew Banin
1ba1fb964b working on timing debug project 2020-12-19 14:31:56 -05:00
Drew Banin
98f573ef29 Merge branch 'dev/kiyoshi-kuromiya' into experiment/use-mashumaro 2020-12-18 16:19:42 -05:00
Drew Banin
c744cca96f experiment with mashumaro 2020-11-12 10:16:04 -05:00
Drew Banin
76571819f3 use mashumaro to encode/decode Relation objects 2020-11-11 14:43:32 -05:00
41 changed files with 471 additions and 373 deletions

View File

@@ -5,10 +5,11 @@ from hologram import JsonSchemaMixin
from dbt.exceptions import RuntimeException from dbt.exceptions import RuntimeException
from typing import Dict, ClassVar, Any, Optional from typing import Dict, ClassVar, Any, Optional
from dbt.contracts.jsonschema import dbtClassMixin
@dataclass @dataclass
class Column(JsonSchemaMixin): class Column(dbtClassMixin):
TYPE_LABELS: ClassVar[Dict[str, str]] = { TYPE_LABELS: ClassVar[Dict[str, str]] = {
'STRING': 'TEXT', 'STRING': 'TEXT',
'TIMESTAMP': 'TIMESTAMP', 'TIMESTAMP': 'TIMESTAMP',

View File

@@ -47,13 +47,16 @@ class BaseRelation(FakeAPIObject, Hashable):
return False return False
return self.to_dict() == other.to_dict() return self.to_dict() == other.to_dict()
# TODO : Unclear why we're leveraging a type system to implement inheritance?
@classmethod @classmethod
def get_default_quote_policy(cls) -> Policy: def get_default_quote_policy(cls) -> Policy:
return cls._get_field_named('quote_policy').default #return cls._get_field_named('quote_policy').default
return Policy()
@classmethod @classmethod
def get_default_include_policy(cls) -> Policy: def get_default_include_policy(cls) -> Policy:
return cls._get_field_named('include_policy').default #return cls._get_field_named('include_policy').default
return Policy()
def get(self, key, default=None): def get(self, key, default=None):
"""Override `.get` to return a metadata object so we don't break """Override `.get` to return a metadata object so we don't break

View File

@@ -137,7 +137,7 @@ class Profile(HasCredentials):
def validate(self): def validate(self):
try: try:
if self.credentials: if self.credentials:
self.credentials.to_dict(validate=True) self.credentials.serialize(validate=True)
ProfileConfig.from_dict( ProfileConfig.from_dict(
self.to_profile_info(serialize_credentials=True) self.to_profile_info(serialize_credentials=True)
) )

View File

@@ -306,7 +306,7 @@ class PartialProject(RenderComponents):
) )
try: try:
cfg = ProjectContract.from_dict(rendered.project_dict) cfg = ProjectContract.deserialize(rendered.project_dict)
except ValidationError as e: except ValidationError as e:
raise DbtProjectError(validator_error_message(e)) from e raise DbtProjectError(validator_error_message(e)) from e
# name/version are required in the Project definition, so we can assume # name/version are required in the Project definition, so we can assume
@@ -586,7 +586,9 @@ class Project:
def validate(self): def validate(self):
try: try:
ProjectContract.from_dict(self.to_project_config()) # TODO : Jank; need to do this to handle aliasing between hyphens and underscores
as_dict = self.to_project_config()
ProjectContract.deserialize(as_dict)
except ValidationError as e: except ValidationError as e:
raise DbtProjectError(validator_error_message(e)) from e raise DbtProjectError(validator_error_message(e)) from e

View File

@@ -174,7 +174,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
:raises DbtProjectError: If the configuration fails validation. :raises DbtProjectError: If the configuration fails validation.
""" """
try: try:
Configuration.from_dict(self.serialize()) Configuration.deserialize(self.serialize())
except ValidationError as e: except ValidationError as e:
raise DbtProjectError(validator_error_message(e)) from e raise DbtProjectError(validator_error_message(e)) from e

View File

@@ -165,7 +165,9 @@ class ContextConfigGenerator(BaseContextConfigGenerator[C]):
# Calculate the defaults. We don't want to validate the defaults, # Calculate the defaults. We don't want to validate the defaults,
# because it might be invalid in the case of required config members # because it might be invalid in the case of required config members
# (such as on snapshots!) # (such as on snapshots!)
result = config_cls.from_dict({}, validate=False) result = config_cls.from_dict({})
# TODO - why validate here?
# result.validate()
return result return result
def _update_from_config( def _update_from_config(

View File

@@ -17,9 +17,12 @@ from dbt.utils import translate_aliases
from dbt.logger import GLOBAL_LOGGER as logger from dbt.logger import GLOBAL_LOGGER as logger
from dbt.contracts.jsonschema import dbtClassMixin, ValidatedStringMixin
from mashumaro.types import SerializableType
Identifier = NewType('Identifier', str)
register_pattern(Identifier, r'^[A-Za-z_][A-Za-z0-9_]+$') class Identifier(ValidatedStringMixin):
ValidationRegex = r'^[A-Za-z_][A-Za-z0-9_]+$'
class ConnectionState(StrEnum): class ConnectionState(StrEnum):
@@ -28,22 +31,30 @@ class ConnectionState(StrEnum):
CLOSED = 'closed' CLOSED = 'closed'
FAIL = 'fail' FAIL = 'fail'
# I think that... this is not the right way to do this
# TODO!!
class DoNotSerializeType(SerializableType):
def _serialize(self) -> None:
return None
# TODO : Figure out ExtensibleJsonSchemaMixin?
@dataclass(init=False) @dataclass(init=False)
class Connection(ExtensibleJsonSchemaMixin, Replaceable): class Connection(dbtClassMixin, Replaceable):
type: Identifier type: Identifier
name: Optional[str] name: Optional[str]
state: ConnectionState = ConnectionState.INIT state: ConnectionState = ConnectionState.INIT
transaction_open: bool = False transaction_open: bool = False
# prevent serialization # prevent serialization
_handle: Optional[Any] = None #_handle: Optional[Any] = None
_credentials: JsonSchemaMixin = field(init=False) #_credentials: dbtClassMixin = field(init=False)
_handle: Optional[DoNotSerializeType] = None
_credentials: Optional[DoNotSerializeType] = None
def __init__( def __init__(
self, self,
type: Identifier, type: Identifier,
name: Optional[str], name: Optional[str],
credentials: JsonSchemaMixin, credentials: dbtClassMixin,
state: ConnectionState = ConnectionState.INIT, state: ConnectionState = ConnectionState.INIT,
transaction_open: bool = False, transaction_open: bool = False,
handle: Optional[Any] = None, handle: Optional[Any] = None,
@@ -102,7 +113,8 @@ class LazyHandle:
# will work. # will work.
@dataclass # type: ignore @dataclass # type: ignore
class Credentials( class Credentials(
ExtensibleJsonSchemaMixin, # ExtensibleJsonSchemaMixin,
dbtClassMixin,
Replaceable, Replaceable,
metaclass=abc.ABCMeta metaclass=abc.ABCMeta
): ):
@@ -121,7 +133,8 @@ class Credentials(
) -> Iterable[Tuple[str, Any]]: ) -> Iterable[Tuple[str, Any]]:
"""Return an ordered iterator of key/value pairs for pretty-printing. """Return an ordered iterator of key/value pairs for pretty-printing.
""" """
as_dict = self.to_dict(omit_none=False, with_aliases=with_aliases) # TODO: Does this... work?
as_dict = self.serialize(omit_none=False, with_aliases=with_aliases)
connection_keys = set(self._connection_keys()) connection_keys = set(self._connection_keys())
aliases: List[str] = [] aliases: List[str] = []
if with_aliases: if with_aliases:
@@ -136,10 +149,11 @@ class Credentials(
def _connection_keys(self) -> Tuple[str, ...]: def _connection_keys(self) -> Tuple[str, ...]:
raise NotImplementedError raise NotImplementedError
@classmethod # TODO TODO TODO
def from_dict(cls, data): # @classmethod
data = cls.translate_aliases(data) # def from_dict(cls, data):
return super().from_dict(data) # data = cls.translate_aliases(data)
# return super().from_dict(data)
@classmethod @classmethod
def translate_aliases( def translate_aliases(
@@ -147,15 +161,16 @@ class Credentials(
) -> Dict[str, Any]: ) -> Dict[str, Any]:
return translate_aliases(kwargs, cls._ALIASES, recurse) return translate_aliases(kwargs, cls._ALIASES, recurse)
def to_dict(self, omit_none=True, validate=False, *, with_aliases=False): # TODO TODO TODO
serialized = super().to_dict(omit_none=omit_none, validate=validate) # def to_dict(self, omit_none=True, validate=False, *, with_aliases=False):
if with_aliases: # serialized = super().to_dict(omit_none=omit_none, validate=validate)
serialized.update({ # if with_aliases:
new_name: serialized[canonical_name] # serialized.update({
for new_name, canonical_name in self._ALIASES.items() # new_name: serialized[canonical_name]
if canonical_name in serialized # for new_name, canonical_name in self._ALIASES.items()
}) # if canonical_name in serialized
return serialized # })
# return serialized
class UserConfigContract(Protocol): class UserConfigContract(Protocol):
@@ -205,7 +220,7 @@ DEFAULT_QUERY_COMMENT = '''
@dataclass @dataclass
class QueryComment(JsonSchemaMixin): class QueryComment(dbtClassMixin):
comment: str = DEFAULT_QUERY_COMMENT comment: str = DEFAULT_QUERY_COMMENT
append: bool = False append: bool = False

View File

@@ -9,13 +9,15 @@ from dbt.exceptions import InternalException
from .util import MacroKey, SourceKey from .util import MacroKey, SourceKey
from dbt.contracts.jsonschema import dbtClassMixin
MAXIMUM_SEED_SIZE = 1 * 1024 * 1024 MAXIMUM_SEED_SIZE = 1 * 1024 * 1024
MAXIMUM_SEED_SIZE_NAME = '1MB' MAXIMUM_SEED_SIZE_NAME = '1MB'
@dataclass @dataclass
class FilePath(JsonSchemaMixin): class FilePath(dbtClassMixin):
searched_path: str searched_path: str
relative_path: str relative_path: str
project_root: str project_root: str
@@ -51,7 +53,7 @@ class FilePath(JsonSchemaMixin):
@dataclass @dataclass
class FileHash(JsonSchemaMixin): class FileHash(dbtClassMixin):
name: str # the hash type name name: str # the hash type name
checksum: str # the hashlib.hash_type().hexdigest() of the file contents checksum: str # the hashlib.hash_type().hexdigest() of the file contents
@@ -91,7 +93,7 @@ class FileHash(JsonSchemaMixin):
@dataclass @dataclass
class RemoteFile(JsonSchemaMixin): class RemoteFile(dbtClassMixin):
@property @property
def searched_path(self) -> str: def searched_path(self) -> str:
return 'from remote system' return 'from remote system'
@@ -110,7 +112,7 @@ class RemoteFile(JsonSchemaMixin):
@dataclass @dataclass
class SourceFile(JsonSchemaMixin): class SourceFile(dbtClassMixin):
"""Define a source file in dbt""" """Define a source file in dbt"""
path: Union[FilePath, RemoteFile] # the path information path: Union[FilePath, RemoteFile] # the path information
checksum: FileHash checksum: FileHash

View File

@@ -23,15 +23,17 @@ from hologram import JsonSchemaMixin
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional, List, Union, Dict, Type from typing import Optional, List, Union, Dict, Type
from dbt.contracts.jsonschema import dbtClassMixin
@dataclass @dataclass
class InjectedCTE(JsonSchemaMixin, Replaceable): class InjectedCTE(dbtClassMixin, Replaceable):
id: str id: str
sql: str sql: str
@dataclass @dataclass
class CompiledNodeMixin(JsonSchemaMixin): class CompiledNodeMixin(dbtClassMixin):
# this is a special mixin class to provide a required argument. If a node # this is a special mixin class to provide a required argument. If a node
# is missing a `compiled` flag entirely, it must not be a CompiledNode. # is missing a `compiled` flag entirely, it must not be a CompiledNode.
compiled: bool compiled: bool
@@ -179,7 +181,8 @@ def parsed_instance_for(compiled: CompiledNode) -> ParsedResource:
.format(compiled.resource_type)) .format(compiled.resource_type))
# validate=False to allow extra keys from compiling # validate=False to allow extra keys from compiling
return cls.from_dict(compiled.to_dict(), validate=False) # TODO : This is probably not going to do the right thing....
return cls.deserialize(compiled.to_dict(), validate=False)
NonSourceCompiledNode = Union[ NonSourceCompiledNode = Union[

View File

@@ -508,10 +508,10 @@ class Manifest:
""" """
self.flat_graph = { self.flat_graph = {
'nodes': { 'nodes': {
k: v.to_dict(omit_none=False) for k, v in self.nodes.items() k: v.serialize(omit_none=False) for k, v in self.nodes.items()
}, },
'sources': { 'sources': {
k: v.to_dict(omit_none=False) for k, v in self.sources.items() k: v.serialize(omit_none=False) for k, v in self.sources.items()
} }
} }
@@ -764,7 +764,7 @@ class Manifest:
) )
def to_dict(self, omit_none=True, validate=False): def to_dict(self, omit_none=True, validate=False):
return self.writable_manifest().to_dict( return self.writable_manifest().serialize(
omit_none=omit_none, validate=validate omit_none=omit_none, validate=validate
) )

View File

@@ -21,6 +21,9 @@ from dbt.contracts.util import Replaceable, list_str
from dbt import hooks from dbt import hooks
from dbt.node_types import NodeType from dbt.node_types import NodeType
from dbt.contracts.jsonschema import dbtClassMixin
from mashumaro.types import SerializableType
M = TypeVar('M', bound='Metadata') M = TypeVar('M', bound='Metadata')
@@ -170,9 +173,19 @@ def insensitive_patterns(*patterns: str):
return '^({})$'.format('|'.join(lowercased)) return '^({})$'.format('|'.join(lowercased))
# TODO?
Severity = NewType('Severity', str) Severity = NewType('Severity', str)
register_pattern(Severity, insensitive_patterns('warn', 'error')) register_pattern(Severity, insensitive_patterns('warn', 'error'))
class Severity(str, SerializableType):
@classmethod
def _deserialize(cls, value: str) -> 'Severity':
# TODO : Validate here?
return Severity(value)
def _serialize(self) -> str:
# TODO : Validate here?
return self
class SnapshotStrategy(StrEnum): class SnapshotStrategy(StrEnum):
@@ -185,7 +198,7 @@ class All(StrEnum):
@dataclass @dataclass
class Hook(JsonSchemaMixin, Replaceable): class Hook(dbtClassMixin, Replaceable):
sql: str sql: str
transaction: bool = True transaction: bool = True
index: Optional[int] = None index: Optional[int] = None
@@ -196,7 +209,7 @@ T = TypeVar('T', bound='BaseConfig')
@dataclass @dataclass
class BaseConfig( class BaseConfig(
AdditionalPropertiesAllowed, Replaceable, MutableMapping[str, Any] dbtClassMixin, Replaceable, MutableMapping[str, Any]
): ):
# Implement MutableMapping so this config will behave as some macros expect # Implement MutableMapping so this config will behave as some macros expect
# during parsing (notably, syntax like `{{ node.config['schema'] }}`) # during parsing (notably, syntax like `{{ node.config['schema'] }}`)
@@ -294,23 +307,25 @@ class BaseConfig(
""" """
result = {} result = {}
for fld, target_field in cls._get_fields(): # TODO : This is not correct.... must implement without reflection
if target_field not in data:
continue
data_attr = data.pop(target_field) # for fld, target_field in cls._get_fields():
if target_field not in src: # if target_field not in data:
result[target_field] = data_attr # continue
continue
merge_behavior = MergeBehavior.from_field(fld) # data_attr = data.pop(target_field)
self_attr = src[target_field] # if target_field not in src:
# result[target_field] = data_attr
# continue
result[target_field] = _merge_field_value( # merge_behavior = MergeBehavior.from_field(fld)
merge_behavior=merge_behavior, # self_attr = src[target_field]
self_value=self_attr,
other_value=data_attr, # result[target_field] = _merge_field_value(
) # merge_behavior=merge_behavior,
# self_value=self_attr,
# other_value=data_attr,
# )
return result return result
def to_dict( def to_dict(
@@ -320,7 +335,7 @@ class BaseConfig(
*, *,
omit_hidden: bool = True, omit_hidden: bool = True,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
result = super().to_dict(omit_none=omit_none, validate=validate) result = super().serialize(omit_none=omit_none, validate=validate)
if omit_hidden and not omit_none: if omit_hidden and not omit_none:
for fld, target_field in self._get_fields(): for fld, target_field in self._get_fields():
if target_field not in result: if target_field not in result:
@@ -344,7 +359,9 @@ class BaseConfig(
""" """
# sadly, this is a circular import # sadly, this is a circular import
from dbt.adapters.factory import get_config_class_by_name from dbt.adapters.factory import get_config_class_by_name
dct = self.to_dict(omit_none=False, validate=False, omit_hidden=False) # TODO : omit_hidden?
# dct = self.serialize(omit_none=False, validate=False, omit_hidden=False)
dct = self.serialize(omit_none=False, validate=False)
adapter_config_cls = get_config_class_by_name(adapter_type) adapter_config_cls = get_config_class_by_name(adapter_type)
@@ -358,11 +375,11 @@ class BaseConfig(
dct.update(data) dct.update(data)
# any validation failures must have come from the update # any validation failures must have come from the update
return self.from_dict(dct, validate=validate) return self.deserialize(dct, validate=validate)
def finalize_and_validate(self: T) -> T: def finalize_and_validate(self: T) -> T:
# from_dict will validate for us # from_dict will validate for us
dct = self.to_dict(omit_none=False, validate=False) dct = self.serialize(omit_none=False, validate=False)
return self.from_dict(dct) return self.from_dict(dct)
def replace(self, **kwargs): def replace(self, **kwargs):
@@ -372,7 +389,7 @@ class BaseConfig(
for key, value in kwargs.items(): for key, value in kwargs.items():
new_key = mapping.get(key, key) new_key = mapping.get(key, key)
dct[new_key] = value dct[new_key] = value
return self.from_dict(dct, validate=False) return self.deserialize(dct, validate=False)
@dataclass @dataclass
@@ -435,12 +452,15 @@ class NodeConfig(BaseConfig):
for key in hooks.ModelHookType: for key in hooks.ModelHookType:
if key in data: if key in data:
data[key] = [hooks.get_hook_dict(h) for h in data[key]] data[key] = [hooks.get_hook_dict(h) for h in data[key]]
return super().from_dict(data, validate=validate) return super().deserialize(data, validate=validate)
@classmethod @classmethod
def field_mapping(cls): def field_mapping(cls):
return {'post_hook': 'post-hook', 'pre_hook': 'pre-hook'} return {'post_hook': 'post-hook', 'pre_hook': 'pre-hook'}
def validate(self):
# TODO : Not implemented!
pass
@dataclass @dataclass
class SeedConfig(NodeConfig): class SeedConfig(NodeConfig):
@@ -454,63 +474,10 @@ class TestConfig(NodeConfig):
severity: Severity = Severity('ERROR') severity: Severity = Severity('ERROR')
SnapshotVariants = Union[
'TimestampSnapshotConfig',
'CheckSnapshotConfig',
'GenericSnapshotConfig',
]
def _relevance_without_strategy(error: jsonschema.ValidationError):
# calculate the 'relevance' of an error the normal jsonschema way, except
# if the validator is in the 'strategy' field and its conflicting with the
# 'enum'. This suppresses `"'timestamp' is not one of ['check']` and such
if 'strategy' in error.path and error.validator in {'enum', 'not'}:
length = 1
else:
length = -len(error.path)
validator = error.validator
return length, validator not in {'anyOf', 'oneOf'}
@dataclass
class SnapshotWrapper(JsonSchemaMixin):
"""This is a little wrapper to let us serialize/deserialize the
SnapshotVariants union.
"""
config: SnapshotVariants # mypy: ignore
@classmethod
def validate(cls, data: Any):
config = data.get('config', {})
if config.get('strategy') == 'check':
schema = _validate_schema(CheckSnapshotConfig)
to_validate = config
elif config.get('strategy') == 'timestamp':
schema = _validate_schema(TimestampSnapshotConfig)
to_validate = config
else:
h_cls = cast(Hashable, cls)
schema = _validate_schema(h_cls)
to_validate = data
validator = jsonschema.Draft7Validator(schema)
error = jsonschema.exceptions.best_match(
validator.iter_errors(to_validate),
key=_relevance_without_strategy,
)
if error is not None:
raise ValidationError.create_from(error) from error
@dataclass @dataclass
class EmptySnapshotConfig(NodeConfig): class EmptySnapshotConfig(NodeConfig):
materialized: str = 'snapshot' materialized: str = 'snapshot'
strategy: str = None
@dataclass(init=False) @dataclass(init=False)
@@ -519,117 +486,17 @@ class SnapshotConfig(EmptySnapshotConfig):
target_schema: str = field(init=False, metadata=dict(init_required=True)) target_schema: str = field(init=False, metadata=dict(init_required=True))
target_database: Optional[str] = None target_database: Optional[str] = None
def __init__(
self,
unique_key: str,
target_schema: str,
target_database: Optional[str] = None,
**kwargs
) -> None:
self.unique_key = unique_key
self.target_schema = target_schema
self.target_database = target_database
# kwargs['materialized'] = materialized
super().__init__(**kwargs)
# type hacks...
@classmethod
def _get_fields(cls) -> List[Tuple[Field, str]]: # type: ignore
fields: List[Tuple[Field, str]] = []
for old_field, name in super()._get_fields():
new_field = old_field
# tell hologram we're really an initvar
if old_field.metadata and old_field.metadata.get('init_required'):
new_field = field(init=True, metadata=old_field.metadata)
new_field.name = old_field.name
new_field.type = old_field.type
new_field._field_type = old_field._field_type # type: ignore
fields.append((new_field, name))
return fields
def finalize_and_validate(self: 'SnapshotConfig') -> SnapshotVariants:
data = self.to_dict()
return SnapshotWrapper.from_dict({'config': data}).config
@dataclass(init=False)
class GenericSnapshotConfig(SnapshotConfig):
strategy: str = field(init=False, metadata=dict(init_required=True))
def __init__(self, strategy: str, **kwargs) -> None:
self.strategy = strategy
super().__init__(**kwargs)
@classmethod
def _collect_json_schema(
cls, definitions: Dict[str, Any]
) -> Dict[str, Any]:
# this is the method you want to override in hologram if you want
# to do clever things about the json schema and have classes that
# contain instances of your JsonSchemaMixin respect the change.
schema = super()._collect_json_schema(definitions)
# Instead of just the strategy we'd calculate normally, say
# "this strategy except none of our specialization strategies".
strategies = [schema['properties']['strategy']]
for specialization in (TimestampSnapshotConfig, CheckSnapshotConfig):
strategies.append(
{'not': specialization.json_schema()['properties']['strategy']}
)
schema['properties']['strategy'] = {
'allOf': strategies
}
return schema
@dataclass(init=False)
class TimestampSnapshotConfig(SnapshotConfig): class TimestampSnapshotConfig(SnapshotConfig):
strategy: str = field( updated_at: str
init=False,
metadata=dict(
restrict=[str(SnapshotStrategy.Timestamp)],
init_required=True,
),
)
updated_at: str = field(init=False, metadata=dict(init_required=True))
def __init__(
self, strategy: str, updated_at: str, **kwargs
) -> None:
self.strategy = strategy
self.updated_at = updated_at
super().__init__(**kwargs)
@dataclass(init=False)
class CheckSnapshotConfig(SnapshotConfig): class CheckSnapshotConfig(SnapshotConfig):
strategy: str = field( check_cols: Union[All, List[str]]
init=False,
metadata=dict(
restrict=[str(SnapshotStrategy.Check)],
init_required=True,
),
)
# TODO: is there a way to get this to accept tuples of strings? Adding
# `Tuple[str, ...]` to the list of types results in this:
# ['email'] is valid under each of {'type': 'array', 'items':
# {'type': 'string'}}, {'type': 'array', 'items': {'type': 'string'}}
# but without it, parsing gets upset about values like `('email',)`
# maybe hologram itself should support this behavior? It's not like tuples
# are meaningful in json
check_cols: Union[All, List[str]] = field(
init=False,
metadata=dict(init_required=True),
)
def __init__(
self, strategy: str, check_cols: Union[All, List[str]], class CheckSnapshotConfig(SnapshotConfig):
**kwargs pass
) -> None:
self.strategy = strategy
self.check_cols = check_cols
super().__init__(**kwargs)
RESOURCE_TYPES: Dict[NodeType, Type[BaseConfig]] = { RESOURCE_TYPES: Dict[NodeType, Type[BaseConfig]] = {

View File

@@ -31,6 +31,8 @@ from dbt.logger import GLOBAL_LOGGER as logger # noqa
from dbt import flags from dbt import flags
from dbt.node_types import NodeType from dbt.node_types import NodeType
from dbt.contracts.jsonschema import dbtClassMixin
from .model_config import ( from .model_config import (
NodeConfig, NodeConfig,
@@ -38,20 +40,15 @@ from .model_config import (
TestConfig, TestConfig,
SourceConfig, SourceConfig,
EmptySnapshotConfig, EmptySnapshotConfig,
SnapshotVariants, SnapshotConfig,
)
# import these 3 so the SnapshotVariants forward ref works.
from .model_config import ( # noqa
TimestampSnapshotConfig,
CheckSnapshotConfig,
GenericSnapshotConfig,
) )
# TODO : Figure out AdditionalPropertiesMixin and ExtensibleJsonSchemaMixin
@dataclass @dataclass
class ColumnInfo( class ColumnInfo(
AdditionalPropertiesMixin, dbtClassMixin,
ExtensibleJsonSchemaMixin, #AdditionalPropertiesMixin,
#ExtensibleJsonSchemaMixin,
Replaceable Replaceable
): ):
name: str name: str
@@ -64,7 +61,7 @@ class ColumnInfo(
@dataclass @dataclass
class HasFqn(JsonSchemaMixin, Replaceable): class HasFqn(dbtClassMixin, Replaceable):
fqn: List[str] fqn: List[str]
def same_fqn(self, other: 'HasFqn') -> bool: def same_fqn(self, other: 'HasFqn') -> bool:
@@ -72,12 +69,12 @@ class HasFqn(JsonSchemaMixin, Replaceable):
@dataclass @dataclass
class HasUniqueID(JsonSchemaMixin, Replaceable): class HasUniqueID(dbtClassMixin, Replaceable):
unique_id: str unique_id: str
@dataclass @dataclass
class MacroDependsOn(JsonSchemaMixin, Replaceable): class MacroDependsOn(dbtClassMixin, Replaceable):
macros: List[str] = field(default_factory=list) macros: List[str] = field(default_factory=list)
# 'in' on lists is O(n) so this is O(n^2) for # of macros # 'in' on lists is O(n) so this is O(n^2) for # of macros
@@ -96,12 +93,12 @@ class DependsOn(MacroDependsOn):
@dataclass @dataclass
class HasRelationMetadata(JsonSchemaMixin, Replaceable): class HasRelationMetadata(dbtClassMixin, Replaceable):
database: Optional[str] database: Optional[str]
schema: str schema: str
class ParsedNodeMixins(JsonSchemaMixin): class ParsedNodeMixins(dbtClassMixin):
resource_type: NodeType resource_type: NodeType
depends_on: DependsOn depends_on: DependsOn
config: NodeConfig config: NodeConfig
@@ -132,8 +129,8 @@ class ParsedNodeMixins(JsonSchemaMixin):
self.meta = patch.meta self.meta = patch.meta
self.docs = patch.docs self.docs = patch.docs
if flags.STRICT_MODE: if flags.STRICT_MODE:
assert isinstance(self, JsonSchemaMixin) assert isinstance(self, dbtClassMixin)
self.to_dict(validate=True, omit_none=False) self.serialize(validate=True, omit_none=False)
def get_materialization(self): def get_materialization(self):
return self.config.materialized return self.config.materialized
@@ -335,14 +332,14 @@ class ParsedSeedNode(ParsedNode):
@dataclass @dataclass
class TestMetadata(JsonSchemaMixin, Replaceable): class TestMetadata(dbtClassMixin, Replaceable):
namespace: Optional[str] namespace: Optional[str]
name: str name: str
kwargs: Dict[str, Any] kwargs: Dict[str, Any]
@dataclass @dataclass
class HasTestMetadata(JsonSchemaMixin): class HasTestMetadata(dbtClassMixin):
test_metadata: TestMetadata test_metadata: TestMetadata
@@ -394,7 +391,7 @@ class IntermediateSnapshotNode(ParsedNode):
@dataclass @dataclass
class ParsedSnapshotNode(ParsedNode): class ParsedSnapshotNode(ParsedNode):
resource_type: NodeType = field(metadata={'restrict': [NodeType.Snapshot]}) resource_type: NodeType = field(metadata={'restrict': [NodeType.Snapshot]})
config: SnapshotVariants config: SnapshotConfig
@dataclass @dataclass
@@ -443,8 +440,8 @@ class ParsedMacro(UnparsedBaseNode, HasUniqueID):
self.docs = patch.docs self.docs = patch.docs
self.arguments = patch.arguments self.arguments = patch.arguments
if flags.STRICT_MODE: if flags.STRICT_MODE:
assert isinstance(self, JsonSchemaMixin) assert isinstance(self, dbtClassMixin)
self.to_dict(validate=True, omit_none=False) self.serialize(validate=True, omit_none=False)
def same_contents(self, other: Optional['ParsedMacro']) -> bool: def same_contents(self, other: Optional['ParsedMacro']) -> bool:
if other is None: if other is None:

View File

@@ -16,9 +16,11 @@ from datetime import timedelta
from pathlib import Path from pathlib import Path
from typing import Optional, List, Union, Dict, Any, Sequence from typing import Optional, List, Union, Dict, Any, Sequence
from dbt.contracts.jsonschema import dbtClassMixin
@dataclass @dataclass
class UnparsedBaseNode(JsonSchemaMixin, Replaceable): class UnparsedBaseNode(dbtClassMixin, Replaceable):
package_name: str package_name: str
root_path: str root_path: str
path: str path: str
@@ -66,18 +68,20 @@ class UnparsedRunHook(UnparsedNode):
@dataclass @dataclass
class Docs(JsonSchemaMixin, Replaceable): class Docs(dbtClassMixin, Replaceable):
show: bool = True show: bool = True
# TODO : This should have AdditionalPropertiesMixin and ExtensibleJsonSchemaMixin
@dataclass @dataclass
class HasDocs(AdditionalPropertiesMixin, ExtensibleJsonSchemaMixin, class HasDocs(dbtClassMixin, Replaceable):
Replaceable):
name: str name: str
description: str = '' description: str = ''
meta: Dict[str, Any] = field(default_factory=dict) meta: Dict[str, Any] = field(default_factory=dict)
data_type: Optional[str] = None data_type: Optional[str] = None
docs: Docs = field(default_factory=Docs) docs: Docs = field(default_factory=Docs)
# TODO : How do we handle these additional fields with mashurmaro?
_extra: Dict[str, Any] = field(default_factory=dict) _extra: Dict[str, Any] = field(default_factory=dict)
@@ -100,7 +104,7 @@ class UnparsedColumn(HasTests):
@dataclass @dataclass
class HasColumnDocs(JsonSchemaMixin, Replaceable): class HasColumnDocs(dbtClassMixin, Replaceable):
columns: Sequence[HasDocs] = field(default_factory=list) columns: Sequence[HasDocs] = field(default_factory=list)
@@ -110,7 +114,7 @@ class HasColumnTests(HasColumnDocs):
@dataclass @dataclass
class HasYamlMetadata(JsonSchemaMixin): class HasYamlMetadata(dbtClassMixin):
original_file_path: str original_file_path: str
yaml_key: str yaml_key: str
package_name: str package_name: str
@@ -127,7 +131,7 @@ class UnparsedNodeUpdate(HasColumnTests, HasTests, HasYamlMetadata):
@dataclass @dataclass
class MacroArgument(JsonSchemaMixin): class MacroArgument(dbtClassMixin):
name: str name: str
type: Optional[str] = None type: Optional[str] = None
description: str = '' description: str = ''
@@ -148,7 +152,7 @@ class TimePeriod(StrEnum):
@dataclass @dataclass
class Time(JsonSchemaMixin, Replaceable): class Time(dbtClassMixin, Replaceable):
count: int count: int
period: TimePeriod period: TimePeriod
@@ -159,7 +163,7 @@ class Time(JsonSchemaMixin, Replaceable):
@dataclass @dataclass
class FreshnessThreshold(JsonSchemaMixin, Mergeable): class FreshnessThreshold(dbtClassMixin, Mergeable):
warn_after: Optional[Time] = None warn_after: Optional[Time] = None
error_after: Optional[Time] = None error_after: Optional[Time] = None
filter: Optional[str] = None filter: Optional[str] = None
@@ -212,7 +216,7 @@ class ExternalTable(AdditionalPropertiesAllowed, Mergeable):
@dataclass @dataclass
class Quoting(JsonSchemaMixin, Mergeable): class Quoting(dbtClassMixin, Mergeable):
database: Optional[bool] = None database: Optional[bool] = None
schema: Optional[bool] = None schema: Optional[bool] = None
identifier: Optional[bool] = None identifier: Optional[bool] = None
@@ -231,14 +235,15 @@ class UnparsedSourceTableDefinition(HasColumnTests, HasTests):
tags: List[str] = field(default_factory=list) tags: List[str] = field(default_factory=list)
def to_dict(self, omit_none=True, validate=False): def to_dict(self, omit_none=True, validate=False):
result = super().to_dict(omit_none=omit_none, validate=validate) result = super().serialize(omit_none=omit_none, validate=validate)
if omit_none and self.freshness is None: if omit_none and self.freshness is None:
result['freshness'] = None result['freshness'] = None
return result return result
@dataclass @dataclass
class UnparsedSourceDefinition(JsonSchemaMixin, Replaceable): class UnparsedSourceDefinition(dbtClassMixin, Replaceable):
name: str name: str
description: str = '' description: str = ''
meta: Dict[str, Any] = field(default_factory=dict) meta: Dict[str, Any] = field(default_factory=dict)
@@ -258,14 +263,14 @@ class UnparsedSourceDefinition(JsonSchemaMixin, Replaceable):
return 'sources' return 'sources'
def to_dict(self, omit_none=True, validate=False): def to_dict(self, omit_none=True, validate=False):
result = super().to_dict(omit_none=omit_none, validate=validate) result = super().serialize(omit_none=omit_none, validate=validate)
if omit_none and self.freshness is None: if omit_none and self.freshness is None:
result['freshness'] = None result['freshness'] = None
return result return result
@dataclass @dataclass
class SourceTablePatch(JsonSchemaMixin): class SourceTablePatch(dbtClassMixin):
name: str name: str
description: Optional[str] = None description: Optional[str] = None
meta: Optional[Dict[str, Any]] = None meta: Optional[Dict[str, Any]] = None
@@ -283,7 +288,7 @@ class SourceTablePatch(JsonSchemaMixin):
columns: Optional[Sequence[UnparsedColumn]] = None columns: Optional[Sequence[UnparsedColumn]] = None
def to_patch_dict(self) -> Dict[str, Any]: def to_patch_dict(self) -> Dict[str, Any]:
dct = self.to_dict(omit_none=True) dct = self.serialize(omit_none=True)
remove_keys = ('name') remove_keys = ('name')
for key in remove_keys: for key in remove_keys:
if key in dct: if key in dct:
@@ -296,7 +301,7 @@ class SourceTablePatch(JsonSchemaMixin):
@dataclass @dataclass
class SourcePatch(JsonSchemaMixin, Replaceable): class SourcePatch(dbtClassMixin, Replaceable):
name: str = field( name: str = field(
metadata=dict(description='The name of the source to override'), metadata=dict(description='The name of the source to override'),
) )
@@ -320,7 +325,7 @@ class SourcePatch(JsonSchemaMixin, Replaceable):
tags: Optional[List[str]] = None tags: Optional[List[str]] = None
def to_patch_dict(self) -> Dict[str, Any]: def to_patch_dict(self) -> Dict[str, Any]:
dct = self.to_dict(omit_none=True) dct = self.serialize(omit_none=True)
remove_keys = ('name', 'overrides', 'tables', 'path') remove_keys = ('name', 'overrides', 'tables', 'path')
for key in remove_keys: for key in remove_keys:
if key in dct: if key in dct:
@@ -340,7 +345,7 @@ class SourcePatch(JsonSchemaMixin, Replaceable):
@dataclass @dataclass
class UnparsedDocumentation(JsonSchemaMixin, Replaceable): class UnparsedDocumentation(dbtClassMixin, Replaceable):
package_name: str package_name: str
root_path: str root_path: str
path: str path: str
@@ -400,7 +405,7 @@ class MaturityType(StrEnum):
@dataclass @dataclass
class ExposureOwner(JsonSchemaMixin, Replaceable): class ExposureOwner(dbtClassMixin, Replaceable):
email: str email: str
name: Optional[str] = None name: Optional[str] = None

View File

@@ -0,0 +1,66 @@
from dataclasses import dataclass, fields, Field
from typing import (
Optional, TypeVar, Generic, Dict, get_type_hints, List, Tuple
)
from mashumaro import DataClassDictMixin
from mashumaro.types import SerializableType
import re
"""
This is a throwaway shim to match the JsonSchemaMixin interface
that downstream consumers (ie. PostgresRelation) is expecting.
I imagine that we would try to remove code that depends on this type
reflection if we pursue an approach like the one shown here
"""
class dbtClassMixin(DataClassDictMixin):
@classmethod
def field_mapping(cls) -> Dict[str, str]:
"""Defines the mapping of python field names to JSON field names.
The main use-case is to allow JSON field names which are Python keywords
"""
return {}
def serialize(self, omit_none=False, validate=False, with_aliases: Optional[Dict[str, str]]=None):
dct = self.to_dict()
if with_aliases:
# TODO : Mutating these dicts is a TERRIBLE idea - remove this
for aliased_name, canonical_name in self._ALIASES.items():
if aliased_name in dct:
dct[canonical_name] = dct.pop(aliased_name)
return dct
@classmethod
def deserialize(cls, data, validate=False, with_aliases=False):
if with_aliases:
# TODO : Mutating these dicts is a TERRIBLE idea - remove this
for aliased_name, canonical_name in cls._ALIASES.items():
if aliased_name in data:
data[canonical_name] = data.pop(aliased_name)
# TODO .... implement these?
return cls.from_dict(data)
class ValidatedStringMixin(str, SerializableType):
ValidationRegex = None
@classmethod
def _deserialize(cls, value: str) -> 'ValidatedStringMixin':
cls.validate(value)
return ValidatedStringMixin(value)
def _serialize(self) -> str:
return str(self)
@classmethod
def validate(cls, value) -> str:
res = re.match(cls.ValidationRegex, value)
if res is None:
raise ValidationError(f"Invalid value: {value}") # TODO

View File

@@ -12,12 +12,16 @@ from hologram.helpers import HyphenatedJsonSchemaMixin, register_pattern, \
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional, List, Dict, Union, Any, NewType from typing import Optional, List, Dict, Union, Any, NewType
from dbt.contracts.jsonschema import dbtClassMixin, ValidatedStringMixin
from mashumaro.types import SerializableType
PIN_PACKAGE_URL = 'https://docs.getdbt.com/docs/package-management#section-specifying-package-versions' # noqa PIN_PACKAGE_URL = 'https://docs.getdbt.com/docs/package-management#section-specifying-package-versions' # noqa
DEFAULT_SEND_ANONYMOUS_USAGE_STATS = True DEFAULT_SEND_ANONYMOUS_USAGE_STATS = True
Name = NewType('Name', str) class Name(ValidatedStringMixin):
register_pattern(Name, r'^[^\d\W]\w*$') ValidationRegex = r'^[^\d\W]\w*$'
# this does not support the full semver (does not allow a trailing -fooXYZ) and # this does not support the full semver (does not allow a trailing -fooXYZ) and
# is not restrictive enough for full semver, (allows '1.0'). But it's like # is not restrictive enough for full semver, (allows '1.0'). But it's like
@@ -28,17 +32,26 @@ register_pattern(
r'^(?:0|[1-9]\d*)\.(?:0|[1-9]\d*)(\.(?:0|[1-9]\d*))?$', r'^(?:0|[1-9]\d*)\.(?:0|[1-9]\d*)(\.(?:0|[1-9]\d*))?$',
) )
class SemverString(str, SerializableType):
def _serialize(self) -> str:
return self
@dataclass @classmethod
class Quoting(JsonSchemaMixin, Mergeable): def _deserialize(cls, value: str) -> 'SemverString':
identifier: Optional[bool] return SemverString(value)
schema: Optional[bool]
database: Optional[bool]
project: Optional[bool]
@dataclass @dataclass
class Package(Replaceable, HyphenatedJsonSchemaMixin): class Quoting(dbtClassMixin, Mergeable):
identifier: Optional[bool] = None
schema: Optional[bool] = None
database: Optional[bool] = None
project: Optional[bool] = None
# TODO .... hyphenation.... what.... why
@dataclass
class Package(dbtClassMixin, Replaceable):
pass pass
@@ -80,7 +93,7 @@ PackageSpec = Union[LocalPackage, GitPackage, RegistryPackage]
@dataclass @dataclass
class PackageConfig(JsonSchemaMixin, Replaceable): class PackageConfig(dbtClassMixin, Replaceable):
packages: List[PackageSpec] packages: List[PackageSpec]
@@ -96,13 +109,13 @@ class ProjectPackageMetadata:
@dataclass @dataclass
class Downloads(ExtensibleJsonSchemaMixin, Replaceable): class Downloads(dbtClassMixin, Replaceable):
tarball: str tarball: str
@dataclass @dataclass
class RegistryPackageMetadata( class RegistryPackageMetadata(
ExtensibleJsonSchemaMixin, dbtClassMixin,
ProjectPackageMetadata, ProjectPackageMetadata,
): ):
downloads: Downloads downloads: Downloads
@@ -153,7 +166,7 @@ BANNED_PROJECT_NAMES = {
@dataclass @dataclass
class Project(HyphenatedJsonSchemaMixin, Replaceable): class Project(dbtClassMixin, Replaceable):
name: Name name: Name
version: Union[SemverString, float] version: Union[SemverString, float]
config_version: int config_version: int
@@ -189,9 +202,30 @@ class Project(HyphenatedJsonSchemaMixin, Replaceable):
packages: List[PackageSpec] = field(default_factory=list) packages: List[PackageSpec] = field(default_factory=list)
query_comment: Optional[Union[QueryComment, NoValue, str]] = NoValue() query_comment: Optional[Union[QueryComment, NoValue, str]] = NoValue()
_ALIASES = {
'config-version': 'config_version',
'source-paths': 'source_paths',
'macro-paths': 'macro_paths',
'data-paths': 'data_paths',
'test-paths': 'test_paths',
'analysis-paths': 'analysis_paths',
'docs-paths': 'docs_paths',
'asset-paths': 'asset_paths',
'target-path': 'target_path',
'snapshot-paths': 'snapshot_paths',
'clean-targets': 'clean_targets',
'log-path': 'log_path',
'modules-path': 'modules_path',
'on-run-start': 'on_run_start',
'on-run-end': 'on_run_end',
'require-dbt-version': 'require_dbt_version',
'project-root': 'project_root',
}
@classmethod @classmethod
def from_dict(cls, data, validate=True) -> 'Project': def deserialize(cls, data) -> 'Project':
result = super().from_dict(data, validate=validate) # TODO : Deserialize this with aliases - right now, only implemented for serializer?
result = super().deserialize(data, with_aliases=True)
if result.name in BANNED_PROJECT_NAMES: if result.name in BANNED_PROJECT_NAMES:
raise ValidationError( raise ValidationError(
f'Invalid project name: {result.name} is a reserved word' f'Invalid project name: {result.name} is a reserved word'
@@ -200,8 +234,9 @@ class Project(HyphenatedJsonSchemaMixin, Replaceable):
return result return result
# TODO : Make extensible?
@dataclass @dataclass
class UserConfig(ExtensibleJsonSchemaMixin, Replaceable, UserConfigContract): class UserConfig(dbtClassMixin, Replaceable, UserConfigContract):
send_anonymous_usage_stats: bool = DEFAULT_SEND_ANONYMOUS_USAGE_STATS send_anonymous_usage_stats: bool = DEFAULT_SEND_ANONYMOUS_USAGE_STATS
use_colors: Optional[bool] = None use_colors: Optional[bool] = None
partial_parse: Optional[bool] = None partial_parse: Optional[bool] = None
@@ -221,7 +256,7 @@ class UserConfig(ExtensibleJsonSchemaMixin, Replaceable, UserConfigContract):
@dataclass @dataclass
class ProfileConfig(HyphenatedJsonSchemaMixin, Replaceable): class ProfileConfig(dbtClassMixin, Replaceable):
profile_name: str = field(metadata={'preserve_underscore': True}) profile_name: str = field(metadata={'preserve_underscore': True})
target_name: str = field(metadata={'preserve_underscore': True}) target_name: str = field(metadata={'preserve_underscore': True})
config: UserConfig config: UserConfig
@@ -234,8 +269,8 @@ class ProfileConfig(HyphenatedJsonSchemaMixin, Replaceable):
class ConfiguredQuoting(Quoting, Replaceable): class ConfiguredQuoting(Quoting, Replaceable):
identifier: bool identifier: bool
schema: bool schema: bool
database: Optional[bool] database: Optional[bool] = None
project: Optional[bool] project: Optional[bool] = None
@dataclass @dataclass
@@ -248,5 +283,5 @@ class Configuration(Project, ProfileConfig):
@dataclass @dataclass
class ProjectList(JsonSchemaMixin): class ProjectList(dbtClassMixin):
projects: Dict[str, Project] projects: Dict[str, Project]

View File

@@ -1,11 +1,10 @@
from collections.abc import Mapping from collections.abc import Mapping
from dataclasses import dataclass, fields from dataclasses import dataclass, fields, Field
from typing import ( from typing import (
Optional, TypeVar, Generic, Dict, Optional, TypeVar, Generic, Dict, get_type_hints, List, Tuple
) )
from typing_extensions import Protocol from typing_extensions import Protocol
from hologram import JsonSchemaMixin
from hologram.helpers import StrEnum from hologram.helpers import StrEnum
from dbt import deprecations from dbt import deprecations
@@ -13,6 +12,8 @@ from dbt.contracts.util import Replaceable
from dbt.exceptions import CompilationException from dbt.exceptions import CompilationException
from dbt.utils import deep_merge from dbt.utils import deep_merge
from dbt.contracts.jsonschema import dbtClassMixin
class RelationType(StrEnum): class RelationType(StrEnum):
Table = 'table' Table = 'table'
@@ -32,7 +33,7 @@ class HasQuoting(Protocol):
quoting: Dict[str, bool] quoting: Dict[str, bool]
class FakeAPIObject(JsonSchemaMixin, Replaceable, Mapping): class FakeAPIObject(dbtClassMixin, Replaceable, Mapping):
# override the mapping truthiness, len is always >1 # override the mapping truthiness, len is always >1
def __bool__(self): def __bool__(self):
return True return True
@@ -58,16 +59,22 @@ class FakeAPIObject(JsonSchemaMixin, Replaceable, Mapping):
return self.from_dict(value) return self.from_dict(value)
T = TypeVar('T') """
mashumaro does not support generic types, so I just duplicated
the generic type class that was here previously to get things
working. We'd probably want to do this differently
in the future
TODO
"""
@dataclass @dataclass
class _ComponentObject(FakeAPIObject, Generic[T]): class Policy(FakeAPIObject):
database: T database: bool = True
schema: T schema: bool = True
identifier: T identifier: bool = True
def get_part(self, key: ComponentName) -> T: def get_part(self, key: ComponentName) -> bool:
if key == ComponentName.Database: if key == ComponentName.Database:
return self.database return self.database
elif key == ComponentName.Schema: elif key == ComponentName.Schema:
@@ -80,22 +87,15 @@ class _ComponentObject(FakeAPIObject, Generic[T]):
.format(key, list(ComponentName)) .format(key, list(ComponentName))
) )
def replace_dict(self, dct: Dict[ComponentName, T]): def replace_dict(self, dct: Dict[ComponentName, bool]):
kwargs: Dict[str, T] = {} kwargs: Dict[str, bool] = {}
for k, v in dct.items(): for k, v in dct.items():
kwargs[str(k)] = v kwargs[str(k)] = v
return self.replace(**kwargs) return self.replace(**kwargs)
@dataclass @dataclass
class Policy(_ComponentObject[bool]): class Path(FakeAPIObject):
database: bool = True
schema: bool = True
identifier: bool = True
@dataclass
class Path(_ComponentObject[Optional[str]]):
database: Optional[str] database: Optional[str]
schema: Optional[str] schema: Optional[str]
identifier: Optional[str] identifier: Optional[str]
@@ -120,3 +120,22 @@ class Path(_ComponentObject[Optional[str]]):
if part is not None: if part is not None:
part = part.lower() part = part.lower()
return part return part
def get_part(self, key: ComponentName) -> str:
if key == ComponentName.Database:
return self.database
elif key == ComponentName.Schema:
return self.schema
elif key == ComponentName.Identifier:
return self.identifier
else:
raise ValueError(
'Got a key of {}, expected one of {}'
.format(key, list(ComponentName))
)
def replace_dict(self, dct: Dict[ComponentName, str]):
kwargs: Dict[str, str] = {}
for k, v in dct.items():
kwargs[str(k)] = v
return self.replace(**kwargs)

View File

@@ -21,6 +21,9 @@ from dbt.utils import lowercase
from hologram.helpers import StrEnum from hologram.helpers import StrEnum
from hologram import JsonSchemaMixin from hologram import JsonSchemaMixin
from dbt.contracts.jsonschema import dbtClassMixin
from mashumaro.types import SerializableType
import agate import agate
from dataclasses import dataclass, field from dataclasses import dataclass, field
@@ -31,7 +34,7 @@ from dbt.clients.system import write_json
@dataclass @dataclass
class TimingInfo(JsonSchemaMixin): class TimingInfo(dbtClassMixin):
name: str name: str
started_at: Optional[datetime] = None started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None completed_at: Optional[datetime] = None
@@ -88,7 +91,7 @@ class FreshnessStatus(StrEnum):
@dataclass @dataclass
class BaseResult(JsonSchemaMixin): class BaseResult(dbtClassMixin):
status: Union[RunStatus, TestStatus, FreshnessStatus] status: Union[RunStatus, TestStatus, FreshnessStatus]
timing: List[TimingInfo] timing: List[TimingInfo]
thread_id: str thread_id: str
@@ -109,15 +112,14 @@ class PartialNodeResult(NodeResult):
def skipped(self): def skipped(self):
return False return False
# TODO : Does this work? no-op agate table serialization in RunModelResult
class SerializableAgateTable(agate.Table, SerializableType):
def _serialize(self) -> None:
return None
@dataclass @dataclass
class RunModelResult(NodeResult): class RunModelResult(NodeResult):
agate_table: Optional[agate.Table] = None agate_table: Optional[SerializableAgateTable] = None
def to_dict(self, *args, **kwargs):
dct = super().to_dict(*args, **kwargs)
dct.pop('agate_table', None)
return dct
@property @property
def skipped(self): def skipped(self):
@@ -125,7 +127,7 @@ class RunModelResult(NodeResult):
@dataclass @dataclass
class ExecutionResult(JsonSchemaMixin): class ExecutionResult(dbtClassMixin):
results: Sequence[BaseResult] results: Sequence[BaseResult]
elapsed_time: float elapsed_time: float
@@ -210,7 +212,8 @@ class RunResultsArtifact(ExecutionResult, ArtifactMixin):
) )
def write(self, path: str, omit_none=False): def write(self, path: str, omit_none=False):
write_json(path, self.to_dict(omit_none=omit_none)) # TODO: Implement omit_none
write_json(path, self.to_dict())
@dataclass @dataclass
@@ -268,14 +271,14 @@ class FreshnessErrorEnum(StrEnum):
@dataclass @dataclass
class SourceFreshnessRuntimeError(JsonSchemaMixin): class SourceFreshnessRuntimeError(dbtClassMixin):
unique_id: str unique_id: str
error: Optional[Union[str, int]] error: Optional[Union[str, int]]
status: FreshnessErrorEnum status: FreshnessErrorEnum
@dataclass @dataclass
class SourceFreshnessOutput(JsonSchemaMixin): class SourceFreshnessOutput(dbtClassMixin):
unique_id: str unique_id: str
max_loaded_at: datetime max_loaded_at: datetime
snapshotted_at: datetime snapshotted_at: datetime
@@ -383,7 +386,7 @@ CatalogKey = NamedTuple(
@dataclass @dataclass
class StatsItem(JsonSchemaMixin): class StatsItem(dbtClassMixin):
id: str id: str
label: str label: str
value: Primitive value: Primitive
@@ -395,7 +398,7 @@ StatsDict = Dict[str, StatsItem]
@dataclass @dataclass
class ColumnMetadata(JsonSchemaMixin): class ColumnMetadata(dbtClassMixin):
type: str type: str
comment: Optional[str] comment: Optional[str]
index: int index: int
@@ -406,7 +409,7 @@ ColumnMap = Dict[str, ColumnMetadata]
@dataclass @dataclass
class TableMetadata(JsonSchemaMixin): class TableMetadata(dbtClassMixin):
type: str type: str
database: Optional[str] database: Optional[str]
schema: str schema: str
@@ -416,7 +419,7 @@ class TableMetadata(JsonSchemaMixin):
@dataclass @dataclass
class CatalogTable(JsonSchemaMixin, Replaceable): class CatalogTable(dbtClassMixin, Replaceable):
metadata: TableMetadata metadata: TableMetadata
columns: ColumnMap columns: ColumnMap
stats: StatsDict stats: StatsDict
@@ -439,7 +442,7 @@ class CatalogMetadata(BaseArtifactMetadata):
@dataclass @dataclass
class CatalogResults(JsonSchemaMixin): class CatalogResults(dbtClassMixin):
nodes: Dict[str, CatalogTable] nodes: Dict[str, CatalogTable]
sources: Dict[str, CatalogTable] sources: Dict[str, CatalogTable]
errors: Optional[List[str]] errors: Optional[List[str]]

View File

@@ -26,6 +26,8 @@ from dbt.exceptions import InternalException
from dbt.logger import LogMessage from dbt.logger import LogMessage
from dbt.utils import restrict_to from dbt.utils import restrict_to
from dbt.contracts.jsonschema import dbtClassMixin
TaskTags = Optional[Dict[str, Any]] TaskTags = Optional[Dict[str, Any]]
TaskID = uuid.UUID TaskID = uuid.UUID
@@ -34,7 +36,7 @@ TaskID = uuid.UUID
@dataclass @dataclass
class RPCParameters(JsonSchemaMixin): class RPCParameters(dbtClassMixin):
timeout: Optional[float] timeout: Optional[float]
task_tags: TaskTags task_tags: TaskTags
@@ -132,7 +134,7 @@ class StatusParameters(RPCParameters):
@dataclass @dataclass
class GCSettings(JsonSchemaMixin): class GCSettings(dbtClassMixin):
# start evicting the longest-ago-ended tasks here # start evicting the longest-ago-ended tasks here
maxsize: int maxsize: int
# start evicting all tasks before now - auto_reap_age when we have this # start evicting all tasks before now - auto_reap_age when we have this
@@ -254,7 +256,7 @@ class RemoteExecutionResult(ExecutionResult, RemoteResult):
@dataclass @dataclass
class ResultTable(JsonSchemaMixin): class ResultTable(dbtClassMixin):
column_names: List[str] column_names: List[str]
rows: List[Any] rows: List[Any]
@@ -411,7 +413,7 @@ class TaskHandlerState(StrEnum):
@dataclass @dataclass
class TaskTiming(JsonSchemaMixin): class TaskTiming(dbtClassMixin):
state: TaskHandlerState state: TaskHandlerState
start: Optional[datetime] start: Optional[datetime]
end: Optional[datetime] end: Optional[datetime]

View File

@@ -3,16 +3,18 @@ from hologram import JsonSchemaMixin
from typing import List, Dict, Any, Union from typing import List, Dict, Any, Union
from dbt.contracts.jsonschema import dbtClassMixin
@dataclass @dataclass
class SelectorDefinition(JsonSchemaMixin): class SelectorDefinition(dbtClassMixin):
name: str name: str
definition: Union[str, Dict[str, Any]] definition: Union[str, Dict[str, Any]]
description: str = '' description: str = ''
@dataclass @dataclass
class SelectorFile(JsonSchemaMixin): class SelectorFile(dbtClassMixin):
selectors: List[SelectorDefinition] selectors: List[SelectorDefinition]
version: int = 2 version: int = 2

View File

@@ -14,6 +14,7 @@ from dbt.exceptions import (
from dbt.version import __version__ from dbt.version import __version__
from dbt.tracking import get_invocation_id from dbt.tracking import get_invocation_id
from hologram import JsonSchemaMixin from hologram import JsonSchemaMixin
from dbt.contracts.jsonschema import dbtClassMixin
MacroKey = Tuple[str, str] MacroKey = Tuple[str, str]
SourceKey = Tuple[str, str] SourceKey = Tuple[str, str]
@@ -58,7 +59,8 @@ class Mergeable(Replaceable):
class Writable: class Writable:
def write(self, path: str, omit_none: bool = False): def write(self, path: str, omit_none: bool = False):
write_json(path, self.to_dict(omit_none=omit_none)) # type: ignore # TODO : Do this faster?
write_json(path, self.serialize(omit_none=omit_none)) # type: ignore
class AdditionalPropertiesMixin: class AdditionalPropertiesMixin:
@@ -71,20 +73,20 @@ class AdditionalPropertiesMixin:
@classmethod @classmethod
def from_dict(cls, data, validate=True): def from_dict(cls, data, validate=True):
self = super().from_dict(data=data, validate=validate) self = super().deserialize(data=data, validate=validate)
keys = self.to_dict(validate=False, omit_none=False) keys = self.serialize(validate=False, omit_none=False)
for key, value in data.items(): for key, value in data.items():
if key not in keys: if key not in keys:
self.extra[key] = value self.extra[key] = value
return self return self
def to_dict(self, omit_none=True, validate=False): def to_dict(self, omit_none=True, validate=False):
data = super().to_dict(omit_none=omit_none, validate=validate) data = super().serialize(omit_none=omit_none, validate=validate)
data.update(self.extra) data.update(self.extra)
return data return data
def replace(self, **kwargs): def replace(self, **kwargs):
dct = self.to_dict(omit_none=False, validate=False) dct = self.serialize(omit_none=False, validate=False)
dct.update(kwargs) dct.update(kwargs)
return self.from_dict(dct) return self.from_dict(dct)
@@ -135,7 +137,7 @@ def get_metadata_env() -> Dict[str, str]:
@dataclasses.dataclass @dataclasses.dataclass
class BaseArtifactMetadata(JsonSchemaMixin): class BaseArtifactMetadata(dbtClassMixin):
dbt_schema_version: str dbt_schema_version: str
dbt_version: str = __version__ dbt_version: str = __version__
generated_at: datetime = dataclasses.field( generated_at: datetime = dataclasses.field(
@@ -158,7 +160,7 @@ def schema_version(name: str, version: int):
@dataclasses.dataclass @dataclasses.dataclass
class VersionedSchema(JsonSchemaMixin): class VersionedSchema(dbtClassMixin):
dbt_schema_version: ClassVar[SchemaVersion] dbt_schema_version: ClassVar[SchemaVersion]
@classmethod @classmethod
@@ -194,4 +196,4 @@ class ArtifactMixin(VersionedSchema, Writable, Readable):
if found != expected: if found != expected:
raise IncompatibleSchemaException(expected, found) raise IncompatibleSchemaException(expected, found)
return super().from_dict(data=data, validate=validate) return super().deserialize(data=data, validate=validate)

View File

@@ -4,8 +4,12 @@ from typing import (
import networkx as nx # type: ignore import networkx as nx # type: ignore
from dbt.exceptions import InternalException from dbt.exceptions import InternalException
from dbt.contracts.jsonschema import ValidatedStringMixin
from mashumaro.types import SerializableType
UniqueId = NewType('UniqueId', str)
class UniqueId(ValidatedStringMixin):
ValidationRegex = '.+'
class Graph: class Graph:

View File

@@ -2,16 +2,33 @@
from dataclasses import dataclass from dataclasses import dataclass
from datetime import timedelta from datetime import timedelta
from pathlib import Path from pathlib import Path
from typing import NewType, Tuple, AbstractSet from typing import NewType, Tuple, AbstractSet, Union
from hologram import ( from hologram import (
FieldEncoder, JsonSchemaMixin, JsonDict, ValidationError FieldEncoder, JsonSchemaMixin, JsonDict, ValidationError
) )
from hologram.helpers import StrEnum from hologram.helpers import StrEnum
from dbt.contracts.jsonschema import dbtClassMixin
from mashumaro.types import SerializableType
Port = NewType('Port', int) Port = NewType('Port', int)
class Port(int, SerializableType):
@classmethod
def _deserialize(cls, value: Union[int, str]) -> 'Port':
try:
value = int(value)
except ValueError:
raise ValidationError(f'Cannot encode {value} into port numbr')
return Port(value)
def _serialize(self) -> int:
# TODO : Validate here?
return self
# TODO
class PortEncoder(FieldEncoder): class PortEncoder(FieldEncoder):
@property @property
def json_schema(self): def json_schema(self):
@@ -66,16 +83,17 @@ class NVEnum(StrEnum):
@dataclass @dataclass
class NoValue(JsonSchemaMixin): class NoValue(dbtClassMixin):
"""Sometimes, you want a way to say none that isn't None""" """Sometimes, you want a way to say none that isn't None"""
novalue: NVEnum = NVEnum.novalue novalue: NVEnum = NVEnum.novalue
JsonSchemaMixin.register_field_encoders({ # TODO : None of this is right lol
Port: PortEncoder(), # JsonSchemaMixin.register_field_encoders({
timedelta: TimeDeltaFieldEncoder(), # Port: PortEncoder(),
Path: PathEncoder(), # timedelta: TimeDeltaFieldEncoder(),
}) # Path: PathEncoder(),
# })
FQNPath = Tuple[str, ...] FQNPath = Tuple[str, ...]

View File

@@ -15,6 +15,8 @@ import colorama
import logbook import logbook
from hologram import JsonSchemaMixin from hologram import JsonSchemaMixin
from dbt.contracts.jsonschema import dbtClassMixin
# Colorama needs some help on windows because we're using logger.info # Colorama needs some help on windows because we're using logger.info
# intead of print(). If the Windows env doesn't have a TERM var set, # intead of print(). If the Windows env doesn't have a TERM var set,
# then we should override the logging stream to use the colorama # then we should override the logging stream to use the colorama
@@ -49,7 +51,7 @@ Extras = Dict[str, Any]
@dataclass @dataclass
class LogMessage(JsonSchemaMixin): class LogMessage(dbtClassMixin):
timestamp: datetime timestamp: datetime
message: str message: str
channel: str channel: str
@@ -215,7 +217,7 @@ class TextOnly(logbook.Processor):
class TimingProcessor(logbook.Processor): class TimingProcessor(logbook.Processor):
def __init__(self, timing_info: Optional[JsonSchemaMixin] = None): def __init__(self, timing_info: Optional[dbtClassMixin] = None):
self.timing_info = timing_info self.timing_info = timing_info
super().__init__() super().__init__()

View File

@@ -137,6 +137,7 @@ def main(args=None):
exit_code = e.code exit_code = e.code
except BaseException as e: except BaseException as e:
print(traceback.format_exc())
logger.warning("Encountered an error:") logger.warning("Encountered an error:")
logger.warning(str(e)) logger.warning(str(e))

View File

@@ -13,7 +13,7 @@ class AnalysisParser(SimpleSQLParser[ParsedAnalysisNode]):
) )
def parse_from_dict(self, dct, validate=True) -> ParsedAnalysisNode: def parse_from_dict(self, dct, validate=True) -> ParsedAnalysisNode:
return ParsedAnalysisNode.from_dict(dct, validate=validate) return ParsedAnalysisNode.deserialize(dct, validate=validate)
@property @property
def resource_type(self) -> NodeType: def resource_type(self) -> NodeType:

View File

@@ -12,7 +12,7 @@ class DataTestParser(SimpleSQLParser[ParsedDataTestNode]):
) )
def parse_from_dict(self, dct, validate=True) -> ParsedDataTestNode: def parse_from_dict(self, dct, validate=True) -> ParsedDataTestNode:
return ParsedDataTestNode.from_dict(dct, validate=validate) return ParsedDataTestNode.deserialize(dct, validate=validate)
@property @property
def resource_type(self) -> NodeType: def resource_type(self) -> NodeType:

View File

@@ -79,7 +79,7 @@ class HookParser(SimpleParser[HookBlock, ParsedHookNode]):
return [path] return [path]
def parse_from_dict(self, dct, validate=True) -> ParsedHookNode: def parse_from_dict(self, dct, validate=True) -> ParsedHookNode:
return ParsedHookNode.from_dict(dct, validate=validate) return ParsedHookNode.deserialize(dct, validate=validate)
@classmethod @classmethod
def get_compiled_path(cls, block: HookBlock): def get_compiled_path(cls, block: HookBlock):

View File

@@ -53,20 +53,22 @@ from dbt.version import __version__
from hologram import JsonSchemaMixin from hologram import JsonSchemaMixin
from dbt.contracts.jsonschema import dbtClassMixin
PARTIAL_PARSE_FILE_NAME = 'partial_parse.pickle' PARTIAL_PARSE_FILE_NAME = 'partial_parse.pickle'
PARSING_STATE = DbtProcessState('parsing') PARSING_STATE = DbtProcessState('parsing')
DEFAULT_PARTIAL_PARSE = False DEFAULT_PARTIAL_PARSE = False
@dataclass @dataclass
class ParserInfo(JsonSchemaMixin): class ParserInfo(dbtClassMixin):
parser: str parser: str
elapsed: float elapsed: float
path_count: int = 0 path_count: int = 0
@dataclass @dataclass
class ProjectLoaderInfo(JsonSchemaMixin): class ProjectLoaderInfo(dbtClassMixin):
project_name: str project_name: str
elapsed: float elapsed: float
parsers: List[ParserInfo] parsers: List[ParserInfo]
@@ -74,7 +76,7 @@ class ProjectLoaderInfo(JsonSchemaMixin):
@dataclass @dataclass
class ManifestLoaderInfo(JsonSchemaMixin, Writable): class ManifestLoaderInfo(dbtClassMixin, Writable):
path_count: int = 0 path_count: int = 0
is_partial_parse_enabled: Optional[bool] = None is_partial_parse_enabled: Optional[bool] = None
parse_project_elapsed: Optional[float] = None parse_project_elapsed: Optional[float] = None

View File

@@ -11,7 +11,7 @@ class ModelParser(SimpleSQLParser[ParsedModelNode]):
) )
def parse_from_dict(self, dct, validate=True) -> ParsedModelNode: def parse_from_dict(self, dct, validate=True) -> ParsedModelNode:
return ParsedModelNode.from_dict(dct, validate=validate) return ParsedModelNode.deserialize(dct, validate=validate)
@property @property
def resource_type(self) -> NodeType: def resource_type(self) -> NodeType:

View File

@@ -3,6 +3,8 @@ from typing import TypeVar, MutableMapping, Mapping, Union, List
from hologram import JsonSchemaMixin from hologram import JsonSchemaMixin
from dbt.contracts.jsonschema import dbtClassMixin
from dbt.contracts.files import RemoteFile, FileHash, SourceFile from dbt.contracts.files import RemoteFile, FileHash, SourceFile
from dbt.contracts.graph.compiled import CompileResultNode from dbt.contracts.graph.compiled import CompileResultNode
from dbt.contracts.graph.parsed import ( from dbt.contracts.graph.parsed import (
@@ -62,7 +64,7 @@ def dict_field():
@dataclass @dataclass
class ParseResult(JsonSchemaMixin, Writable, Replaceable): class ParseResult(dbtClassMixin, Writable, Replaceable):
vars_hash: FileHash vars_hash: FileHash
profile_hash: FileHash profile_hash: FileHash
project_hashes: MutableMapping[str, FileHash] project_hashes: MutableMapping[str, FileHash]

View File

@@ -26,7 +26,7 @@ class RPCCallParser(SimpleSQLParser[ParsedRPCNode]):
return [] return []
def parse_from_dict(self, dct, validate=True) -> ParsedRPCNode: def parse_from_dict(self, dct, validate=True) -> ParsedRPCNode:
return ParsedRPCNode.from_dict(dct, validate=validate) return ParsedRPCNode.deserialize(dct, validate=validate)
@property @property
def resource_type(self) -> NodeType: def resource_type(self) -> NodeType:

View File

@@ -7,6 +7,7 @@ from typing import (
) )
from hologram import ValidationError, JsonSchemaMixin from hologram import ValidationError, JsonSchemaMixin
from dbt.contracts.jsonschema import dbtClassMixin
from dbt.adapters.factory import get_adapter from dbt.adapters.factory import get_adapter
from dbt.clients.jinja import get_rendered, add_rendered_test_kwargs from dbt.clients.jinja import get_rendered, add_rendered_test_kwargs
@@ -119,7 +120,8 @@ class ParserRef:
meta=meta, meta=meta,
tags=tags, tags=tags,
quote=quote, quote=quote,
_extra=column.extra # TODO
#_extra=column.extra
) )
@classmethod @classmethod
@@ -201,7 +203,7 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
) )
def parse_from_dict(self, dct, validate=True) -> ParsedSchemaTestNode: def parse_from_dict(self, dct, validate=True) -> ParsedSchemaTestNode:
return ParsedSchemaTestNode.from_dict(dct, validate=validate) return ParsedSchemaTestNode.deserialize(dct, validate=validate)
def _parse_format_version( def _parse_format_version(
self, yaml: YamlBlock self, yaml: YamlBlock
@@ -654,7 +656,7 @@ class YamlDocsReader(YamlReader):
raise NotImplementedError('parse is abstract') raise NotImplementedError('parse is abstract')
T = TypeVar('T', bound=JsonSchemaMixin) T = TypeVar('T', bound=dbtClassMixin)
class SourceParser(YamlDocsReader): class SourceParser(YamlDocsReader):

View File

@@ -13,7 +13,7 @@ class SeedParser(SimpleSQLParser[ParsedSeedNode]):
) )
def parse_from_dict(self, dct, validate=True) -> ParsedSeedNode: def parse_from_dict(self, dct, validate=True) -> ParsedSeedNode:
return ParsedSeedNode.from_dict(dct, validate=validate) return ParsedSeedNode.deserialize(dct, validate=validate)
@property @property
def resource_type(self) -> NodeType: def resource_type(self) -> NodeType:

View File

@@ -26,7 +26,7 @@ class SnapshotParser(
) )
def parse_from_dict(self, dct, validate=True) -> IntermediateSnapshotNode: def parse_from_dict(self, dct, validate=True) -> IntermediateSnapshotNode:
return IntermediateSnapshotNode.from_dict(dct, validate=validate) return IntermediateSnapshotNode.deserialize(dct, validate=validate)
@property @property
def resource_type(self) -> NodeType: def resource_type(self) -> NodeType:
@@ -66,6 +66,9 @@ class SnapshotParser(
def transform(self, node: IntermediateSnapshotNode) -> ParsedSnapshotNode: def transform(self, node: IntermediateSnapshotNode) -> ParsedSnapshotNode:
try: try:
# TODO : This is not going to work at all
# TODO : We need to accept the config and turn that into a concrete
# Snapshot node type
parsed_node = ParsedSnapshotNode.from_dict(node.to_dict()) parsed_node = ParsedSnapshotNode.from_dict(node.to_dict())
self.set_snapshot_attributes(parsed_node) self.set_snapshot_attributes(parsed_node)
return parsed_node return parsed_node

View File

@@ -15,6 +15,9 @@ from dbt.contracts.rpc import (
from dbt.exceptions import InternalException from dbt.exceptions import InternalException
from dbt.utils import restrict_to from dbt.utils import restrict_to
from dbt.contracts.jsonschema import dbtClassMixin
from mashumaro.types import SerializableType
class QueueMessageType(StrEnum): class QueueMessageType(StrEnum):
Error = 'error' Error = 'error'
@@ -26,16 +29,41 @@ class QueueMessageType(StrEnum):
@dataclass @dataclass
class QueueMessage(JsonSchemaMixin): class QueueMessage(dbtClassMixin):
message_type: QueueMessageType message_type: QueueMessageType
class SerializableLogRecord(logbook.LogRecord, SerializableType):
def _serialize(self):
# TODO
import ipdb; ipdb.set_trace()
pass
@classmethod
def _deserialize(cls, value):
# TODO
import ipdb; ipdb.set_trace()
pass
class SerializableJSONRPCError(JSONRPCError, SerializableType):
def _serialize(self):
# TODO
import ipdb; ipdb.set_trace()
pass
@classmethod
def _deserialize(cls, value):
# TODO
import ipdb; ipdb.set_trace()
pass
@dataclass @dataclass
class QueueLogMessage(QueueMessage): class QueueLogMessage(QueueMessage):
message_type: QueueMessageType = field( message_type: QueueMessageType = field(
metadata=restrict_to(QueueMessageType.Log) metadata=restrict_to(QueueMessageType.Log)
) )
record: logbook.LogRecord record: SerializableLogRecord
@classmethod @classmethod
def from_record(cls, record: logbook.LogRecord): def from_record(cls, record: logbook.LogRecord):
@@ -50,7 +78,7 @@ class QueueErrorMessage(QueueMessage):
message_type: QueueMessageType = field( message_type: QueueMessageType = field(
metadata=restrict_to(QueueMessageType.Error) metadata=restrict_to(QueueMessageType.Error)
) )
error: JSONRPCError error: SerializableJSONRPCError
@classmethod @classmethod
def from_error(cls, error: JSONRPCError): def from_error(cls, error: JSONRPCError):

View File

@@ -8,6 +8,8 @@ from hologram import JsonSchemaMixin, ValidationError
from dbt.contracts.rpc import RPCParameters, RemoteResult, RemoteMethodFlags from dbt.contracts.rpc import RPCParameters, RemoteResult, RemoteMethodFlags
from dbt.exceptions import NotImplementedException, InternalException from dbt.exceptions import NotImplementedException, InternalException
from dbt.contracts.jsonschema import dbtClassMixin
Parameters = TypeVar('Parameters', bound=RPCParameters) Parameters = TypeVar('Parameters', bound=RPCParameters)
Result = TypeVar('Result', bound=RemoteResult) Result = TypeVar('Result', bound=RemoteResult)
@@ -109,7 +111,7 @@ class RemoteBuiltinMethod(RemoteMethod[Parameters, Result]):
'the run() method on builtins should never be called' 'the run() method on builtins should never be called'
) )
def __call__(self, **kwargs: Dict[str, Any]) -> JsonSchemaMixin: def __call__(self, **kwargs: Dict[str, Any]) -> dbtClassMixin:
try: try:
params = self.get_parameters().from_dict(kwargs) params = self.get_parameters().from_dict(kwargs)
except ValidationError as exc: except ValidationError as exc:

View File

@@ -20,6 +20,8 @@ from dbt.rpc.task_handler import RequestTaskHandler
from dbt.rpc.method import RemoteMethod from dbt.rpc.method import RemoteMethod
from dbt.rpc.task_manager import TaskManager from dbt.rpc.task_manager import TaskManager
from dbt.contracts.jsonschema import dbtClassMixin
def track_rpc_request(task): def track_rpc_request(task):
dbt.tracking.track_rpc_request({ dbt.tracking.track_rpc_request({
@@ -90,11 +92,11 @@ class ResponseManager(JSONRPCResponseManager):
@classmethod @classmethod
def _get_responses(cls, requests, dispatcher): def _get_responses(cls, requests, dispatcher):
for output in super()._get_responses(requests, dispatcher): for output in super()._get_responses(requests, dispatcher):
# if it's a result, check if it's a JsonSchemaMixin and if so call # if it's a result, check if it's a dbtClassMixin and if so call
# to_dict # to_dict
if hasattr(output, 'result'): if hasattr(output, 'result'):
if isinstance(output.result, JsonSchemaMixin): if isinstance(output.result, dbtClassMixin):
output.result = output.result.to_dict(omit_none=False) output.result = output.result.serialize(omit_none=False)
yield output yield output
@classmethod @classmethod

View File

@@ -43,6 +43,9 @@ from dbt.rpc.method import RemoteMethod
# we use this in typing only... # we use this in typing only...
from queue import Queue # noqa from queue import Queue # noqa
from dbt.contracts.jsonschema import dbtClassMixin
def sigterm_handler(signum, frame): def sigterm_handler(signum, frame):
raise dbt.exceptions.RPCKilledException(signum) raise dbt.exceptions.RPCKilledException(signum)
@@ -283,7 +286,7 @@ class RequestTaskHandler(threading.Thread, TaskHandlerProtocol):
# - The actual thread that this represents, which writes its data to # - The actual thread that this represents, which writes its data to
# the result and logs. The atomicity of list.append() and item # the result and logs. The atomicity of list.append() and item
# assignment means we don't need a lock. # assignment means we don't need a lock.
self.result: Optional[JsonSchemaMixin] = None self.result: Optional[dbtClassMixin] = None
self.error: Optional[RPCException] = None self.error: Optional[RPCException] = None
self.state: TaskHandlerState = TaskHandlerState.NotStarted self.state: TaskHandlerState = TaskHandlerState.NotStarted
self.logs: List[LogMessage] = [] self.logs: List[LogMessage] = []

View File

@@ -8,6 +8,8 @@ from hologram import JsonSchemaMixin
from hologram.helpers import StrEnum from hologram.helpers import StrEnum
from typing import Optional from typing import Optional
from dbt.contracts.jsonschema import dbtClassMixin
class Matchers(StrEnum): class Matchers(StrEnum):
GREATER_THAN = '>' GREATER_THAN = '>'
@@ -18,12 +20,12 @@ class Matchers(StrEnum):
@dataclass @dataclass
class VersionSpecification(JsonSchemaMixin): class VersionSpecification(dbtClassMixin):
major: Optional[str] major: Optional[str] = None
minor: Optional[str] minor: Optional[str] = None
patch: Optional[str] patch: Optional[str] = None
prerelease: Optional[str] prerelease: Optional[str] = None
build: Optional[str] build: Optional[str] = None
matcher: Matchers = Matchers.EXACT matcher: Matchers = Matchers.EXACT

View File

@@ -110,7 +110,7 @@ class ListTask(GraphRunnableTask):
for node in self._iterate_selected_nodes(): for node in self._iterate_selected_nodes():
yield json.dumps({ yield json.dumps({
k: v k: v
for k, v in node.to_dict(omit_none=False).items() for k, v in node.serialize(omit_none=False).items()
if k in self.ALLOWED_KEYS if k in self.ALLOWED_KEYS
}) })

View File

@@ -16,13 +16,14 @@ from typing import Optional
class PostgresCredentials(Credentials): class PostgresCredentials(Credentials):
host: str host: str
user: str user: str
role: Optional[str]
port: Port port: Port
password: str # on postgres the password is mandatory password: str # on postgres the password is mandatory
role: Optional[str] = None
search_path: Optional[str] = None search_path: Optional[str] = None
keepalives_idle: int = 0 # 0 means to use the default value keepalives_idle: int = 0 # 0 means to use the default value
sslmode: Optional[str] = None sslmode: Optional[str] = None
# TODO : Fix all instances of _ALIASES? Or include them in some base class?
_ALIASES = { _ALIASES = {
'dbname': 'database', 'dbname': 'database',
'pass': 'password' 'pass': 'password'