Compare commits

...

5 Commits

Author SHA1 Message Date
Drew Banin
b7cfd70ba3 cleanup 2020-12-20 13:08:08 -05:00
Drew Banin
1ba1fb964b working on timing debug project 2020-12-19 14:31:56 -05:00
Drew Banin
98f573ef29 Merge branch 'dev/kiyoshi-kuromiya' into experiment/use-mashumaro 2020-12-18 16:19:42 -05:00
Drew Banin
c744cca96f experiment with mashumaro 2020-11-12 10:16:04 -05:00
Drew Banin
76571819f3 use mashumaro to encode/decode Relation objects 2020-11-11 14:43:32 -05:00
41 changed files with 471 additions and 373 deletions

View File

@@ -5,10 +5,11 @@ from hologram import JsonSchemaMixin
from dbt.exceptions import RuntimeException
from typing import Dict, ClassVar, Any, Optional
from dbt.contracts.jsonschema import dbtClassMixin
@dataclass
class Column(JsonSchemaMixin):
class Column(dbtClassMixin):
TYPE_LABELS: ClassVar[Dict[str, str]] = {
'STRING': 'TEXT',
'TIMESTAMP': 'TIMESTAMP',

View File

@@ -47,13 +47,16 @@ class BaseRelation(FakeAPIObject, Hashable):
return False
return self.to_dict() == other.to_dict()
# TODO : Unclear why we're leveraging a type system to implement inheritance?
@classmethod
def get_default_quote_policy(cls) -> Policy:
return cls._get_field_named('quote_policy').default
#return cls._get_field_named('quote_policy').default
return Policy()
@classmethod
def get_default_include_policy(cls) -> Policy:
return cls._get_field_named('include_policy').default
#return cls._get_field_named('include_policy').default
return Policy()
def get(self, key, default=None):
"""Override `.get` to return a metadata object so we don't break

View File

@@ -137,7 +137,7 @@ class Profile(HasCredentials):
def validate(self):
try:
if self.credentials:
self.credentials.to_dict(validate=True)
self.credentials.serialize(validate=True)
ProfileConfig.from_dict(
self.to_profile_info(serialize_credentials=True)
)

View File

@@ -306,7 +306,7 @@ class PartialProject(RenderComponents):
)
try:
cfg = ProjectContract.from_dict(rendered.project_dict)
cfg = ProjectContract.deserialize(rendered.project_dict)
except ValidationError as e:
raise DbtProjectError(validator_error_message(e)) from e
# name/version are required in the Project definition, so we can assume
@@ -586,7 +586,9 @@ class Project:
def validate(self):
try:
ProjectContract.from_dict(self.to_project_config())
# TODO : Jank; need to do this to handle aliasing between hyphens and underscores
as_dict = self.to_project_config()
ProjectContract.deserialize(as_dict)
except ValidationError as e:
raise DbtProjectError(validator_error_message(e)) from e

View File

@@ -174,7 +174,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
:raises DbtProjectError: If the configuration fails validation.
"""
try:
Configuration.from_dict(self.serialize())
Configuration.deserialize(self.serialize())
except ValidationError as e:
raise DbtProjectError(validator_error_message(e)) from e

View File

@@ -165,7 +165,9 @@ class ContextConfigGenerator(BaseContextConfigGenerator[C]):
# Calculate the defaults. We don't want to validate the defaults,
# because it might be invalid in the case of required config members
# (such as on snapshots!)
result = config_cls.from_dict({}, validate=False)
result = config_cls.from_dict({})
# TODO - why validate here?
# result.validate()
return result
def _update_from_config(

View File

@@ -17,9 +17,12 @@ from dbt.utils import translate_aliases
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.contracts.jsonschema import dbtClassMixin, ValidatedStringMixin
from mashumaro.types import SerializableType
Identifier = NewType('Identifier', str)
register_pattern(Identifier, r'^[A-Za-z_][A-Za-z0-9_]+$')
class Identifier(ValidatedStringMixin):
ValidationRegex = r'^[A-Za-z_][A-Za-z0-9_]+$'
class ConnectionState(StrEnum):
@@ -28,22 +31,30 @@ class ConnectionState(StrEnum):
CLOSED = 'closed'
FAIL = 'fail'
# I think that... this is not the right way to do this
# TODO!!
class DoNotSerializeType(SerializableType):
def _serialize(self) -> None:
return None
# TODO : Figure out ExtensibleJsonSchemaMixin?
@dataclass(init=False)
class Connection(ExtensibleJsonSchemaMixin, Replaceable):
class Connection(dbtClassMixin, Replaceable):
type: Identifier
name: Optional[str]
state: ConnectionState = ConnectionState.INIT
transaction_open: bool = False
# prevent serialization
_handle: Optional[Any] = None
_credentials: JsonSchemaMixin = field(init=False)
#_handle: Optional[Any] = None
#_credentials: dbtClassMixin = field(init=False)
_handle: Optional[DoNotSerializeType] = None
_credentials: Optional[DoNotSerializeType] = None
def __init__(
self,
type: Identifier,
name: Optional[str],
credentials: JsonSchemaMixin,
credentials: dbtClassMixin,
state: ConnectionState = ConnectionState.INIT,
transaction_open: bool = False,
handle: Optional[Any] = None,
@@ -102,7 +113,8 @@ class LazyHandle:
# will work.
@dataclass # type: ignore
class Credentials(
ExtensibleJsonSchemaMixin,
# ExtensibleJsonSchemaMixin,
dbtClassMixin,
Replaceable,
metaclass=abc.ABCMeta
):
@@ -121,7 +133,8 @@ class Credentials(
) -> Iterable[Tuple[str, Any]]:
"""Return an ordered iterator of key/value pairs for pretty-printing.
"""
as_dict = self.to_dict(omit_none=False, with_aliases=with_aliases)
# TODO: Does this... work?
as_dict = self.serialize(omit_none=False, with_aliases=with_aliases)
connection_keys = set(self._connection_keys())
aliases: List[str] = []
if with_aliases:
@@ -136,10 +149,11 @@ class Credentials(
def _connection_keys(self) -> Tuple[str, ...]:
raise NotImplementedError
@classmethod
def from_dict(cls, data):
data = cls.translate_aliases(data)
return super().from_dict(data)
# TODO TODO TODO
# @classmethod
# def from_dict(cls, data):
# data = cls.translate_aliases(data)
# return super().from_dict(data)
@classmethod
def translate_aliases(
@@ -147,15 +161,16 @@ class Credentials(
) -> Dict[str, Any]:
return translate_aliases(kwargs, cls._ALIASES, recurse)
def to_dict(self, omit_none=True, validate=False, *, with_aliases=False):
serialized = super().to_dict(omit_none=omit_none, validate=validate)
if with_aliases:
serialized.update({
new_name: serialized[canonical_name]
for new_name, canonical_name in self._ALIASES.items()
if canonical_name in serialized
})
return serialized
# TODO TODO TODO
# def to_dict(self, omit_none=True, validate=False, *, with_aliases=False):
# serialized = super().to_dict(omit_none=omit_none, validate=validate)
# if with_aliases:
# serialized.update({
# new_name: serialized[canonical_name]
# for new_name, canonical_name in self._ALIASES.items()
# if canonical_name in serialized
# })
# return serialized
class UserConfigContract(Protocol):
@@ -205,7 +220,7 @@ DEFAULT_QUERY_COMMENT = '''
@dataclass
class QueryComment(JsonSchemaMixin):
class QueryComment(dbtClassMixin):
comment: str = DEFAULT_QUERY_COMMENT
append: bool = False

View File

@@ -9,13 +9,15 @@ from dbt.exceptions import InternalException
from .util import MacroKey, SourceKey
from dbt.contracts.jsonschema import dbtClassMixin
MAXIMUM_SEED_SIZE = 1 * 1024 * 1024
MAXIMUM_SEED_SIZE_NAME = '1MB'
@dataclass
class FilePath(JsonSchemaMixin):
class FilePath(dbtClassMixin):
searched_path: str
relative_path: str
project_root: str
@@ -51,7 +53,7 @@ class FilePath(JsonSchemaMixin):
@dataclass
class FileHash(JsonSchemaMixin):
class FileHash(dbtClassMixin):
name: str # the hash type name
checksum: str # the hashlib.hash_type().hexdigest() of the file contents
@@ -91,7 +93,7 @@ class FileHash(JsonSchemaMixin):
@dataclass
class RemoteFile(JsonSchemaMixin):
class RemoteFile(dbtClassMixin):
@property
def searched_path(self) -> str:
return 'from remote system'
@@ -110,7 +112,7 @@ class RemoteFile(JsonSchemaMixin):
@dataclass
class SourceFile(JsonSchemaMixin):
class SourceFile(dbtClassMixin):
"""Define a source file in dbt"""
path: Union[FilePath, RemoteFile] # the path information
checksum: FileHash

View File

@@ -23,15 +23,17 @@ from hologram import JsonSchemaMixin
from dataclasses import dataclass, field
from typing import Optional, List, Union, Dict, Type
from dbt.contracts.jsonschema import dbtClassMixin
@dataclass
class InjectedCTE(JsonSchemaMixin, Replaceable):
class InjectedCTE(dbtClassMixin, Replaceable):
id: str
sql: str
@dataclass
class CompiledNodeMixin(JsonSchemaMixin):
class CompiledNodeMixin(dbtClassMixin):
# this is a special mixin class to provide a required argument. If a node
# is missing a `compiled` flag entirely, it must not be a CompiledNode.
compiled: bool
@@ -179,7 +181,8 @@ def parsed_instance_for(compiled: CompiledNode) -> ParsedResource:
.format(compiled.resource_type))
# validate=False to allow extra keys from compiling
return cls.from_dict(compiled.to_dict(), validate=False)
# TODO : This is probably not going to do the right thing....
return cls.deserialize(compiled.to_dict(), validate=False)
NonSourceCompiledNode = Union[

View File

@@ -508,10 +508,10 @@ class Manifest:
"""
self.flat_graph = {
'nodes': {
k: v.to_dict(omit_none=False) for k, v in self.nodes.items()
k: v.serialize(omit_none=False) for k, v in self.nodes.items()
},
'sources': {
k: v.to_dict(omit_none=False) for k, v in self.sources.items()
k: v.serialize(omit_none=False) for k, v in self.sources.items()
}
}
@@ -764,7 +764,7 @@ class Manifest:
)
def to_dict(self, omit_none=True, validate=False):
return self.writable_manifest().to_dict(
return self.writable_manifest().serialize(
omit_none=omit_none, validate=validate
)

View File

@@ -21,6 +21,9 @@ from dbt.contracts.util import Replaceable, list_str
from dbt import hooks
from dbt.node_types import NodeType
from dbt.contracts.jsonschema import dbtClassMixin
from mashumaro.types import SerializableType
M = TypeVar('M', bound='Metadata')
@@ -170,9 +173,19 @@ def insensitive_patterns(*patterns: str):
return '^({})$'.format('|'.join(lowercased))
# TODO?
Severity = NewType('Severity', str)
register_pattern(Severity, insensitive_patterns('warn', 'error'))
class Severity(str, SerializableType):
@classmethod
def _deserialize(cls, value: str) -> 'Severity':
# TODO : Validate here?
return Severity(value)
def _serialize(self) -> str:
# TODO : Validate here?
return self
class SnapshotStrategy(StrEnum):
@@ -185,7 +198,7 @@ class All(StrEnum):
@dataclass
class Hook(JsonSchemaMixin, Replaceable):
class Hook(dbtClassMixin, Replaceable):
sql: str
transaction: bool = True
index: Optional[int] = None
@@ -196,7 +209,7 @@ T = TypeVar('T', bound='BaseConfig')
@dataclass
class BaseConfig(
AdditionalPropertiesAllowed, Replaceable, MutableMapping[str, Any]
dbtClassMixin, Replaceable, MutableMapping[str, Any]
):
# Implement MutableMapping so this config will behave as some macros expect
# during parsing (notably, syntax like `{{ node.config['schema'] }}`)
@@ -294,23 +307,25 @@ class BaseConfig(
"""
result = {}
for fld, target_field in cls._get_fields():
if target_field not in data:
continue
# TODO : This is not correct.... must implement without reflection
data_attr = data.pop(target_field)
if target_field not in src:
result[target_field] = data_attr
continue
# for fld, target_field in cls._get_fields():
# if target_field not in data:
# continue
merge_behavior = MergeBehavior.from_field(fld)
self_attr = src[target_field]
# data_attr = data.pop(target_field)
# if target_field not in src:
# result[target_field] = data_attr
# continue
result[target_field] = _merge_field_value(
merge_behavior=merge_behavior,
self_value=self_attr,
other_value=data_attr,
)
# merge_behavior = MergeBehavior.from_field(fld)
# self_attr = src[target_field]
# result[target_field] = _merge_field_value(
# merge_behavior=merge_behavior,
# self_value=self_attr,
# other_value=data_attr,
# )
return result
def to_dict(
@@ -320,7 +335,7 @@ class BaseConfig(
*,
omit_hidden: bool = True,
) -> Dict[str, Any]:
result = super().to_dict(omit_none=omit_none, validate=validate)
result = super().serialize(omit_none=omit_none, validate=validate)
if omit_hidden and not omit_none:
for fld, target_field in self._get_fields():
if target_field not in result:
@@ -344,7 +359,9 @@ class BaseConfig(
"""
# sadly, this is a circular import
from dbt.adapters.factory import get_config_class_by_name
dct = self.to_dict(omit_none=False, validate=False, omit_hidden=False)
# TODO : omit_hidden?
# dct = self.serialize(omit_none=False, validate=False, omit_hidden=False)
dct = self.serialize(omit_none=False, validate=False)
adapter_config_cls = get_config_class_by_name(adapter_type)
@@ -358,11 +375,11 @@ class BaseConfig(
dct.update(data)
# any validation failures must have come from the update
return self.from_dict(dct, validate=validate)
return self.deserialize(dct, validate=validate)
def finalize_and_validate(self: T) -> T:
# from_dict will validate for us
dct = self.to_dict(omit_none=False, validate=False)
dct = self.serialize(omit_none=False, validate=False)
return self.from_dict(dct)
def replace(self, **kwargs):
@@ -372,7 +389,7 @@ class BaseConfig(
for key, value in kwargs.items():
new_key = mapping.get(key, key)
dct[new_key] = value
return self.from_dict(dct, validate=False)
return self.deserialize(dct, validate=False)
@dataclass
@@ -435,12 +452,15 @@ class NodeConfig(BaseConfig):
for key in hooks.ModelHookType:
if key in data:
data[key] = [hooks.get_hook_dict(h) for h in data[key]]
return super().from_dict(data, validate=validate)
return super().deserialize(data, validate=validate)
@classmethod
def field_mapping(cls):
return {'post_hook': 'post-hook', 'pre_hook': 'pre-hook'}
def validate(self):
# TODO : Not implemented!
pass
@dataclass
class SeedConfig(NodeConfig):
@@ -454,63 +474,10 @@ class TestConfig(NodeConfig):
severity: Severity = Severity('ERROR')
SnapshotVariants = Union[
'TimestampSnapshotConfig',
'CheckSnapshotConfig',
'GenericSnapshotConfig',
]
def _relevance_without_strategy(error: jsonschema.ValidationError):
# calculate the 'relevance' of an error the normal jsonschema way, except
# if the validator is in the 'strategy' field and its conflicting with the
# 'enum'. This suppresses `"'timestamp' is not one of ['check']` and such
if 'strategy' in error.path and error.validator in {'enum', 'not'}:
length = 1
else:
length = -len(error.path)
validator = error.validator
return length, validator not in {'anyOf', 'oneOf'}
@dataclass
class SnapshotWrapper(JsonSchemaMixin):
"""This is a little wrapper to let us serialize/deserialize the
SnapshotVariants union.
"""
config: SnapshotVariants # mypy: ignore
@classmethod
def validate(cls, data: Any):
config = data.get('config', {})
if config.get('strategy') == 'check':
schema = _validate_schema(CheckSnapshotConfig)
to_validate = config
elif config.get('strategy') == 'timestamp':
schema = _validate_schema(TimestampSnapshotConfig)
to_validate = config
else:
h_cls = cast(Hashable, cls)
schema = _validate_schema(h_cls)
to_validate = data
validator = jsonschema.Draft7Validator(schema)
error = jsonschema.exceptions.best_match(
validator.iter_errors(to_validate),
key=_relevance_without_strategy,
)
if error is not None:
raise ValidationError.create_from(error) from error
@dataclass
class EmptySnapshotConfig(NodeConfig):
materialized: str = 'snapshot'
strategy: str = None
@dataclass(init=False)
@@ -519,117 +486,17 @@ class SnapshotConfig(EmptySnapshotConfig):
target_schema: str = field(init=False, metadata=dict(init_required=True))
target_database: Optional[str] = None
def __init__(
self,
unique_key: str,
target_schema: str,
target_database: Optional[str] = None,
**kwargs
) -> None:
self.unique_key = unique_key
self.target_schema = target_schema
self.target_database = target_database
# kwargs['materialized'] = materialized
super().__init__(**kwargs)
# type hacks...
@classmethod
def _get_fields(cls) -> List[Tuple[Field, str]]: # type: ignore
fields: List[Tuple[Field, str]] = []
for old_field, name in super()._get_fields():
new_field = old_field
# tell hologram we're really an initvar
if old_field.metadata and old_field.metadata.get('init_required'):
new_field = field(init=True, metadata=old_field.metadata)
new_field.name = old_field.name
new_field.type = old_field.type
new_field._field_type = old_field._field_type # type: ignore
fields.append((new_field, name))
return fields
def finalize_and_validate(self: 'SnapshotConfig') -> SnapshotVariants:
data = self.to_dict()
return SnapshotWrapper.from_dict({'config': data}).config
@dataclass(init=False)
class GenericSnapshotConfig(SnapshotConfig):
strategy: str = field(init=False, metadata=dict(init_required=True))
def __init__(self, strategy: str, **kwargs) -> None:
self.strategy = strategy
super().__init__(**kwargs)
@classmethod
def _collect_json_schema(
cls, definitions: Dict[str, Any]
) -> Dict[str, Any]:
# this is the method you want to override in hologram if you want
# to do clever things about the json schema and have classes that
# contain instances of your JsonSchemaMixin respect the change.
schema = super()._collect_json_schema(definitions)
# Instead of just the strategy we'd calculate normally, say
# "this strategy except none of our specialization strategies".
strategies = [schema['properties']['strategy']]
for specialization in (TimestampSnapshotConfig, CheckSnapshotConfig):
strategies.append(
{'not': specialization.json_schema()['properties']['strategy']}
)
schema['properties']['strategy'] = {
'allOf': strategies
}
return schema
@dataclass(init=False)
class TimestampSnapshotConfig(SnapshotConfig):
strategy: str = field(
init=False,
metadata=dict(
restrict=[str(SnapshotStrategy.Timestamp)],
init_required=True,
),
)
updated_at: str = field(init=False, metadata=dict(init_required=True))
def __init__(
self, strategy: str, updated_at: str, **kwargs
) -> None:
self.strategy = strategy
self.updated_at = updated_at
super().__init__(**kwargs)
updated_at: str
@dataclass(init=False)
class CheckSnapshotConfig(SnapshotConfig):
strategy: str = field(
init=False,
metadata=dict(
restrict=[str(SnapshotStrategy.Check)],
init_required=True,
),
)
# TODO: is there a way to get this to accept tuples of strings? Adding
# `Tuple[str, ...]` to the list of types results in this:
# ['email'] is valid under each of {'type': 'array', 'items':
# {'type': 'string'}}, {'type': 'array', 'items': {'type': 'string'}}
# but without it, parsing gets upset about values like `('email',)`
# maybe hologram itself should support this behavior? It's not like tuples
# are meaningful in json
check_cols: Union[All, List[str]] = field(
init=False,
metadata=dict(init_required=True),
)
check_cols: Union[All, List[str]]
def __init__(
self, strategy: str, check_cols: Union[All, List[str]],
**kwargs
) -> None:
self.strategy = strategy
self.check_cols = check_cols
super().__init__(**kwargs)
class CheckSnapshotConfig(SnapshotConfig):
pass
RESOURCE_TYPES: Dict[NodeType, Type[BaseConfig]] = {

View File

@@ -31,6 +31,8 @@ from dbt.logger import GLOBAL_LOGGER as logger # noqa
from dbt import flags
from dbt.node_types import NodeType
from dbt.contracts.jsonschema import dbtClassMixin
from .model_config import (
NodeConfig,
@@ -38,20 +40,15 @@ from .model_config import (
TestConfig,
SourceConfig,
EmptySnapshotConfig,
SnapshotVariants,
)
# import these 3 so the SnapshotVariants forward ref works.
from .model_config import ( # noqa
TimestampSnapshotConfig,
CheckSnapshotConfig,
GenericSnapshotConfig,
SnapshotConfig,
)
# TODO : Figure out AdditionalPropertiesMixin and ExtensibleJsonSchemaMixin
@dataclass
class ColumnInfo(
AdditionalPropertiesMixin,
ExtensibleJsonSchemaMixin,
dbtClassMixin,
#AdditionalPropertiesMixin,
#ExtensibleJsonSchemaMixin,
Replaceable
):
name: str
@@ -64,7 +61,7 @@ class ColumnInfo(
@dataclass
class HasFqn(JsonSchemaMixin, Replaceable):
class HasFqn(dbtClassMixin, Replaceable):
fqn: List[str]
def same_fqn(self, other: 'HasFqn') -> bool:
@@ -72,12 +69,12 @@ class HasFqn(JsonSchemaMixin, Replaceable):
@dataclass
class HasUniqueID(JsonSchemaMixin, Replaceable):
class HasUniqueID(dbtClassMixin, Replaceable):
unique_id: str
@dataclass
class MacroDependsOn(JsonSchemaMixin, Replaceable):
class MacroDependsOn(dbtClassMixin, Replaceable):
macros: List[str] = field(default_factory=list)
# 'in' on lists is O(n) so this is O(n^2) for # of macros
@@ -96,12 +93,12 @@ class DependsOn(MacroDependsOn):
@dataclass
class HasRelationMetadata(JsonSchemaMixin, Replaceable):
class HasRelationMetadata(dbtClassMixin, Replaceable):
database: Optional[str]
schema: str
class ParsedNodeMixins(JsonSchemaMixin):
class ParsedNodeMixins(dbtClassMixin):
resource_type: NodeType
depends_on: DependsOn
config: NodeConfig
@@ -132,8 +129,8 @@ class ParsedNodeMixins(JsonSchemaMixin):
self.meta = patch.meta
self.docs = patch.docs
if flags.STRICT_MODE:
assert isinstance(self, JsonSchemaMixin)
self.to_dict(validate=True, omit_none=False)
assert isinstance(self, dbtClassMixin)
self.serialize(validate=True, omit_none=False)
def get_materialization(self):
return self.config.materialized
@@ -335,14 +332,14 @@ class ParsedSeedNode(ParsedNode):
@dataclass
class TestMetadata(JsonSchemaMixin, Replaceable):
class TestMetadata(dbtClassMixin, Replaceable):
namespace: Optional[str]
name: str
kwargs: Dict[str, Any]
@dataclass
class HasTestMetadata(JsonSchemaMixin):
class HasTestMetadata(dbtClassMixin):
test_metadata: TestMetadata
@@ -394,7 +391,7 @@ class IntermediateSnapshotNode(ParsedNode):
@dataclass
class ParsedSnapshotNode(ParsedNode):
resource_type: NodeType = field(metadata={'restrict': [NodeType.Snapshot]})
config: SnapshotVariants
config: SnapshotConfig
@dataclass
@@ -443,8 +440,8 @@ class ParsedMacro(UnparsedBaseNode, HasUniqueID):
self.docs = patch.docs
self.arguments = patch.arguments
if flags.STRICT_MODE:
assert isinstance(self, JsonSchemaMixin)
self.to_dict(validate=True, omit_none=False)
assert isinstance(self, dbtClassMixin)
self.serialize(validate=True, omit_none=False)
def same_contents(self, other: Optional['ParsedMacro']) -> bool:
if other is None:

View File

@@ -16,9 +16,11 @@ from datetime import timedelta
from pathlib import Path
from typing import Optional, List, Union, Dict, Any, Sequence
from dbt.contracts.jsonschema import dbtClassMixin
@dataclass
class UnparsedBaseNode(JsonSchemaMixin, Replaceable):
class UnparsedBaseNode(dbtClassMixin, Replaceable):
package_name: str
root_path: str
path: str
@@ -66,18 +68,20 @@ class UnparsedRunHook(UnparsedNode):
@dataclass
class Docs(JsonSchemaMixin, Replaceable):
class Docs(dbtClassMixin, Replaceable):
show: bool = True
# TODO : This should have AdditionalPropertiesMixin and ExtensibleJsonSchemaMixin
@dataclass
class HasDocs(AdditionalPropertiesMixin, ExtensibleJsonSchemaMixin,
Replaceable):
class HasDocs(dbtClassMixin, Replaceable):
name: str
description: str = ''
meta: Dict[str, Any] = field(default_factory=dict)
data_type: Optional[str] = None
docs: Docs = field(default_factory=Docs)
# TODO : How do we handle these additional fields with mashurmaro?
_extra: Dict[str, Any] = field(default_factory=dict)
@@ -100,7 +104,7 @@ class UnparsedColumn(HasTests):
@dataclass
class HasColumnDocs(JsonSchemaMixin, Replaceable):
class HasColumnDocs(dbtClassMixin, Replaceable):
columns: Sequence[HasDocs] = field(default_factory=list)
@@ -110,7 +114,7 @@ class HasColumnTests(HasColumnDocs):
@dataclass
class HasYamlMetadata(JsonSchemaMixin):
class HasYamlMetadata(dbtClassMixin):
original_file_path: str
yaml_key: str
package_name: str
@@ -127,7 +131,7 @@ class UnparsedNodeUpdate(HasColumnTests, HasTests, HasYamlMetadata):
@dataclass
class MacroArgument(JsonSchemaMixin):
class MacroArgument(dbtClassMixin):
name: str
type: Optional[str] = None
description: str = ''
@@ -148,7 +152,7 @@ class TimePeriod(StrEnum):
@dataclass
class Time(JsonSchemaMixin, Replaceable):
class Time(dbtClassMixin, Replaceable):
count: int
period: TimePeriod
@@ -159,7 +163,7 @@ class Time(JsonSchemaMixin, Replaceable):
@dataclass
class FreshnessThreshold(JsonSchemaMixin, Mergeable):
class FreshnessThreshold(dbtClassMixin, Mergeable):
warn_after: Optional[Time] = None
error_after: Optional[Time] = None
filter: Optional[str] = None
@@ -212,7 +216,7 @@ class ExternalTable(AdditionalPropertiesAllowed, Mergeable):
@dataclass
class Quoting(JsonSchemaMixin, Mergeable):
class Quoting(dbtClassMixin, Mergeable):
database: Optional[bool] = None
schema: Optional[bool] = None
identifier: Optional[bool] = None
@@ -231,14 +235,15 @@ class UnparsedSourceTableDefinition(HasColumnTests, HasTests):
tags: List[str] = field(default_factory=list)
def to_dict(self, omit_none=True, validate=False):
result = super().to_dict(omit_none=omit_none, validate=validate)
result = super().serialize(omit_none=omit_none, validate=validate)
if omit_none and self.freshness is None:
result['freshness'] = None
return result
@dataclass
class UnparsedSourceDefinition(JsonSchemaMixin, Replaceable):
class UnparsedSourceDefinition(dbtClassMixin, Replaceable):
name: str
description: str = ''
meta: Dict[str, Any] = field(default_factory=dict)
@@ -258,14 +263,14 @@ class UnparsedSourceDefinition(JsonSchemaMixin, Replaceable):
return 'sources'
def to_dict(self, omit_none=True, validate=False):
result = super().to_dict(omit_none=omit_none, validate=validate)
result = super().serialize(omit_none=omit_none, validate=validate)
if omit_none and self.freshness is None:
result['freshness'] = None
return result
@dataclass
class SourceTablePatch(JsonSchemaMixin):
class SourceTablePatch(dbtClassMixin):
name: str
description: Optional[str] = None
meta: Optional[Dict[str, Any]] = None
@@ -283,7 +288,7 @@ class SourceTablePatch(JsonSchemaMixin):
columns: Optional[Sequence[UnparsedColumn]] = None
def to_patch_dict(self) -> Dict[str, Any]:
dct = self.to_dict(omit_none=True)
dct = self.serialize(omit_none=True)
remove_keys = ('name')
for key in remove_keys:
if key in dct:
@@ -296,7 +301,7 @@ class SourceTablePatch(JsonSchemaMixin):
@dataclass
class SourcePatch(JsonSchemaMixin, Replaceable):
class SourcePatch(dbtClassMixin, Replaceable):
name: str = field(
metadata=dict(description='The name of the source to override'),
)
@@ -320,7 +325,7 @@ class SourcePatch(JsonSchemaMixin, Replaceable):
tags: Optional[List[str]] = None
def to_patch_dict(self) -> Dict[str, Any]:
dct = self.to_dict(omit_none=True)
dct = self.serialize(omit_none=True)
remove_keys = ('name', 'overrides', 'tables', 'path')
for key in remove_keys:
if key in dct:
@@ -340,7 +345,7 @@ class SourcePatch(JsonSchemaMixin, Replaceable):
@dataclass
class UnparsedDocumentation(JsonSchemaMixin, Replaceable):
class UnparsedDocumentation(dbtClassMixin, Replaceable):
package_name: str
root_path: str
path: str
@@ -400,7 +405,7 @@ class MaturityType(StrEnum):
@dataclass
class ExposureOwner(JsonSchemaMixin, Replaceable):
class ExposureOwner(dbtClassMixin, Replaceable):
email: str
name: Optional[str] = None

View File

@@ -0,0 +1,66 @@
from dataclasses import dataclass, fields, Field
from typing import (
Optional, TypeVar, Generic, Dict, get_type_hints, List, Tuple
)
from mashumaro import DataClassDictMixin
from mashumaro.types import SerializableType
import re
"""
This is a throwaway shim to match the JsonSchemaMixin interface
that downstream consumers (ie. PostgresRelation) is expecting.
I imagine that we would try to remove code that depends on this type
reflection if we pursue an approach like the one shown here
"""
class dbtClassMixin(DataClassDictMixin):
@classmethod
def field_mapping(cls) -> Dict[str, str]:
"""Defines the mapping of python field names to JSON field names.
The main use-case is to allow JSON field names which are Python keywords
"""
return {}
def serialize(self, omit_none=False, validate=False, with_aliases: Optional[Dict[str, str]]=None):
dct = self.to_dict()
if with_aliases:
# TODO : Mutating these dicts is a TERRIBLE idea - remove this
for aliased_name, canonical_name in self._ALIASES.items():
if aliased_name in dct:
dct[canonical_name] = dct.pop(aliased_name)
return dct
@classmethod
def deserialize(cls, data, validate=False, with_aliases=False):
if with_aliases:
# TODO : Mutating these dicts is a TERRIBLE idea - remove this
for aliased_name, canonical_name in cls._ALIASES.items():
if aliased_name in data:
data[canonical_name] = data.pop(aliased_name)
# TODO .... implement these?
return cls.from_dict(data)
class ValidatedStringMixin(str, SerializableType):
ValidationRegex = None
@classmethod
def _deserialize(cls, value: str) -> 'ValidatedStringMixin':
cls.validate(value)
return ValidatedStringMixin(value)
def _serialize(self) -> str:
return str(self)
@classmethod
def validate(cls, value) -> str:
res = re.match(cls.ValidationRegex, value)
if res is None:
raise ValidationError(f"Invalid value: {value}") # TODO

View File

@@ -12,12 +12,16 @@ from hologram.helpers import HyphenatedJsonSchemaMixin, register_pattern, \
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Union, Any, NewType
from dbt.contracts.jsonschema import dbtClassMixin, ValidatedStringMixin
from mashumaro.types import SerializableType
PIN_PACKAGE_URL = 'https://docs.getdbt.com/docs/package-management#section-specifying-package-versions' # noqa
DEFAULT_SEND_ANONYMOUS_USAGE_STATS = True
Name = NewType('Name', str)
register_pattern(Name, r'^[^\d\W]\w*$')
class Name(ValidatedStringMixin):
ValidationRegex = r'^[^\d\W]\w*$'
# this does not support the full semver (does not allow a trailing -fooXYZ) and
# is not restrictive enough for full semver, (allows '1.0'). But it's like
@@ -28,17 +32,26 @@ register_pattern(
r'^(?:0|[1-9]\d*)\.(?:0|[1-9]\d*)(\.(?:0|[1-9]\d*))?$',
)
class SemverString(str, SerializableType):
def _serialize(self) -> str:
return self
@dataclass
class Quoting(JsonSchemaMixin, Mergeable):
identifier: Optional[bool]
schema: Optional[bool]
database: Optional[bool]
project: Optional[bool]
@classmethod
def _deserialize(cls, value: str) -> 'SemverString':
return SemverString(value)
@dataclass
class Package(Replaceable, HyphenatedJsonSchemaMixin):
class Quoting(dbtClassMixin, Mergeable):
identifier: Optional[bool] = None
schema: Optional[bool] = None
database: Optional[bool] = None
project: Optional[bool] = None
# TODO .... hyphenation.... what.... why
@dataclass
class Package(dbtClassMixin, Replaceable):
pass
@@ -80,7 +93,7 @@ PackageSpec = Union[LocalPackage, GitPackage, RegistryPackage]
@dataclass
class PackageConfig(JsonSchemaMixin, Replaceable):
class PackageConfig(dbtClassMixin, Replaceable):
packages: List[PackageSpec]
@@ -96,13 +109,13 @@ class ProjectPackageMetadata:
@dataclass
class Downloads(ExtensibleJsonSchemaMixin, Replaceable):
class Downloads(dbtClassMixin, Replaceable):
tarball: str
@dataclass
class RegistryPackageMetadata(
ExtensibleJsonSchemaMixin,
dbtClassMixin,
ProjectPackageMetadata,
):
downloads: Downloads
@@ -153,7 +166,7 @@ BANNED_PROJECT_NAMES = {
@dataclass
class Project(HyphenatedJsonSchemaMixin, Replaceable):
class Project(dbtClassMixin, Replaceable):
name: Name
version: Union[SemverString, float]
config_version: int
@@ -189,9 +202,30 @@ class Project(HyphenatedJsonSchemaMixin, Replaceable):
packages: List[PackageSpec] = field(default_factory=list)
query_comment: Optional[Union[QueryComment, NoValue, str]] = NoValue()
_ALIASES = {
'config-version': 'config_version',
'source-paths': 'source_paths',
'macro-paths': 'macro_paths',
'data-paths': 'data_paths',
'test-paths': 'test_paths',
'analysis-paths': 'analysis_paths',
'docs-paths': 'docs_paths',
'asset-paths': 'asset_paths',
'target-path': 'target_path',
'snapshot-paths': 'snapshot_paths',
'clean-targets': 'clean_targets',
'log-path': 'log_path',
'modules-path': 'modules_path',
'on-run-start': 'on_run_start',
'on-run-end': 'on_run_end',
'require-dbt-version': 'require_dbt_version',
'project-root': 'project_root',
}
@classmethod
def from_dict(cls, data, validate=True) -> 'Project':
result = super().from_dict(data, validate=validate)
def deserialize(cls, data) -> 'Project':
# TODO : Deserialize this with aliases - right now, only implemented for serializer?
result = super().deserialize(data, with_aliases=True)
if result.name in BANNED_PROJECT_NAMES:
raise ValidationError(
f'Invalid project name: {result.name} is a reserved word'
@@ -200,8 +234,9 @@ class Project(HyphenatedJsonSchemaMixin, Replaceable):
return result
# TODO : Make extensible?
@dataclass
class UserConfig(ExtensibleJsonSchemaMixin, Replaceable, UserConfigContract):
class UserConfig(dbtClassMixin, Replaceable, UserConfigContract):
send_anonymous_usage_stats: bool = DEFAULT_SEND_ANONYMOUS_USAGE_STATS
use_colors: Optional[bool] = None
partial_parse: Optional[bool] = None
@@ -221,7 +256,7 @@ class UserConfig(ExtensibleJsonSchemaMixin, Replaceable, UserConfigContract):
@dataclass
class ProfileConfig(HyphenatedJsonSchemaMixin, Replaceable):
class ProfileConfig(dbtClassMixin, Replaceable):
profile_name: str = field(metadata={'preserve_underscore': True})
target_name: str = field(metadata={'preserve_underscore': True})
config: UserConfig
@@ -234,8 +269,8 @@ class ProfileConfig(HyphenatedJsonSchemaMixin, Replaceable):
class ConfiguredQuoting(Quoting, Replaceable):
identifier: bool
schema: bool
database: Optional[bool]
project: Optional[bool]
database: Optional[bool] = None
project: Optional[bool] = None
@dataclass
@@ -248,5 +283,5 @@ class Configuration(Project, ProfileConfig):
@dataclass
class ProjectList(JsonSchemaMixin):
class ProjectList(dbtClassMixin):
projects: Dict[str, Project]

View File

@@ -1,11 +1,10 @@
from collections.abc import Mapping
from dataclasses import dataclass, fields
from dataclasses import dataclass, fields, Field
from typing import (
Optional, TypeVar, Generic, Dict,
Optional, TypeVar, Generic, Dict, get_type_hints, List, Tuple
)
from typing_extensions import Protocol
from hologram import JsonSchemaMixin
from hologram.helpers import StrEnum
from dbt import deprecations
@@ -13,6 +12,8 @@ from dbt.contracts.util import Replaceable
from dbt.exceptions import CompilationException
from dbt.utils import deep_merge
from dbt.contracts.jsonschema import dbtClassMixin
class RelationType(StrEnum):
Table = 'table'
@@ -32,7 +33,7 @@ class HasQuoting(Protocol):
quoting: Dict[str, bool]
class FakeAPIObject(JsonSchemaMixin, Replaceable, Mapping):
class FakeAPIObject(dbtClassMixin, Replaceable, Mapping):
# override the mapping truthiness, len is always >1
def __bool__(self):
return True
@@ -58,16 +59,22 @@ class FakeAPIObject(JsonSchemaMixin, Replaceable, Mapping):
return self.from_dict(value)
T = TypeVar('T')
"""
mashumaro does not support generic types, so I just duplicated
the generic type class that was here previously to get things
working. We'd probably want to do this differently
in the future
TODO
"""
@dataclass
class _ComponentObject(FakeAPIObject, Generic[T]):
database: T
schema: T
identifier: T
class Policy(FakeAPIObject):
database: bool = True
schema: bool = True
identifier: bool = True
def get_part(self, key: ComponentName) -> T:
def get_part(self, key: ComponentName) -> bool:
if key == ComponentName.Database:
return self.database
elif key == ComponentName.Schema:
@@ -80,22 +87,15 @@ class _ComponentObject(FakeAPIObject, Generic[T]):
.format(key, list(ComponentName))
)
def replace_dict(self, dct: Dict[ComponentName, T]):
kwargs: Dict[str, T] = {}
def replace_dict(self, dct: Dict[ComponentName, bool]):
kwargs: Dict[str, bool] = {}
for k, v in dct.items():
kwargs[str(k)] = v
return self.replace(**kwargs)
@dataclass
class Policy(_ComponentObject[bool]):
database: bool = True
schema: bool = True
identifier: bool = True
@dataclass
class Path(_ComponentObject[Optional[str]]):
class Path(FakeAPIObject):
database: Optional[str]
schema: Optional[str]
identifier: Optional[str]
@@ -120,3 +120,22 @@ class Path(_ComponentObject[Optional[str]]):
if part is not None:
part = part.lower()
return part
def get_part(self, key: ComponentName) -> str:
if key == ComponentName.Database:
return self.database
elif key == ComponentName.Schema:
return self.schema
elif key == ComponentName.Identifier:
return self.identifier
else:
raise ValueError(
'Got a key of {}, expected one of {}'
.format(key, list(ComponentName))
)
def replace_dict(self, dct: Dict[ComponentName, str]):
kwargs: Dict[str, str] = {}
for k, v in dct.items():
kwargs[str(k)] = v
return self.replace(**kwargs)

View File

@@ -21,6 +21,9 @@ from dbt.utils import lowercase
from hologram.helpers import StrEnum
from hologram import JsonSchemaMixin
from dbt.contracts.jsonschema import dbtClassMixin
from mashumaro.types import SerializableType
import agate
from dataclasses import dataclass, field
@@ -31,7 +34,7 @@ from dbt.clients.system import write_json
@dataclass
class TimingInfo(JsonSchemaMixin):
class TimingInfo(dbtClassMixin):
name: str
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
@@ -88,7 +91,7 @@ class FreshnessStatus(StrEnum):
@dataclass
class BaseResult(JsonSchemaMixin):
class BaseResult(dbtClassMixin):
status: Union[RunStatus, TestStatus, FreshnessStatus]
timing: List[TimingInfo]
thread_id: str
@@ -109,15 +112,14 @@ class PartialNodeResult(NodeResult):
def skipped(self):
return False
# TODO : Does this work? no-op agate table serialization in RunModelResult
class SerializableAgateTable(agate.Table, SerializableType):
def _serialize(self) -> None:
return None
@dataclass
class RunModelResult(NodeResult):
agate_table: Optional[agate.Table] = None
def to_dict(self, *args, **kwargs):
dct = super().to_dict(*args, **kwargs)
dct.pop('agate_table', None)
return dct
agate_table: Optional[SerializableAgateTable] = None
@property
def skipped(self):
@@ -125,7 +127,7 @@ class RunModelResult(NodeResult):
@dataclass
class ExecutionResult(JsonSchemaMixin):
class ExecutionResult(dbtClassMixin):
results: Sequence[BaseResult]
elapsed_time: float
@@ -210,7 +212,8 @@ class RunResultsArtifact(ExecutionResult, ArtifactMixin):
)
def write(self, path: str, omit_none=False):
write_json(path, self.to_dict(omit_none=omit_none))
# TODO: Implement omit_none
write_json(path, self.to_dict())
@dataclass
@@ -268,14 +271,14 @@ class FreshnessErrorEnum(StrEnum):
@dataclass
class SourceFreshnessRuntimeError(JsonSchemaMixin):
class SourceFreshnessRuntimeError(dbtClassMixin):
unique_id: str
error: Optional[Union[str, int]]
status: FreshnessErrorEnum
@dataclass
class SourceFreshnessOutput(JsonSchemaMixin):
class SourceFreshnessOutput(dbtClassMixin):
unique_id: str
max_loaded_at: datetime
snapshotted_at: datetime
@@ -383,7 +386,7 @@ CatalogKey = NamedTuple(
@dataclass
class StatsItem(JsonSchemaMixin):
class StatsItem(dbtClassMixin):
id: str
label: str
value: Primitive
@@ -395,7 +398,7 @@ StatsDict = Dict[str, StatsItem]
@dataclass
class ColumnMetadata(JsonSchemaMixin):
class ColumnMetadata(dbtClassMixin):
type: str
comment: Optional[str]
index: int
@@ -406,7 +409,7 @@ ColumnMap = Dict[str, ColumnMetadata]
@dataclass
class TableMetadata(JsonSchemaMixin):
class TableMetadata(dbtClassMixin):
type: str
database: Optional[str]
schema: str
@@ -416,7 +419,7 @@ class TableMetadata(JsonSchemaMixin):
@dataclass
class CatalogTable(JsonSchemaMixin, Replaceable):
class CatalogTable(dbtClassMixin, Replaceable):
metadata: TableMetadata
columns: ColumnMap
stats: StatsDict
@@ -439,7 +442,7 @@ class CatalogMetadata(BaseArtifactMetadata):
@dataclass
class CatalogResults(JsonSchemaMixin):
class CatalogResults(dbtClassMixin):
nodes: Dict[str, CatalogTable]
sources: Dict[str, CatalogTable]
errors: Optional[List[str]]

View File

@@ -26,6 +26,8 @@ from dbt.exceptions import InternalException
from dbt.logger import LogMessage
from dbt.utils import restrict_to
from dbt.contracts.jsonschema import dbtClassMixin
TaskTags = Optional[Dict[str, Any]]
TaskID = uuid.UUID
@@ -34,7 +36,7 @@ TaskID = uuid.UUID
@dataclass
class RPCParameters(JsonSchemaMixin):
class RPCParameters(dbtClassMixin):
timeout: Optional[float]
task_tags: TaskTags
@@ -132,7 +134,7 @@ class StatusParameters(RPCParameters):
@dataclass
class GCSettings(JsonSchemaMixin):
class GCSettings(dbtClassMixin):
# start evicting the longest-ago-ended tasks here
maxsize: int
# start evicting all tasks before now - auto_reap_age when we have this
@@ -254,7 +256,7 @@ class RemoteExecutionResult(ExecutionResult, RemoteResult):
@dataclass
class ResultTable(JsonSchemaMixin):
class ResultTable(dbtClassMixin):
column_names: List[str]
rows: List[Any]
@@ -411,7 +413,7 @@ class TaskHandlerState(StrEnum):
@dataclass
class TaskTiming(JsonSchemaMixin):
class TaskTiming(dbtClassMixin):
state: TaskHandlerState
start: Optional[datetime]
end: Optional[datetime]

View File

@@ -3,16 +3,18 @@ from hologram import JsonSchemaMixin
from typing import List, Dict, Any, Union
from dbt.contracts.jsonschema import dbtClassMixin
@dataclass
class SelectorDefinition(JsonSchemaMixin):
class SelectorDefinition(dbtClassMixin):
name: str
definition: Union[str, Dict[str, Any]]
description: str = ''
@dataclass
class SelectorFile(JsonSchemaMixin):
class SelectorFile(dbtClassMixin):
selectors: List[SelectorDefinition]
version: int = 2

View File

@@ -14,6 +14,7 @@ from dbt.exceptions import (
from dbt.version import __version__
from dbt.tracking import get_invocation_id
from hologram import JsonSchemaMixin
from dbt.contracts.jsonschema import dbtClassMixin
MacroKey = Tuple[str, str]
SourceKey = Tuple[str, str]
@@ -58,7 +59,8 @@ class Mergeable(Replaceable):
class Writable:
def write(self, path: str, omit_none: bool = False):
write_json(path, self.to_dict(omit_none=omit_none)) # type: ignore
# TODO : Do this faster?
write_json(path, self.serialize(omit_none=omit_none)) # type: ignore
class AdditionalPropertiesMixin:
@@ -71,20 +73,20 @@ class AdditionalPropertiesMixin:
@classmethod
def from_dict(cls, data, validate=True):
self = super().from_dict(data=data, validate=validate)
keys = self.to_dict(validate=False, omit_none=False)
self = super().deserialize(data=data, validate=validate)
keys = self.serialize(validate=False, omit_none=False)
for key, value in data.items():
if key not in keys:
self.extra[key] = value
return self
def to_dict(self, omit_none=True, validate=False):
data = super().to_dict(omit_none=omit_none, validate=validate)
data = super().serialize(omit_none=omit_none, validate=validate)
data.update(self.extra)
return data
def replace(self, **kwargs):
dct = self.to_dict(omit_none=False, validate=False)
dct = self.serialize(omit_none=False, validate=False)
dct.update(kwargs)
return self.from_dict(dct)
@@ -135,7 +137,7 @@ def get_metadata_env() -> Dict[str, str]:
@dataclasses.dataclass
class BaseArtifactMetadata(JsonSchemaMixin):
class BaseArtifactMetadata(dbtClassMixin):
dbt_schema_version: str
dbt_version: str = __version__
generated_at: datetime = dataclasses.field(
@@ -158,7 +160,7 @@ def schema_version(name: str, version: int):
@dataclasses.dataclass
class VersionedSchema(JsonSchemaMixin):
class VersionedSchema(dbtClassMixin):
dbt_schema_version: ClassVar[SchemaVersion]
@classmethod
@@ -194,4 +196,4 @@ class ArtifactMixin(VersionedSchema, Writable, Readable):
if found != expected:
raise IncompatibleSchemaException(expected, found)
return super().from_dict(data=data, validate=validate)
return super().deserialize(data=data, validate=validate)

View File

@@ -4,8 +4,12 @@ from typing import (
import networkx as nx # type: ignore
from dbt.exceptions import InternalException
from dbt.contracts.jsonschema import ValidatedStringMixin
from mashumaro.types import SerializableType
UniqueId = NewType('UniqueId', str)
class UniqueId(ValidatedStringMixin):
ValidationRegex = '.+'
class Graph:

View File

@@ -2,16 +2,33 @@
from dataclasses import dataclass
from datetime import timedelta
from pathlib import Path
from typing import NewType, Tuple, AbstractSet
from typing import NewType, Tuple, AbstractSet, Union
from hologram import (
FieldEncoder, JsonSchemaMixin, JsonDict, ValidationError
)
from hologram.helpers import StrEnum
from dbt.contracts.jsonschema import dbtClassMixin
from mashumaro.types import SerializableType
Port = NewType('Port', int)
class Port(int, SerializableType):
@classmethod
def _deserialize(cls, value: Union[int, str]) -> 'Port':
try:
value = int(value)
except ValueError:
raise ValidationError(f'Cannot encode {value} into port numbr')
return Port(value)
def _serialize(self) -> int:
# TODO : Validate here?
return self
# TODO
class PortEncoder(FieldEncoder):
@property
def json_schema(self):
@@ -66,16 +83,17 @@ class NVEnum(StrEnum):
@dataclass
class NoValue(JsonSchemaMixin):
class NoValue(dbtClassMixin):
"""Sometimes, you want a way to say none that isn't None"""
novalue: NVEnum = NVEnum.novalue
JsonSchemaMixin.register_field_encoders({
Port: PortEncoder(),
timedelta: TimeDeltaFieldEncoder(),
Path: PathEncoder(),
})
# TODO : None of this is right lol
# JsonSchemaMixin.register_field_encoders({
# Port: PortEncoder(),
# timedelta: TimeDeltaFieldEncoder(),
# Path: PathEncoder(),
# })
FQNPath = Tuple[str, ...]

View File

@@ -15,6 +15,8 @@ import colorama
import logbook
from hologram import JsonSchemaMixin
from dbt.contracts.jsonschema import dbtClassMixin
# Colorama needs some help on windows because we're using logger.info
# intead of print(). If the Windows env doesn't have a TERM var set,
# then we should override the logging stream to use the colorama
@@ -49,7 +51,7 @@ Extras = Dict[str, Any]
@dataclass
class LogMessage(JsonSchemaMixin):
class LogMessage(dbtClassMixin):
timestamp: datetime
message: str
channel: str
@@ -215,7 +217,7 @@ class TextOnly(logbook.Processor):
class TimingProcessor(logbook.Processor):
def __init__(self, timing_info: Optional[JsonSchemaMixin] = None):
def __init__(self, timing_info: Optional[dbtClassMixin] = None):
self.timing_info = timing_info
super().__init__()

View File

@@ -137,6 +137,7 @@ def main(args=None):
exit_code = e.code
except BaseException as e:
print(traceback.format_exc())
logger.warning("Encountered an error:")
logger.warning(str(e))

View File

@@ -13,7 +13,7 @@ class AnalysisParser(SimpleSQLParser[ParsedAnalysisNode]):
)
def parse_from_dict(self, dct, validate=True) -> ParsedAnalysisNode:
return ParsedAnalysisNode.from_dict(dct, validate=validate)
return ParsedAnalysisNode.deserialize(dct, validate=validate)
@property
def resource_type(self) -> NodeType:

View File

@@ -12,7 +12,7 @@ class DataTestParser(SimpleSQLParser[ParsedDataTestNode]):
)
def parse_from_dict(self, dct, validate=True) -> ParsedDataTestNode:
return ParsedDataTestNode.from_dict(dct, validate=validate)
return ParsedDataTestNode.deserialize(dct, validate=validate)
@property
def resource_type(self) -> NodeType:

View File

@@ -79,7 +79,7 @@ class HookParser(SimpleParser[HookBlock, ParsedHookNode]):
return [path]
def parse_from_dict(self, dct, validate=True) -> ParsedHookNode:
return ParsedHookNode.from_dict(dct, validate=validate)
return ParsedHookNode.deserialize(dct, validate=validate)
@classmethod
def get_compiled_path(cls, block: HookBlock):

View File

@@ -53,20 +53,22 @@ from dbt.version import __version__
from hologram import JsonSchemaMixin
from dbt.contracts.jsonschema import dbtClassMixin
PARTIAL_PARSE_FILE_NAME = 'partial_parse.pickle'
PARSING_STATE = DbtProcessState('parsing')
DEFAULT_PARTIAL_PARSE = False
@dataclass
class ParserInfo(JsonSchemaMixin):
class ParserInfo(dbtClassMixin):
parser: str
elapsed: float
path_count: int = 0
@dataclass
class ProjectLoaderInfo(JsonSchemaMixin):
class ProjectLoaderInfo(dbtClassMixin):
project_name: str
elapsed: float
parsers: List[ParserInfo]
@@ -74,7 +76,7 @@ class ProjectLoaderInfo(JsonSchemaMixin):
@dataclass
class ManifestLoaderInfo(JsonSchemaMixin, Writable):
class ManifestLoaderInfo(dbtClassMixin, Writable):
path_count: int = 0
is_partial_parse_enabled: Optional[bool] = None
parse_project_elapsed: Optional[float] = None

View File

@@ -11,7 +11,7 @@ class ModelParser(SimpleSQLParser[ParsedModelNode]):
)
def parse_from_dict(self, dct, validate=True) -> ParsedModelNode:
return ParsedModelNode.from_dict(dct, validate=validate)
return ParsedModelNode.deserialize(dct, validate=validate)
@property
def resource_type(self) -> NodeType:

View File

@@ -3,6 +3,8 @@ from typing import TypeVar, MutableMapping, Mapping, Union, List
from hologram import JsonSchemaMixin
from dbt.contracts.jsonschema import dbtClassMixin
from dbt.contracts.files import RemoteFile, FileHash, SourceFile
from dbt.contracts.graph.compiled import CompileResultNode
from dbt.contracts.graph.parsed import (
@@ -62,7 +64,7 @@ def dict_field():
@dataclass
class ParseResult(JsonSchemaMixin, Writable, Replaceable):
class ParseResult(dbtClassMixin, Writable, Replaceable):
vars_hash: FileHash
profile_hash: FileHash
project_hashes: MutableMapping[str, FileHash]

View File

@@ -26,7 +26,7 @@ class RPCCallParser(SimpleSQLParser[ParsedRPCNode]):
return []
def parse_from_dict(self, dct, validate=True) -> ParsedRPCNode:
return ParsedRPCNode.from_dict(dct, validate=validate)
return ParsedRPCNode.deserialize(dct, validate=validate)
@property
def resource_type(self) -> NodeType:

View File

@@ -7,6 +7,7 @@ from typing import (
)
from hologram import ValidationError, JsonSchemaMixin
from dbt.contracts.jsonschema import dbtClassMixin
from dbt.adapters.factory import get_adapter
from dbt.clients.jinja import get_rendered, add_rendered_test_kwargs
@@ -119,7 +120,8 @@ class ParserRef:
meta=meta,
tags=tags,
quote=quote,
_extra=column.extra
# TODO
#_extra=column.extra
)
@classmethod
@@ -201,7 +203,7 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
)
def parse_from_dict(self, dct, validate=True) -> ParsedSchemaTestNode:
return ParsedSchemaTestNode.from_dict(dct, validate=validate)
return ParsedSchemaTestNode.deserialize(dct, validate=validate)
def _parse_format_version(
self, yaml: YamlBlock
@@ -654,7 +656,7 @@ class YamlDocsReader(YamlReader):
raise NotImplementedError('parse is abstract')
T = TypeVar('T', bound=JsonSchemaMixin)
T = TypeVar('T', bound=dbtClassMixin)
class SourceParser(YamlDocsReader):

View File

@@ -13,7 +13,7 @@ class SeedParser(SimpleSQLParser[ParsedSeedNode]):
)
def parse_from_dict(self, dct, validate=True) -> ParsedSeedNode:
return ParsedSeedNode.from_dict(dct, validate=validate)
return ParsedSeedNode.deserialize(dct, validate=validate)
@property
def resource_type(self) -> NodeType:

View File

@@ -26,7 +26,7 @@ class SnapshotParser(
)
def parse_from_dict(self, dct, validate=True) -> IntermediateSnapshotNode:
return IntermediateSnapshotNode.from_dict(dct, validate=validate)
return IntermediateSnapshotNode.deserialize(dct, validate=validate)
@property
def resource_type(self) -> NodeType:
@@ -66,6 +66,9 @@ class SnapshotParser(
def transform(self, node: IntermediateSnapshotNode) -> ParsedSnapshotNode:
try:
# TODO : This is not going to work at all
# TODO : We need to accept the config and turn that into a concrete
# Snapshot node type
parsed_node = ParsedSnapshotNode.from_dict(node.to_dict())
self.set_snapshot_attributes(parsed_node)
return parsed_node

View File

@@ -15,6 +15,9 @@ from dbt.contracts.rpc import (
from dbt.exceptions import InternalException
from dbt.utils import restrict_to
from dbt.contracts.jsonschema import dbtClassMixin
from mashumaro.types import SerializableType
class QueueMessageType(StrEnum):
Error = 'error'
@@ -26,16 +29,41 @@ class QueueMessageType(StrEnum):
@dataclass
class QueueMessage(JsonSchemaMixin):
class QueueMessage(dbtClassMixin):
message_type: QueueMessageType
class SerializableLogRecord(logbook.LogRecord, SerializableType):
def _serialize(self):
# TODO
import ipdb; ipdb.set_trace()
pass
@classmethod
def _deserialize(cls, value):
# TODO
import ipdb; ipdb.set_trace()
pass
class SerializableJSONRPCError(JSONRPCError, SerializableType):
def _serialize(self):
# TODO
import ipdb; ipdb.set_trace()
pass
@classmethod
def _deserialize(cls, value):
# TODO
import ipdb; ipdb.set_trace()
pass
@dataclass
class QueueLogMessage(QueueMessage):
message_type: QueueMessageType = field(
metadata=restrict_to(QueueMessageType.Log)
)
record: logbook.LogRecord
record: SerializableLogRecord
@classmethod
def from_record(cls, record: logbook.LogRecord):
@@ -50,7 +78,7 @@ class QueueErrorMessage(QueueMessage):
message_type: QueueMessageType = field(
metadata=restrict_to(QueueMessageType.Error)
)
error: JSONRPCError
error: SerializableJSONRPCError
@classmethod
def from_error(cls, error: JSONRPCError):

View File

@@ -8,6 +8,8 @@ from hologram import JsonSchemaMixin, ValidationError
from dbt.contracts.rpc import RPCParameters, RemoteResult, RemoteMethodFlags
from dbt.exceptions import NotImplementedException, InternalException
from dbt.contracts.jsonschema import dbtClassMixin
Parameters = TypeVar('Parameters', bound=RPCParameters)
Result = TypeVar('Result', bound=RemoteResult)
@@ -109,7 +111,7 @@ class RemoteBuiltinMethod(RemoteMethod[Parameters, Result]):
'the run() method on builtins should never be called'
)
def __call__(self, **kwargs: Dict[str, Any]) -> JsonSchemaMixin:
def __call__(self, **kwargs: Dict[str, Any]) -> dbtClassMixin:
try:
params = self.get_parameters().from_dict(kwargs)
except ValidationError as exc:

View File

@@ -20,6 +20,8 @@ from dbt.rpc.task_handler import RequestTaskHandler
from dbt.rpc.method import RemoteMethod
from dbt.rpc.task_manager import TaskManager
from dbt.contracts.jsonschema import dbtClassMixin
def track_rpc_request(task):
dbt.tracking.track_rpc_request({
@@ -90,11 +92,11 @@ class ResponseManager(JSONRPCResponseManager):
@classmethod
def _get_responses(cls, requests, dispatcher):
for output in super()._get_responses(requests, dispatcher):
# if it's a result, check if it's a JsonSchemaMixin and if so call
# if it's a result, check if it's a dbtClassMixin and if so call
# to_dict
if hasattr(output, 'result'):
if isinstance(output.result, JsonSchemaMixin):
output.result = output.result.to_dict(omit_none=False)
if isinstance(output.result, dbtClassMixin):
output.result = output.result.serialize(omit_none=False)
yield output
@classmethod

View File

@@ -43,6 +43,9 @@ from dbt.rpc.method import RemoteMethod
# we use this in typing only...
from queue import Queue # noqa
from dbt.contracts.jsonschema import dbtClassMixin
def sigterm_handler(signum, frame):
raise dbt.exceptions.RPCKilledException(signum)
@@ -283,7 +286,7 @@ class RequestTaskHandler(threading.Thread, TaskHandlerProtocol):
# - The actual thread that this represents, which writes its data to
# the result and logs. The atomicity of list.append() and item
# assignment means we don't need a lock.
self.result: Optional[JsonSchemaMixin] = None
self.result: Optional[dbtClassMixin] = None
self.error: Optional[RPCException] = None
self.state: TaskHandlerState = TaskHandlerState.NotStarted
self.logs: List[LogMessage] = []

View File

@@ -8,6 +8,8 @@ from hologram import JsonSchemaMixin
from hologram.helpers import StrEnum
from typing import Optional
from dbt.contracts.jsonschema import dbtClassMixin
class Matchers(StrEnum):
GREATER_THAN = '>'
@@ -18,12 +20,12 @@ class Matchers(StrEnum):
@dataclass
class VersionSpecification(JsonSchemaMixin):
major: Optional[str]
minor: Optional[str]
patch: Optional[str]
prerelease: Optional[str]
build: Optional[str]
class VersionSpecification(dbtClassMixin):
major: Optional[str] = None
minor: Optional[str] = None
patch: Optional[str] = None
prerelease: Optional[str] = None
build: Optional[str] = None
matcher: Matchers = Matchers.EXACT

View File

@@ -110,7 +110,7 @@ class ListTask(GraphRunnableTask):
for node in self._iterate_selected_nodes():
yield json.dumps({
k: v
for k, v in node.to_dict(omit_none=False).items()
for k, v in node.serialize(omit_none=False).items()
if k in self.ALLOWED_KEYS
})

View File

@@ -16,13 +16,14 @@ from typing import Optional
class PostgresCredentials(Credentials):
host: str
user: str
role: Optional[str]
port: Port
password: str # on postgres the password is mandatory
role: Optional[str] = None
search_path: Optional[str] = None
keepalives_idle: int = 0 # 0 means to use the default value
sslmode: Optional[str] = None
# TODO : Fix all instances of _ALIASES? Or include them in some base class?
_ALIASES = {
'dbname': 'database',
'pass': 'password'