Compare commits

...

15 Commits

Author SHA1 Message Date
Michelle Ark
fa5aa42247 Merge branch 'main' into add-dbt-common-requirement 2024-01-11 18:01:19 -05:00
Michelle Ark
06370a636a point to dbt-common main 2024-01-11 11:43:52 -05:00
Michelle Ark
5bcf2d41ff Merge branch 'main' into add-dbt-common-requirement 2024-01-11 11:43:25 -05:00
Emily Rockman
43f44a9298 remove commented out code 2024-01-10 15:28:42 -06:00
Emily Rockman
51f625578f Fully remove legacy logger (#9353)
* tried blinding cutting out all logbook dependency related bits

* cleanup

* changelog

* remove import

* remove unused import
2024-01-10 14:43:09 -06:00
Michelle Ark
b12ed0f6a4 Merge branch 'main' into add-dbt-common-requirement 2024-01-10 12:21:51 -05:00
Michelle Ark
39420024aa changelog entry 2024-01-10 10:57:39 -05:00
Michelle Ark
e4812c62b4 remove tests/unit/common 2024-01-10 10:41:46 -05:00
Michelle Ark
14fe6a5966 update imports from dbt.common to dbt_common 2024-01-09 16:17:06 -05:00
Emily Rockman
8c167d0bff some cleanup 2024-01-09 09:04:14 -06:00
Emily Rockman
e58b8cbd0c WIP 2024-01-08 16:15:54 -06:00
Emily Rockman
08f24be6cf Merge branch 'add-dbt-common-requirement' of https://github.com/dbt-labs/dbt-core into add-dbt-common-requirement 2024-01-05 15:26:28 -06:00
Emily Rockman
30b36bd8b7 update requirements, remove colorama 2024-01-05 15:02:25 -06:00
Michelle Ark
5ffa3080c1 remove dbt-common unit tests 2024-01-05 15:11:10 -05:00
Michelle Ark
e4ddbc8114 replace dbt/common with dbt-common 2024-01-05 14:24:07 -05:00
269 changed files with 622 additions and 9289 deletions

View File

@@ -0,0 +1,6 @@
kind: Dependencies
body: Remove logbook dependency
time: 2024-01-09T12:05:30.176656-06:00
custom:
Author: emmyoop
PR: "9353"

View File

@@ -0,0 +1,6 @@
kind: Under the Hood
body: Add dbt-common as a dependency and remove dbt/common
time: 2024-01-10T10:57:34.054908-05:00
custom:
Author: michelleark emmyoop
Issue: "9357"

View File

@@ -22,8 +22,6 @@
### links.py ### links.py
### logger.py
### main.py ### main.py
### node_types.py ### node_types.py

View File

@@ -2,7 +2,7 @@ from dataclasses import dataclass
import re import re
from typing import Dict, ClassVar, Any, Optional from typing import Dict, ClassVar, Any, Optional
from dbt.common.exceptions import DbtRuntimeError from dbt_common.exceptions import DbtRuntimeError
@dataclass @dataclass

View File

@@ -25,7 +25,7 @@ from typing import (
import agate import agate
import dbt.adapters.exceptions import dbt.adapters.exceptions
import dbt.common.exceptions.base import dbt_common.exceptions.base
from dbt.adapters.contracts.connection import ( from dbt.adapters.contracts.connection import (
Connection, Connection,
Identifier, Identifier,
@@ -38,7 +38,7 @@ from dbt.adapters.base.query_headers import (
MacroQueryStringSetter, MacroQueryStringSetter,
) )
from dbt.adapters.events.logging import AdapterLogger from dbt.adapters.events.logging import AdapterLogger
from dbt.common.events.functions import fire_event from dbt_common.events.functions import fire_event
from dbt.adapters.events.types import ( from dbt.adapters.events.types import (
NewConnection, NewConnection,
ConnectionReused, ConnectionReused,
@@ -49,8 +49,8 @@ from dbt.adapters.events.types import (
Rollback, Rollback,
RollbackFailed, RollbackFailed,
) )
from dbt.common.events.contextvars import get_node_info from dbt_common.events.contextvars import get_node_info
from dbt.common.utils import cast_to_str from dbt_common.utils import cast_to_str
SleepTime = Union[int, float] # As taken by time.sleep. SleepTime = Union[int, float] # As taken by time.sleep.
AdapterHandle = Any # Adapter connection handle objects can be any class. AdapterHandle = Any # Adapter connection handle objects can be any class.
@@ -99,7 +99,7 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
def set_thread_connection(self, conn: Connection) -> None: def set_thread_connection(self, conn: Connection) -> None:
key = self.get_thread_identifier() key = self.get_thread_identifier()
if key in self.thread_connections: if key in self.thread_connections:
raise dbt.common.exceptions.DbtInternalError( raise dbt_common.exceptions.DbtInternalError(
"In set_thread_connection, existing connection exists for {}" "In set_thread_connection, existing connection exists for {}"
) )
self.thread_connections[key] = conn self.thread_connections[key] = conn
@@ -139,7 +139,7 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
:return: A context manager that handles exceptions raised by the :return: A context manager that handles exceptions raised by the
underlying database. underlying database.
""" """
raise dbt.common.exceptions.base.NotImplementedError( raise dbt_common.exceptions.base.NotImplementedError(
"`exception_handler` is not implemented for this adapter!" "`exception_handler` is not implemented for this adapter!"
) )
@@ -275,7 +275,7 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
def cancel_open(self) -> Optional[List[str]]: def cancel_open(self) -> Optional[List[str]]:
"""Cancel all open connections on the adapter. (passable)""" """Cancel all open connections on the adapter. (passable)"""
raise dbt.common.exceptions.base.NotImplementedError( raise dbt_common.exceptions.base.NotImplementedError(
"`cancel_open` is not implemented for this adapter!" "`cancel_open` is not implemented for this adapter!"
) )
@@ -290,7 +290,7 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
This should be thread-safe, or hold the lock if necessary. The given This should be thread-safe, or hold the lock if necessary. The given
connection should not be in either in_use or available. connection should not be in either in_use or available.
""" """
raise dbt.common.exceptions.base.NotImplementedError( raise dbt_common.exceptions.base.NotImplementedError(
"`open` is not implemented for this adapter!" "`open` is not implemented for this adapter!"
) )
@@ -324,14 +324,14 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
def begin(self) -> None: def begin(self) -> None:
"""Begin a transaction. (passable)""" """Begin a transaction. (passable)"""
raise dbt.common.exceptions.base.NotImplementedError( raise dbt_common.exceptions.base.NotImplementedError(
"`begin` is not implemented for this adapter!" "`begin` is not implemented for this adapter!"
) )
@abc.abstractmethod @abc.abstractmethod
def commit(self) -> None: def commit(self) -> None:
"""Commit a transaction. (passable)""" """Commit a transaction. (passable)"""
raise dbt.common.exceptions.base.NotImplementedError( raise dbt_common.exceptions.base.NotImplementedError(
"`commit` is not implemented for this adapter!" "`commit` is not implemented for this adapter!"
) )
@@ -369,7 +369,7 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
def _rollback(cls, connection: Connection) -> None: def _rollback(cls, connection: Connection) -> None:
"""Roll back the given connection.""" """Roll back the given connection."""
if connection.transaction_open is False: if connection.transaction_open is False:
raise dbt.common.exceptions.DbtInternalError( raise dbt_common.exceptions.DbtInternalError(
f"Tried to rollback transaction on connection " f"Tried to rollback transaction on connection "
f'"{connection.name}", but it does not have one open!' f'"{connection.name}", but it does not have one open!'
) )
@@ -420,7 +420,7 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
:return: A tuple of the query status and results (empty if fetch=False). :return: A tuple of the query status and results (empty if fetch=False).
:rtype: Tuple[AdapterResponse, agate.Table] :rtype: Tuple[AdapterResponse, agate.Table]
""" """
raise dbt.common.exceptions.base.NotImplementedError( raise dbt_common.exceptions.base.NotImplementedError(
"`execute` is not implemented for this adapter!" "`execute` is not implemented for this adapter!"
) )
@@ -432,7 +432,7 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
See https://github.com/dbt-labs/dbt-core/issues/8396 for more information. See https://github.com/dbt-labs/dbt-core/issues/8396 for more information.
""" """
raise dbt.common.exceptions.base.NotImplementedError( raise dbt_common.exceptions.base.NotImplementedError(
"`add_select_query` is not implemented for this adapter!" "`add_select_query` is not implemented for this adapter!"
) )
@@ -440,6 +440,6 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
def data_type_code_to_name(cls, type_code: Union[int, str]) -> str: def data_type_code_to_name(cls, type_code: Union[int, str]) -> str:
"""Get the string representation of the data type from the type_code.""" """Get the string representation of the data type from the type_code."""
# https://peps.python.org/pep-0249/#type-objects # https://peps.python.org/pep-0249/#type-objects
raise dbt.common.exceptions.base.NotImplementedError( raise dbt_common.exceptions.base.NotImplementedError(
"`data_type_code_to_name` is not implemented for this adapter!" "`data_type_code_to_name` is not implemented for this adapter!"
) )

View File

@@ -23,7 +23,7 @@ from typing import (
from multiprocessing.context import SpawnContext from multiprocessing.context import SpawnContext
from dbt.adapters.capability import Capability, CapabilityDict from dbt.adapters.capability import Capability, CapabilityDict
from dbt.common.contracts.constraints import ( from dbt_common.contracts.constraints import (
ColumnLevelConstraint, ColumnLevelConstraint,
ConstraintType, ConstraintType,
ModelLevelConstraint, ModelLevelConstraint,
@@ -44,7 +44,7 @@ from dbt.adapters.exceptions import (
QuoteConfigTypeError, QuoteConfigTypeError,
) )
from dbt.common.exceptions import ( from dbt_common.exceptions import (
NotImplementedError, NotImplementedError,
DbtInternalError, DbtInternalError,
DbtRuntimeError, DbtRuntimeError,
@@ -58,15 +58,15 @@ from dbt.adapters.protocol import (
AdapterConfig, AdapterConfig,
MacroContextGeneratorCallable, MacroContextGeneratorCallable,
) )
from dbt.common.clients.agate_helper import ( from dbt_common.clients.agate_helper import (
empty_table, empty_table,
get_column_value_uncased, get_column_value_uncased,
merge_tables, merge_tables,
table_from_rows, table_from_rows,
Integer, Integer,
) )
from dbt.common.clients.jinja import CallableMacroGenerator from dbt_common.clients.jinja import CallableMacroGenerator
from dbt.common.events.functions import fire_event, warn_or_error from dbt_common.events.functions import fire_event, warn_or_error
from dbt.adapters.events.types import ( from dbt.adapters.events.types import (
CacheMiss, CacheMiss,
ListRelations, ListRelations,
@@ -76,7 +76,7 @@ from dbt.adapters.events.types import (
ConstraintNotSupported, ConstraintNotSupported,
ConstraintNotEnforced, ConstraintNotEnforced,
) )
from dbt.common.utils import filter_null_values, executor, cast_to_str, AttrDict from dbt_common.utils import filter_null_values, executor, cast_to_str, AttrDict
from dbt.adapters.contracts.relation import RelationConfig from dbt.adapters.contracts.relation import RelationConfig
from dbt.adapters.base.connections import ( from dbt.adapters.base.connections import (

View File

@@ -1,7 +1,7 @@
import abc import abc
from functools import wraps from functools import wraps
from typing import Callable, Optional, Any, FrozenSet, Dict, Set from typing import Callable, Optional, Any, FrozenSet, Dict, Set
from dbt.common.events.functions import warn_or_error from dbt_common.events.functions import warn_or_error
from dbt.adapters.events.types import AdapterDeprecationWarning from dbt.adapters.events.types import AdapterDeprecationWarning
Decorator = Callable[[Any], Callable] Decorator = Callable[[Any], Callable]

View File

@@ -3,7 +3,7 @@ from typing import Optional, Callable, Dict, Any
from dbt.adapters.clients.jinja import QueryStringGenerator from dbt.adapters.clients.jinja import QueryStringGenerator
from dbt.adapters.contracts.connection import AdapterRequiredConfig, QueryComment from dbt.adapters.contracts.connection import AdapterRequiredConfig, QueryComment
from dbt.common.exceptions import DbtRuntimeError from dbt_common.exceptions import DbtRuntimeError
class QueryHeaderContextWrapper: class QueryHeaderContextWrapper:

View File

@@ -12,10 +12,10 @@ from dbt.adapters.contracts.relation import (
Path, Path,
) )
from dbt.adapters.exceptions import MultipleDatabasesNotAllowedError, ApproximateMatchError from dbt.adapters.exceptions import MultipleDatabasesNotAllowedError, ApproximateMatchError
from dbt.common.utils import filter_null_values, deep_merge from dbt_common.utils import filter_null_values, deep_merge
from dbt.adapters.utils import classproperty from dbt.adapters.utils import classproperty
import dbt.common.exceptions import dbt_common.exceptions
Self = TypeVar("Self", bound="BaseRelation") Self = TypeVar("Self", bound="BaseRelation")
@@ -97,7 +97,7 @@ class BaseRelation(FakeAPIObject, Hashable):
if not search: if not search:
# nothing was passed in # nothing was passed in
raise dbt.common.exceptions.DbtRuntimeError( raise dbt_common.exceptions.DbtRuntimeError(
"Tried to match relation, but no search path was passed!" "Tried to match relation, but no search path was passed!"
) )
@@ -360,7 +360,7 @@ class InformationSchema(BaseRelation):
def __post_init__(self): def __post_init__(self):
if not isinstance(self.information_schema_view, (type(None), str)): if not isinstance(self.information_schema_view, (type(None), str)):
raise dbt.common.exceptions.CompilationError( raise dbt_common.exceptions.CompilationError(
"Got an invalid name: {}".format(self.information_schema_view) "Got an invalid name: {}".format(self.information_schema_view)
) )

View File

@@ -7,6 +7,7 @@ from dbt.adapters.reference_keys import (
_make_ref_key_dict, _make_ref_key_dict,
_ReferenceKey, _ReferenceKey,
) )
from dbt.adapters.exceptions.cache import ( from dbt.adapters.exceptions.cache import (
NewNameAlreadyInCacheError, NewNameAlreadyInCacheError,
ReferencedLinkNotCachedError, ReferencedLinkNotCachedError,
@@ -14,9 +15,9 @@ from dbt.adapters.exceptions.cache import (
TruncatedModelNameCausedCollisionError, TruncatedModelNameCausedCollisionError,
NoneRelationFoundError, NoneRelationFoundError,
) )
from dbt.common.events.functions import fire_event, fire_event_if from dbt_common.events.functions import fire_event, fire_event_if
from dbt.adapters.events.types import CacheAction, CacheDumpGraph from dbt.adapters.events.types import CacheAction, CacheDumpGraph
from dbt.common.utils.formatting import lowercase from dbt_common.utils.formatting import lowercase
def dot_separated(key: _ReferenceKey) -> str: def dot_separated(key: _ReferenceKey) -> str:

View File

@@ -1,5 +1,5 @@
from typing import Dict, Any from typing import Dict, Any
from dbt.common.clients.jinja import BaseMacroGenerator, get_environment from dbt_common.clients.jinja import BaseMacroGenerator, get_environment
class QueryStringGenerator(BaseMacroGenerator): class QueryStringGenerator(BaseMacroGenerator):

View File

@@ -16,21 +16,21 @@ from typing_extensions import Protocol, Annotated
from mashumaro.jsonschema.annotations import Pattern from mashumaro.jsonschema.annotations import Pattern
from dbt.adapters.utils import translate_aliases from dbt.adapters.utils import translate_aliases
from dbt.common.exceptions import DbtInternalError from dbt_common.exceptions import DbtInternalError
from dbt.common.dataclass_schema import ( from dbt_common.dataclass_schema import (
dbtClassMixin, dbtClassMixin,
StrEnum, StrEnum,
ExtensibleDbtClassMixin, ExtensibleDbtClassMixin,
ValidatedStringMixin, ValidatedStringMixin,
) )
from dbt.common.contracts.util import Replaceable from dbt_common.contracts.util import Replaceable
from dbt.common.utils import md5 from dbt_common.utils import md5
from dbt.common.events.functions import fire_event from dbt_common.events.functions import fire_event
from dbt.adapters.events.types import NewConnectionOpening from dbt.adapters.events.types import NewConnectionOpening
# TODO: this is a very bad dependency - shared global state # TODO: this is a very bad dependency - shared global state
from dbt.common.events.contextvars import get_node_info from dbt_common.events.contextvars import get_node_info
class Identifier(ValidatedStringMixin): class Identifier(ValidatedStringMixin):

View File

@@ -1,7 +1,7 @@
from typing import Optional from typing import Optional
from typing_extensions import Protocol from typing_extensions import Protocol
from dbt.common.clients.jinja import MacroProtocol from dbt_common.clients.jinja import MacroProtocol
class MacroResolverProtocol(Protocol): class MacroResolverProtocol(Protocol):

View File

@@ -6,11 +6,11 @@ from typing import (
) )
from typing_extensions import Protocol from typing_extensions import Protocol
from dbt.common.dataclass_schema import dbtClassMixin, StrEnum from dbt_common.dataclass_schema import dbtClassMixin, StrEnum
from dbt.common.contracts.util import Replaceable from dbt_common.contracts.util import Replaceable
from dbt.common.exceptions import CompilationError, DataclassNotDictError from dbt_common.exceptions import CompilationError, DataclassNotDictError
from dbt.common.utils import deep_merge from dbt_common.utils import deep_merge
class RelationType(StrEnum): class RelationType(StrEnum):

View File

@@ -1,5 +1,5 @@
# Aliasing common Level classes in order to make custom, but not overly-verbose versions that have PROTO_TYPES_MODULE set to the adapter-specific generated types_pb2 module # Aliasing common Level classes in order to make custom, but not overly-verbose versions that have PROTO_TYPES_MODULE set to the adapter-specific generated types_pb2 module
from dbt.common.events.base_types import ( from dbt_common.events.base_types import (
BaseEvent, BaseEvent,
DynamicLevel as CommonDyanicLevel, DynamicLevel as CommonDyanicLevel,
TestLevel as CommonTestLevel, TestLevel as CommonTestLevel,

View File

@@ -7,10 +7,10 @@ from dbt.adapters.events.types import (
AdapterEventWarning, AdapterEventWarning,
AdapterEventError, AdapterEventError,
) )
from dbt.common.events import get_event_manager from dbt_common.events import get_event_manager
from dbt.common.events.contextvars import get_node_info from dbt_common.events.contextvars import get_node_info
from dbt.common.events.event_handler import set_package_logging from dbt_common.events.event_handler import set_package_logging
from dbt.common.events.functions import fire_event from dbt_common.events.functions import fire_event
@dataclass @dataclass

View File

@@ -1,5 +1,5 @@
from dbt.adapters.events.base_types import WarnLevel, InfoLevel, ErrorLevel, DebugLevel from dbt.adapters.events.base_types import WarnLevel, InfoLevel, ErrorLevel, DebugLevel
from dbt.common.ui import line_wrap_message, warning_tag from dbt_common.ui import line_wrap_message, warning_tag
def format_adapter_message(name, base_msg, args) -> str: def format_adapter_message(name, base_msg, args) -> str:

View File

@@ -1,6 +1,6 @@
from typing import Mapping, Any from typing import Mapping, Any
from dbt.common.exceptions import DbtValidationError from dbt_common.exceptions import DbtValidationError
class AliasError(DbtValidationError): class AliasError(DbtValidationError):

View File

@@ -1,7 +1,7 @@
import re import re
from typing import Dict from typing import Dict
from dbt.common.exceptions import DbtInternalError from dbt_common.exceptions import DbtInternalError
class CacheInconsistencyError(DbtInternalError): class CacheInconsistencyError(DbtInternalError):

View File

@@ -1,7 +1,7 @@
from typing import List, Mapping, Any from typing import List, Mapping, Any
from dbt.common.exceptions import CompilationError, DbtDatabaseError from dbt_common.exceptions import CompilationError, DbtDatabaseError
from dbt.common.ui import line_wrap_message from dbt_common.ui import line_wrap_message
class MissingConfigError(CompilationError): class MissingConfigError(CompilationError):

View File

@@ -1,6 +1,6 @@
from typing import List from typing import List
from dbt.common.exceptions import DbtRuntimeError, DbtDatabaseError from dbt_common.exceptions import DbtRuntimeError, DbtDatabaseError
class InvalidConnectionError(DbtRuntimeError): class InvalidConnectionError(DbtRuntimeError):

View File

@@ -1,6 +1,6 @@
from typing import Any from typing import Any
from dbt.common.exceptions import NotImplementedError, CompilationError from dbt_common.exceptions import NotImplementedError, CompilationError
class UnexpectedDbReferenceError(NotImplementedError): class UnexpectedDbReferenceError(NotImplementedError):

View File

@@ -10,12 +10,12 @@ from typing import Any, Dict, List, Optional, Set, Type
from dbt.adapters.base.plugin import AdapterPlugin from dbt.adapters.base.plugin import AdapterPlugin
from dbt.adapters.protocol import AdapterConfig, AdapterProtocol, RelationProtocol from dbt.adapters.protocol import AdapterConfig, AdapterProtocol, RelationProtocol
from dbt.adapters.contracts.connection import AdapterRequiredConfig, Credentials from dbt.adapters.contracts.connection import AdapterRequiredConfig, Credentials
from dbt.common.events.functions import fire_event from dbt_common.events.functions import fire_event
from dbt.adapters.events.types import AdapterImportError, PluginLoadError, AdapterRegistered from dbt.adapters.events.types import AdapterImportError, PluginLoadError, AdapterRegistered
from dbt.common.exceptions import DbtInternalError, DbtRuntimeError from dbt_common.exceptions import DbtInternalError, DbtRuntimeError
from dbt.adapters.include.global_project import PACKAGE_PATH as GLOBAL_PROJECT_PATH from dbt.adapters.include.global_project import PACKAGE_PATH as GLOBAL_PROJECT_PATH
from dbt.adapters.include.global_project import PROJECT_NAME as GLOBAL_PROJECT_NAME from dbt.adapters.include.global_project import PROJECT_NAME as GLOBAL_PROJECT_NAME
from dbt.common.semver import VersionSpecifier from dbt_common.semver import VersionSpecifier
Adapter = AdapterProtocol Adapter = AdapterProtocol

View File

@@ -18,8 +18,8 @@ import agate
from dbt.adapters.contracts.connection import Connection, AdapterRequiredConfig, AdapterResponse from dbt.adapters.contracts.connection import Connection, AdapterRequiredConfig, AdapterResponse
from dbt.adapters.contracts.macros import MacroResolverProtocol from dbt.adapters.contracts.macros import MacroResolverProtocol
from dbt.adapters.contracts.relation import Policy, HasQuoting, RelationConfig from dbt.adapters.contracts.relation import Policy, HasQuoting, RelationConfig
from dbt.common.contracts.config.base import BaseConfig from dbt_common.contracts.config.base import BaseConfig
from dbt.common.clients.jinja import MacroProtocol from dbt_common.clients.jinja import MacroProtocol
@dataclass @dataclass

View File

@@ -2,7 +2,7 @@ from dataclasses import dataclass
from typing import Union, Dict from typing import Union, Dict
import agate import agate
from dbt.common.utils import filter_null_values from dbt_common.utils import filter_null_values
""" """

View File

@@ -3,7 +3,7 @@ from dataclasses import dataclass
from typing import Hashable from typing import Hashable
from dbt.adapters.relation_configs.config_base import RelationConfigBase from dbt.adapters.relation_configs.config_base import RelationConfigBase
from dbt.common.dataclass_schema import StrEnum from dbt_common.dataclass_schema import StrEnum
class RelationConfigChangeAction(StrEnum): class RelationConfigChangeAction(StrEnum):

View File

@@ -1,7 +1,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Set, Optional from typing import Set, Optional
from dbt.common.exceptions import DbtRuntimeError from dbt_common.exceptions import DbtRuntimeError
@dataclass(frozen=True, eq=True, unsafe_hash=True) @dataclass(frozen=True, eq=True, unsafe_hash=True)

View File

@@ -5,13 +5,13 @@ from typing import List, Optional, Tuple, Any, Iterable, Dict
import agate import agate
from dbt.adapters.events.types import ConnectionUsed, SQLQuery, SQLCommit, SQLQueryStatus from dbt.adapters.events.types import ConnectionUsed, SQLQuery, SQLCommit, SQLQueryStatus
import dbt.common.clients.agate_helper import dbt_common.clients.agate_helper
import dbt.common.exceptions import dbt_common.exceptions
from dbt.adapters.base import BaseConnectionManager from dbt.adapters.base import BaseConnectionManager
from dbt.adapters.contracts.connection import Connection, ConnectionState, AdapterResponse from dbt.adapters.contracts.connection import Connection, ConnectionState, AdapterResponse
from dbt.common.events.functions import fire_event from dbt_common.events.functions import fire_event
from dbt.common.events.contextvars import get_node_info from dbt_common.events.contextvars import get_node_info
from dbt.common.utils import cast_to_str from dbt_common.utils import cast_to_str
class SQLConnectionManager(BaseConnectionManager): class SQLConnectionManager(BaseConnectionManager):
@@ -27,7 +27,7 @@ class SQLConnectionManager(BaseConnectionManager):
@abc.abstractmethod @abc.abstractmethod
def cancel(self, connection: Connection): def cancel(self, connection: Connection):
"""Cancel the given connection.""" """Cancel the given connection."""
raise dbt.common.exceptions.base.NotImplementedError( raise dbt_common.exceptions.base.NotImplementedError(
"`cancel` is not implemented for this adapter!" "`cancel` is not implemented for this adapter!"
) )
@@ -95,7 +95,7 @@ class SQLConnectionManager(BaseConnectionManager):
@abc.abstractmethod @abc.abstractmethod
def get_response(cls, cursor: Any) -> AdapterResponse: def get_response(cls, cursor: Any) -> AdapterResponse:
"""Get the status of the cursor.""" """Get the status of the cursor."""
raise dbt.common.exceptions.base.NotImplementedError( raise dbt_common.exceptions.base.NotImplementedError(
"`get_response` is not implemented for this adapter!" "`get_response` is not implemented for this adapter!"
) )
@@ -131,7 +131,7 @@ class SQLConnectionManager(BaseConnectionManager):
rows = cursor.fetchall() rows = cursor.fetchall()
data = cls.process_results(column_names, rows) data = cls.process_results(column_names, rows)
return dbt.common.clients.agate_helper.table_from_data_flat(data, column_names) return dbt_common.clients.agate_helper.table_from_data_flat(data, column_names)
def execute( def execute(
self, sql: str, auto_begin: bool = False, fetch: bool = False, limit: Optional[int] = None self, sql: str, auto_begin: bool = False, fetch: bool = False, limit: Optional[int] = None
@@ -142,7 +142,7 @@ class SQLConnectionManager(BaseConnectionManager):
if fetch: if fetch:
table = self.get_result_from_cursor(cursor, limit) table = self.get_result_from_cursor(cursor, limit)
else: else:
table = dbt.common.clients.agate_helper.empty_table() table = dbt_common.clients.agate_helper.empty_table()
return response, table return response, table
def add_begin_query(self): def add_begin_query(self):
@@ -158,7 +158,7 @@ class SQLConnectionManager(BaseConnectionManager):
def begin(self): def begin(self):
connection = self.get_thread_connection() connection = self.get_thread_connection()
if connection.transaction_open is True: if connection.transaction_open is True:
raise dbt.common.exceptions.DbtInternalError( raise dbt_common.exceptions.DbtInternalError(
'Tried to begin a new transaction on connection "{}", but ' 'Tried to begin a new transaction on connection "{}", but '
"it already had one open!".format(connection.name) "it already had one open!".format(connection.name)
) )
@@ -171,7 +171,7 @@ class SQLConnectionManager(BaseConnectionManager):
def commit(self): def commit(self):
connection = self.get_thread_connection() connection = self.get_thread_connection()
if connection.transaction_open is False: if connection.transaction_open is False:
raise dbt.common.exceptions.DbtInternalError( raise dbt_common.exceptions.DbtInternalError(
'Tried to commit transaction on connection "{}", but ' 'Tried to commit transaction on connection "{}", but '
"it does not have one open!".format(connection.name) "it does not have one open!".format(connection.name)
) )

View File

@@ -7,7 +7,7 @@ from dbt.adapters.exceptions import RelationTypeNullError
from dbt.adapters.base import BaseAdapter, available from dbt.adapters.base import BaseAdapter, available
from dbt.adapters.cache import _make_ref_key_dict from dbt.adapters.cache import _make_ref_key_dict
from dbt.adapters.sql import SQLConnectionManager from dbt.adapters.sql import SQLConnectionManager
from dbt.common.events.functions import fire_event from dbt_common.events.functions import fire_event
from dbt.adapters.base.relation import BaseRelation from dbt.adapters.base.relation import BaseRelation

View File

@@ -2,7 +2,7 @@ import dataclasses
from datetime import datetime from datetime import datetime
from typing import ClassVar, Type, TypeVar, Dict, Any, Optional from typing import ClassVar, Type, TypeVar, Dict, Any, Optional
from dbt.common.clients.system import write_json, read_json from dbt_common.clients.system import write_json, read_json
from dbt.exceptions import ( from dbt.exceptions import (
DbtInternalError, DbtInternalError,
DbtRuntimeError, DbtRuntimeError,
@@ -10,9 +10,9 @@ from dbt.exceptions import (
) )
from dbt.version import __version__ from dbt.version import __version__
from dbt.common.events.functions import get_metadata_vars from dbt_common.events.functions import get_metadata_vars
from dbt.common.invocation import get_invocation_id from dbt_common.invocation import get_invocation_id
from dbt.common.dataclass_schema import dbtClassMixin from dbt_common.dataclass_schema import dbtClassMixin
from mashumaro.jsonschema import build_json_schema from mashumaro.jsonschema import build_json_schema
from mashumaro.jsonschema.dialects import DRAFT_2020_12 from mashumaro.jsonschema.dialects import DRAFT_2020_12

View File

@@ -2,9 +2,9 @@ from typing import Dict, Union, Optional, NamedTuple, Any, List
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from dbt.common.dataclass_schema import dbtClassMixin from dbt_common.dataclass_schema import dbtClassMixin
from dbt.common.utils.formatting import lowercase from dbt_common.utils.formatting import lowercase
from dbt.common.contracts.util import Replaceable from dbt_common.contracts.util import Replaceable
from dbt.artifacts.base import ArtifactMixin, BaseArtifactMetadata, schema_version from dbt.artifacts.base import ArtifactMixin, BaseArtifactMetadata, schema_version
Primitive = Union[bool, str, float, None] Primitive = Union[bool, str, float, None]

View File

@@ -4,8 +4,8 @@ from datetime import datetime
from dbt.artifacts.results import ExecutionResult, FreshnessStatus, NodeResult, TimingInfo from dbt.artifacts.results import ExecutionResult, FreshnessStatus, NodeResult, TimingInfo
from dbt.artifacts.base import ArtifactMixin, VersionedSchema, schema_version, BaseArtifactMetadata from dbt.artifacts.base import ArtifactMixin, VersionedSchema, schema_version, BaseArtifactMetadata
from dbt.common.dataclass_schema import dbtClassMixin, StrEnum from dbt_common.dataclass_schema import dbtClassMixin, StrEnum
from dbt.common.exceptions import DbtInternalError from dbt_common.exceptions import DbtInternalError
from dbt.contracts.graph.unparsed import FreshnessThreshold from dbt.contracts.graph.unparsed import FreshnessThreshold
from dbt.contracts.graph.nodes import SourceDefinition from dbt.contracts.graph.nodes import SourceDefinition

View File

@@ -1,11 +1,10 @@
from dbt.contracts.graph.nodes import ResultNode from dbt.contracts.graph.nodes import ResultNode
from dbt.common.events.functions import fire_event from dbt_common.events.functions import fire_event
from dbt.events.types import TimingInfoCollected from dbt.events.types import TimingInfoCollected
from dbt.common.events.contextvars import get_node_info from dbt_common.events.contextvars import get_node_info
from dbt.common.events.helpers import datetime_to_json_string from dbt_common.events.helpers import datetime_to_json_string
from dbt.logger import TimingProcessor from dbt_common.utils import cast_to_str, cast_to_int
from dbt.common.utils import cast_to_str, cast_to_int from dbt_common.dataclass_schema import dbtClassMixin, StrEnum
from dbt.common.dataclass_schema import dbtClassMixin, StrEnum
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
@@ -45,8 +44,6 @@ class collect_timing_info:
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
self.timing_info.end() self.timing_info.end()
self.callback(self.timing_info) self.callback(self.timing_info)
# Note: when legacy logger is removed, we can remove the following line
with TimingProcessor(self.timing_info):
fire_event( fire_event(
TimingInfoCollected( TimingInfoCollected(
timing_info=self.timing_info.to_msg_dict(), node_info=get_node_info() timing_info=self.timing_info.to_msg_dict(), node_info=get_node_info()

View File

@@ -19,7 +19,7 @@ from dbt.artifacts.results import (
ResultNode, ResultNode,
ExecutionResult, ExecutionResult,
) )
from dbt.common.clients.system import write_json from dbt_common.clients.system import write_json
@dataclass @dataclass

View File

@@ -13,12 +13,12 @@ from dbt.cli.resolvers import default_log_path, default_project_dir
from dbt.cli.types import Command as CliCommand from dbt.cli.types import Command as CliCommand
from dbt.config.project import read_project_flags from dbt.config.project import read_project_flags
from dbt.contracts.project import ProjectFlags from dbt.contracts.project import ProjectFlags
from dbt.common import ui from dbt_common import ui
from dbt.common.events import functions from dbt_common.events import functions
from dbt.common.exceptions import DbtInternalError from dbt_common.exceptions import DbtInternalError
from dbt.common.clients import jinja from dbt_common.clients import jinja
from dbt.deprecations import renamed_env_var from dbt.deprecations import renamed_env_var
from dbt.common.helper_types import WarnErrorOptions from dbt_common.helper_types import WarnErrorOptions
from dbt.events import ALL_EVENT_NAMES from dbt.events import ALL_EVENT_NAMES
if os.name != "nt": if os.name != "nt":

View File

@@ -19,7 +19,7 @@ from dbt.cli.exceptions import (
from dbt.contracts.graph.manifest import Manifest from dbt.contracts.graph.manifest import Manifest
from dbt.artifacts.catalog import CatalogArtifact from dbt.artifacts.catalog import CatalogArtifact
from dbt.artifacts.run import RunExecutionResult from dbt.artifacts.run import RunExecutionResult
from dbt.common.events.base_types import EventMsg from dbt_common.events.base_types import EventMsg
from dbt.task.build import BuildTask from dbt.task.build import BuildTask
from dbt.task.clean import CleanTask from dbt.task.clean import CleanTask
from dbt.task.clone import CloneTask from dbt.task.clone import CloneTask
@@ -122,7 +122,6 @@ def global_flags(func):
@p.cache_selected_only @p.cache_selected_only
@p.debug @p.debug
@p.deprecated_print @p.deprecated_print
@p.enable_legacy_logger
@p.fail_fast @p.fail_fast
@p.log_cache_events @p.log_cache_events
@p.log_file_max_bytes @p.log_file_max_bytes

View File

@@ -3,9 +3,9 @@ from click import ParamType, Choice
from dbt.config.utils import parse_cli_yaml_string from dbt.config.utils import parse_cli_yaml_string
from dbt.events import ALL_EVENT_NAMES from dbt.events import ALL_EVENT_NAMES
from dbt.exceptions import ValidationError, OptionNotYamlDictError from dbt.exceptions import ValidationError, OptionNotYamlDictError
from dbt.common.exceptions import DbtValidationError from dbt_common.exceptions import DbtValidationError
from dbt.common.helper_types import WarnErrorOptions from dbt_common.helper_types import WarnErrorOptions
class YAML(ParamType): class YAML(ParamType):

View File

@@ -90,12 +90,6 @@ empty = click.option(
is_flag=True, is_flag=True,
) )
enable_legacy_logger = click.option(
"--enable-legacy-logger/--no-enable-legacy-logger",
envvar="DBT_ENABLE_LEGACY_LOGGER",
hidden=True,
)
exclude = click.option( exclude = click.option(
"--exclude", "--exclude",
envvar=None, envvar=None,

View File

@@ -1,5 +1,5 @@
import dbt.tracking import dbt.tracking
from dbt.common.invocation import reset_invocation_id from dbt_common.invocation import reset_invocation_id
from dbt.version import installed as installed_version from dbt.version import installed as installed_version
from dbt.adapters.factory import adapter_management from dbt.adapters.factory import adapter_management
from dbt.flags import set_flags, get_flag_dict from dbt.flags import set_flags, get_flag_dict
@@ -11,8 +11,8 @@ from dbt.cli.flags import Flags
from dbt.config import RuntimeConfig from dbt.config import RuntimeConfig
from dbt.config.runtime import load_project, load_profile, UnsetProfile from dbt.config.runtime import load_project, load_profile, UnsetProfile
from dbt.common.events.base_types import EventLevel from dbt_common.events.base_types import EventLevel
from dbt.common.events.functions import ( from dbt_common.events.functions import (
fire_event, fire_event,
LOG_VERSION, LOG_VERSION,
) )
@@ -22,14 +22,14 @@ from dbt.events.types import (
MainReportArgs, MainReportArgs,
MainTrackingUserState, MainTrackingUserState,
) )
from dbt.common.events.helpers import get_json_string_utcnow from dbt_common.events.helpers import get_json_string_utcnow
from dbt.events.types import CommandCompleted, MainEncounteredError, MainStackTrace, ResourceReport from dbt.events.types import CommandCompleted, MainEncounteredError, MainStackTrace, ResourceReport
from dbt.common.exceptions import DbtBaseException as DbtException from dbt_common.exceptions import DbtBaseException as DbtException
from dbt.exceptions import DbtProjectError, FailFastError from dbt.exceptions import DbtProjectError, FailFastError
from dbt.parser.manifest import parse_manifest from dbt.parser.manifest import parse_manifest
from dbt.profiler import profiler from dbt.profiler import profiler
from dbt.tracking import active_user, initialize_from_flags, track_run from dbt.tracking import active_user, initialize_from_flags, track_run
from dbt.common.utils import cast_dict_to_dict_of_strings from dbt_common.utils import cast_dict_to_dict_of_strings
from dbt.plugins import set_up_plugin_manager from dbt.plugins import set_up_plugin_manager
from click import Context from click import Context

View File

@@ -1,7 +1,7 @@
from enum import Enum from enum import Enum
from typing import List from typing import List
from dbt.common.exceptions import DbtInternalError from dbt_common.exceptions import DbtInternalError
class Command(Enum): class Command(Enum):

View File

@@ -1,8 +1,8 @@
import re import re
import os.path import os.path
from dbt.common.clients.system import run_cmd, rmdir from dbt_common.clients.system import run_cmd, rmdir
from dbt.common.events.functions import fire_event from dbt_common.events.functions import fire_event
from dbt.events.types import ( from dbt.events.types import (
GitSparseCheckoutSubdirectory, GitSparseCheckoutSubdirectory,
GitProgressCheckoutRevision, GitProgressCheckoutRevision,

View File

@@ -10,13 +10,13 @@ import jinja2.nodes
import jinja2.parser import jinja2.parser
import jinja2.sandbox import jinja2.sandbox
from dbt.common.clients.jinja import ( from dbt_common.clients.jinja import (
render_template, render_template,
get_template, get_template,
CallableMacroGenerator, CallableMacroGenerator,
MacroProtocol, MacroProtocol,
) )
from dbt.common.utils import deep_map_render from dbt_common.utils import deep_map_render
from dbt.contracts.graph.nodes import GenericTestNode from dbt.contracts.graph.nodes import GenericTestNode
from dbt.exceptions import ( from dbt.exceptions import (

View File

@@ -1,7 +1,7 @@
import jinja2 import jinja2
from dbt.common.clients.jinja import get_environment from dbt_common.clients.jinja import get_environment
from dbt.exceptions import MacroNamespaceNotStringError from dbt.exceptions import MacroNamespaceNotStringError
from dbt.common.exceptions.macros import MacroNameNotStringError from dbt_common.exceptions.macros import MacroNameNotStringError
def statically_extract_macro_calls(string, ctx, db_wrapper=None): def statically_extract_macro_calls(string, ctx, db_wrapper=None):

View File

@@ -1,7 +1,7 @@
import functools import functools
from typing import Any, Dict, List from typing import Any, Dict, List
import requests import requests
from dbt.common.events.functions import fire_event from dbt_common.events.functions import fire_event
from dbt.events.types import ( from dbt.events.types import (
RegistryProgressGETRequest, RegistryProgressGETRequest,
RegistryProgressGETResponse, RegistryProgressGETResponse,
@@ -13,9 +13,9 @@ from dbt.events.types import (
RegistryResponseExtraNestedKeys, RegistryResponseExtraNestedKeys,
) )
from dbt.utils import memoized from dbt.utils import memoized
from dbt.common.utils.connection import connection_exception_retry from dbt_common.utils.connection import connection_exception_retry
from dbt import deprecations from dbt import deprecations
from dbt.common import semver from dbt_common import semver
import os import os
if os.getenv("DBT_PACKAGE_HUB_URL"): if os.getenv("DBT_PACKAGE_HUB_URL"):

View File

@@ -1,5 +1,5 @@
import dbt.common.exceptions.base import dbt_common.exceptions.base
import dbt.exceptions import dbt_common.exceptions
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import yaml import yaml
@@ -61,4 +61,4 @@ def load_yaml_text(contents, path=None):
else: else:
error = str(e) error = str(e)
raise dbt.common.exceptions.base.DbtValidationError(error) raise dbt_common.exceptions.base.DbtValidationError(error)

View File

@@ -1,360 +0,0 @@
import re
from collections import namedtuple
from dbt.common.exceptions import (
BlockDefinitionNotAtTopError,
DbtInternalError,
MissingCloseTagError,
MissingControlFlowStartTagError,
NestedTagsError,
UnexpectedControlFlowEndTagError,
UnexpectedMacroEOFError,
)
def regex(pat):
return re.compile(pat, re.DOTALL | re.MULTILINE)
class BlockData:
"""raw plaintext data from the top level of the file."""
def __init__(self, contents):
self.block_type_name = "__dbt__data"
self.contents = contents
self.full_block = contents
class BlockTag:
def __init__(self, block_type_name, block_name, contents=None, full_block=None, **kw):
self.block_type_name = block_type_name
self.block_name = block_name
self.contents = contents
self.full_block = full_block
def __str__(self):
return "BlockTag({!r}, {!r})".format(self.block_type_name, self.block_name)
def __repr__(self):
return str(self)
@property
def end_block_type_name(self):
return "end{}".format(self.block_type_name)
def end_pat(self):
# we don't want to use string formatting here because jinja uses most
# of the string formatting operators in its syntax...
pattern = "".join(
(
r"(?P<endblock>((?:\s*\{\%\-|\{\%)\s*",
self.end_block_type_name,
r"\s*(?:\-\%\}\s*|\%\})))",
)
)
return regex(pattern)
Tag = namedtuple("Tag", "block_type_name block_name start end")
_NAME_PATTERN = r"[A-Za-z_][A-Za-z_0-9]*"
COMMENT_START_PATTERN = regex(r"(?:(?P<comment_start>(\s*\{\#)))")
COMMENT_END_PATTERN = regex(r"(.*?)(\s*\#\})")
RAW_START_PATTERN = regex(r"(?:\s*\{\%\-|\{\%)\s*(?P<raw_start>(raw))\s*(?:\-\%\}\s*|\%\})")
EXPR_START_PATTERN = regex(r"(?P<expr_start>(\{\{\s*))")
EXPR_END_PATTERN = regex(r"(?P<expr_end>(\s*\}\}))")
BLOCK_START_PATTERN = regex(
"".join(
(
r"(?:\s*\{\%\-|\{\%)\s*",
r"(?P<block_type_name>({}))".format(_NAME_PATTERN),
# some blocks have a 'block name'.
r"(?:\s+(?P<block_name>({})))?".format(_NAME_PATTERN),
)
)
)
RAW_BLOCK_PATTERN = regex(
"".join(
(
r"(?:\s*\{\%\-|\{\%)\s*raw\s*(?:\-\%\}\s*|\%\})",
r"(?:.*?)",
r"(?:\s*\{\%\-|\{\%)\s*endraw\s*(?:\-\%\}\s*|\%\})",
)
)
)
TAG_CLOSE_PATTERN = regex(r"(?:(?P<tag_close>(\-\%\}\s*|\%\})))")
# stolen from jinja's lexer. Note that we've consumed all prefix whitespace by
# the time we want to use this.
STRING_PATTERN = regex(r"(?P<string>('([^'\\]*(?:\\.[^'\\]*)*)'|" r'"([^"\\]*(?:\\.[^"\\]*)*)"))')
QUOTE_START_PATTERN = regex(r"""(?P<quote>(['"]))""")
class TagIterator:
def __init__(self, data):
self.data = data
self.blocks = []
self._parenthesis_stack = []
self.pos = 0
def linepos(self, end=None) -> str:
"""Given an absolute position in the input data, return a pair of
line number + relative position to the start of the line.
"""
end_val: int = self.pos if end is None else end
data = self.data[:end_val]
# if not found, rfind returns -1, and -1+1=0, which is perfect!
last_line_start = data.rfind("\n") + 1
# it's easy to forget this, but line numbers are 1-indexed
line_number = data.count("\n") + 1
return f"{line_number}:{end_val - last_line_start}"
def advance(self, new_position):
self.pos = new_position
def rewind(self, amount=1):
self.pos -= amount
def _search(self, pattern):
return pattern.search(self.data, self.pos)
def _match(self, pattern):
return pattern.match(self.data, self.pos)
def _first_match(self, *patterns, **kwargs):
matches = []
for pattern in patterns:
# default to 'search', but sometimes we want to 'match'.
if kwargs.get("method", "search") == "search":
match = self._search(pattern)
else:
match = self._match(pattern)
if match:
matches.append(match)
if not matches:
return None
# if there are multiple matches, pick the least greedy match
# TODO: do I need to account for m.start(), or is this ok?
return min(matches, key=lambda m: m.end())
def _expect_match(self, expected_name, *patterns, **kwargs):
match = self._first_match(*patterns, **kwargs)
if match is None:
raise UnexpectedMacroEOFError(expected_name, self.data[self.pos :])
return match
def handle_expr(self, match):
"""Handle an expression. At this point we're at a string like:
{{ 1 + 2 }}
^ right here
And the match contains "{{ "
We expect to find a `}}`, but we might find one in a string before
that. Imagine the case of `{{ 2 * "}}" }}`...
You're not allowed to have blocks or comments inside an expr so it is
pretty straightforward, I hope: only strings can get in the way.
"""
self.advance(match.end())
while True:
match = self._expect_match("}}", EXPR_END_PATTERN, QUOTE_START_PATTERN)
if match.groupdict().get("expr_end") is not None:
break
else:
# it's a quote. we haven't advanced for this match yet, so
# just slurp up the whole string, no need to rewind.
match = self._expect_match("string", STRING_PATTERN)
self.advance(match.end())
self.advance(match.end())
def handle_comment(self, match):
self.advance(match.end())
match = self._expect_match("#}", COMMENT_END_PATTERN)
self.advance(match.end())
def _expect_block_close(self):
"""Search for the tag close marker.
To the right of the type name, there are a few possiblities:
- a name (handled by the regex's 'block_name')
- any number of: `=`, `(`, `)`, strings, etc (arguments)
- nothing
followed eventually by a %}
So the only characters we actually have to worry about in this context
are quote and `%}` - nothing else can hide the %} and be valid jinja.
"""
while True:
end_match = self._expect_match(
'tag close ("%}")', QUOTE_START_PATTERN, TAG_CLOSE_PATTERN
)
self.advance(end_match.end())
if end_match.groupdict().get("tag_close") is not None:
return
# must be a string. Rewind to its start and advance past it.
self.rewind()
string_match = self._expect_match("string", STRING_PATTERN)
self.advance(string_match.end())
def handle_raw(self):
# raw blocks are super special, they are a single complete regex
match = self._expect_match("{% raw %}...{% endraw %}", RAW_BLOCK_PATTERN)
self.advance(match.end())
return match.end()
def handle_tag(self, match):
"""The tag could be one of a few things:
{% mytag %}
{% mytag x = y %}
{% mytag x = "y" %}
{% mytag x.y() %}
{% mytag foo("a", "b", c="d") %}
But the key here is that it's always going to be `{% mytag`!
"""
groups = match.groupdict()
# always a value
block_type_name = groups["block_type_name"]
# might be None
block_name = groups.get("block_name")
start_pos = self.pos
if block_type_name == "raw":
match = self._expect_match("{% raw %}...{% endraw %}", RAW_BLOCK_PATTERN)
self.advance(match.end())
else:
self.advance(match.end())
self._expect_block_close()
return Tag(
block_type_name=block_type_name, block_name=block_name, start=start_pos, end=self.pos
)
def find_tags(self):
while True:
match = self._first_match(
BLOCK_START_PATTERN, COMMENT_START_PATTERN, EXPR_START_PATTERN
)
if match is None:
break
self.advance(match.start())
# start = self.pos
groups = match.groupdict()
comment_start = groups.get("comment_start")
expr_start = groups.get("expr_start")
block_type_name = groups.get("block_type_name")
if comment_start is not None:
self.handle_comment(match)
elif expr_start is not None:
self.handle_expr(match)
elif block_type_name is not None:
yield self.handle_tag(match)
else:
raise DbtInternalError(
"Invalid regex match in next_block, expected block start, "
"expr start, or comment start"
)
def __iter__(self):
return self.find_tags()
_CONTROL_FLOW_TAGS = {
"if": "endif",
"for": "endfor",
}
_CONTROL_FLOW_END_TAGS = {v: k for k, v in _CONTROL_FLOW_TAGS.items()}
class BlockIterator:
def __init__(self, data):
self.tag_parser = TagIterator(data)
self.current = None
self.stack = []
self.last_position = 0
@property
def current_end(self):
if self.current is None:
return 0
else:
return self.current.end
@property
def data(self):
return self.tag_parser.data
def is_current_end(self, tag):
return (
tag.block_type_name.startswith("end")
and self.current is not None
and tag.block_type_name[3:] == self.current.block_type_name
)
def find_blocks(self, allowed_blocks=None, collect_raw_data=True):
"""Find all top-level blocks in the data."""
if allowed_blocks is None:
allowed_blocks = {"snapshot", "macro", "materialization", "docs"}
for tag in self.tag_parser.find_tags():
if tag.block_type_name in _CONTROL_FLOW_TAGS:
self.stack.append(tag.block_type_name)
elif tag.block_type_name in _CONTROL_FLOW_END_TAGS:
found = None
if self.stack:
found = self.stack.pop()
else:
expected = _CONTROL_FLOW_END_TAGS[tag.block_type_name]
raise UnexpectedControlFlowEndTagError(tag, expected, self.tag_parser)
expected = _CONTROL_FLOW_TAGS[found]
if expected != tag.block_type_name:
raise MissingControlFlowStartTagError(tag, expected, self.tag_parser)
if tag.block_type_name in allowed_blocks:
if self.stack:
raise BlockDefinitionNotAtTopError(self.tag_parser, tag.start)
if self.current is not None:
raise NestedTagsError(outer=self.current, inner=tag)
if collect_raw_data:
raw_data = self.data[self.last_position : tag.start]
self.last_position = tag.start
if raw_data:
yield BlockData(raw_data)
self.current = tag
elif self.is_current_end(tag):
self.last_position = tag.end
assert self.current is not None
yield BlockTag(
block_type_name=self.current.block_type_name,
block_name=self.current.block_name,
contents=self.data[self.current.end : tag.start],
full_block=self.data[self.current.start : tag.end],
)
self.current = None
if self.current:
linecount = self.data[: self.current.end].count("\n") + 1
raise MissingCloseTagError(self.current.block_type_name, linecount)
if collect_raw_data:
raw_data = self.data[self.last_position :]
if raw_data:
yield BlockData(raw_data)
def lex_for_blocks(self, allowed_blocks=None, collect_raw_data=True):
return list(
self.find_blocks(allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data)
)

View File

@@ -1,251 +0,0 @@
from codecs import BOM_UTF8
import agate
import datetime
import isodate
import json
from typing import Iterable, List, Dict, Union, Optional, Any
from dbt.common.exceptions import DbtRuntimeError
from dbt.common.utils import ForgivingJSONEncoder
BOM = BOM_UTF8.decode("utf-8") # '\ufeff'
class Integer(agate.data_types.DataType):
def cast(self, d):
# by default agate will cast none as a Number
# but we need to cast it as an Integer to preserve
# the type when merging and unioning tables
if type(d) == int or d is None:
return d
else:
raise agate.exceptions.CastError('Can not parse value "%s" as Integer.' % d)
def jsonify(self, d):
return d
class Number(agate.data_types.Number):
# undo the change in https://github.com/wireservice/agate/pull/733
# i.e. do not cast True and False to numeric 1 and 0
def cast(self, d):
if type(d) == bool:
raise agate.exceptions.CastError("Do not cast True to 1 or False to 0.")
else:
return super().cast(d)
class ISODateTime(agate.data_types.DateTime):
def cast(self, d):
# this is agate.data_types.DateTime.cast with the "clever" bits removed
# so we only handle ISO8601 stuff
if isinstance(d, datetime.datetime) or d is None:
return d
elif isinstance(d, datetime.date):
return datetime.datetime.combine(d, datetime.time(0, 0, 0))
elif isinstance(d, str):
d = d.strip()
if d.lower() in self.null_values:
return None
try:
return isodate.parse_datetime(d)
except: # noqa
pass
raise agate.exceptions.CastError('Can not parse value "%s" as datetime.' % d)
def build_type_tester(
text_columns: Iterable[str], string_null_values: Optional[Iterable[str]] = ("null", "")
) -> agate.TypeTester:
types = [
Integer(null_values=("null", "")),
Number(null_values=("null", "")),
agate.data_types.Date(null_values=("null", ""), date_format="%Y-%m-%d"),
agate.data_types.DateTime(null_values=("null", ""), datetime_format="%Y-%m-%d %H:%M:%S"),
ISODateTime(null_values=("null", "")),
agate.data_types.Boolean(
true_values=("true",), false_values=("false",), null_values=("null", "")
),
agate.data_types.Text(null_values=string_null_values),
]
force = {k: agate.data_types.Text(null_values=string_null_values) for k in text_columns}
return agate.TypeTester(force=force, types=types)
DEFAULT_TYPE_TESTER = build_type_tester(())
def table_from_rows(
rows: List[Any],
column_names: Iterable[str],
text_only_columns: Optional[Iterable[str]] = None,
) -> agate.Table:
if text_only_columns is None:
column_types = DEFAULT_TYPE_TESTER
else:
# If text_only_columns are present, prevent coercing empty string or
# literal 'null' strings to a None representation.
column_types = build_type_tester(text_only_columns, string_null_values=())
return agate.Table(rows, column_names, column_types=column_types)
def table_from_data(data, column_names: Iterable[str]) -> agate.Table:
"Convert a list of dictionaries into an Agate table"
# The agate table is generated from a list of dicts, so the column order
# from `data` is not preserved. We can use `select` to reorder the columns
#
# If there is no data, create an empty table with the specified columns
if len(data) == 0:
return agate.Table([], column_names=column_names)
else:
table = agate.Table.from_object(data, column_types=DEFAULT_TYPE_TESTER)
return table.select(column_names)
def table_from_data_flat(data, column_names: Iterable[str]) -> agate.Table:
"""
Convert a list of dictionaries into an Agate table. This method does not
coerce string values into more specific types (eg. '005' will not be
coerced to '5'). Additionally, this method does not coerce values to
None (eg. '' or 'null' will retain their string literal representations).
"""
rows = []
text_only_columns = set()
for _row in data:
row = []
for col_name in column_names:
value = _row[col_name]
if isinstance(value, (dict, list, tuple)):
# Represent container types as json strings
value = json.dumps(value, cls=ForgivingJSONEncoder)
text_only_columns.add(col_name)
elif isinstance(value, str):
text_only_columns.add(col_name)
row.append(value)
rows.append(row)
return table_from_rows(
rows=rows, column_names=column_names, text_only_columns=text_only_columns
)
def empty_table():
"Returns an empty Agate table. To be used in place of None"
return agate.Table(rows=[])
def as_matrix(table):
"Return an agate table as a matrix of data sans columns"
return [r.values() for r in table.rows.values()]
def from_csv(abspath, text_columns, delimiter=","):
type_tester = build_type_tester(text_columns=text_columns)
with open(abspath, encoding="utf-8") as fp:
if fp.read(1) != BOM:
fp.seek(0)
return agate.Table.from_csv(fp, column_types=type_tester, delimiter=delimiter)
class _NullMarker:
pass
NullableAgateType = Union[agate.data_types.DataType, _NullMarker]
class ColumnTypeBuilder(Dict[str, NullableAgateType]):
def __init__(self) -> None:
super().__init__()
def __setitem__(self, key, value):
if key not in self:
super().__setitem__(key, value)
return
existing_type = self[key]
if isinstance(existing_type, _NullMarker):
# overwrite
super().__setitem__(key, value)
elif isinstance(value, _NullMarker):
# use the existing value
return
# when one table column is Number while another is Integer, force the column to Number on merge
elif isinstance(value, Integer) and isinstance(existing_type, agate.data_types.Number):
# use the existing value
return
elif isinstance(existing_type, Integer) and isinstance(value, agate.data_types.Number):
# overwrite
super().__setitem__(key, value)
elif not isinstance(value, type(existing_type)):
# actual type mismatch!
raise DbtRuntimeError(
f"Tables contain columns with the same names ({key}), "
f"but different types ({value} vs {existing_type})"
)
def finalize(self) -> Dict[str, agate.data_types.DataType]:
result: Dict[str, agate.data_types.DataType] = {}
for key, value in self.items():
if isinstance(value, _NullMarker):
# agate would make it a Number but we'll make it Integer so that if this column
# gets merged with another Integer column, it won't get forced to a Number
result[key] = Integer()
else:
result[key] = value
return result
def _merged_column_types(tables: List[agate.Table]) -> Dict[str, agate.data_types.DataType]:
# this is a lot like agate.Table.merge, but with handling for all-null
# rows being "any type".
new_columns: ColumnTypeBuilder = ColumnTypeBuilder()
for table in tables:
for i in range(len(table.columns)):
column_name: str = table.column_names[i]
column_type: NullableAgateType = table.column_types[i]
# avoid over-sensitive type inference
if all(x is None for x in table.columns[column_name]):
column_type = _NullMarker()
new_columns[column_name] = column_type
return new_columns.finalize()
def merge_tables(tables: List[agate.Table]) -> agate.Table:
"""This is similar to agate.Table.merge, but it handles rows of all 'null'
values more gracefully during merges.
"""
new_columns = _merged_column_types(tables)
column_names = tuple(new_columns.keys())
column_types = tuple(new_columns.values())
rows: List[agate.Row] = []
for table in tables:
if table.column_names == column_names and table.column_types == column_types:
rows.extend(table.rows)
else:
for row in table.rows:
data = [row.get(name, None) for name in column_names]
rows.append(agate.Row(data, column_names))
# _is_fork to tell agate that we already made things into `Row`s.
return agate.Table(rows, column_names, column_types, _is_fork=True)
def get_column_value_uncased(column_name: str, row: agate.Row) -> Any:
"""Get the value of a column in this row, ignoring the casing of the column name."""
for key, value in row.items():
if key.casefold() == column_name.casefold():
return value
raise KeyError

View File

@@ -1,505 +0,0 @@
import codecs
import linecache
import os
import tempfile
from ast import literal_eval
from contextlib import contextmanager
from itertools import chain, islice
from typing import List, Union, Set, Optional, Dict, Any, Iterator, Type, Callable
from typing_extensions import Protocol
import jinja2
import jinja2.ext
import jinja2.nativetypes # type: ignore
import jinja2.nodes
import jinja2.parser
import jinja2.sandbox
from dbt.common.utils import (
get_dbt_macro_name,
get_docs_macro_name,
get_materialization_macro_name,
get_test_macro_name,
)
from dbt.common.clients._jinja_blocks import BlockIterator, BlockData, BlockTag
from dbt.common.exceptions import (
CompilationError,
DbtInternalError,
CaughtMacroErrorWithNodeError,
MaterializationArgError,
JinjaRenderingError,
UndefinedCompilationError,
)
from dbt.common.exceptions.macros import MacroReturn, UndefinedMacroError, CaughtMacroError
SUPPORTED_LANG_ARG = jinja2.nodes.Name("supported_languages", "param")
# Global which can be set by dependents of dbt-common (e.g. core via flag parsing)
MACRO_DEBUGGING = False
def _linecache_inject(source, write):
if write:
# this is the only reliable way to accomplish this. Obviously, it's
# really darn noisy and will fill your temporary directory
tmp_file = tempfile.NamedTemporaryFile(
prefix="dbt-macro-compiled-",
suffix=".py",
delete=False,
mode="w+",
encoding="utf-8",
)
tmp_file.write(source)
filename = tmp_file.name
else:
# `codecs.encode` actually takes a `bytes` as the first argument if
# the second argument is 'hex' - mypy does not know this.
rnd = codecs.encode(os.urandom(12), "hex") # type: ignore
filename = rnd.decode("ascii")
# put ourselves in the cache
cache_entry = (len(source), None, [line + "\n" for line in source.splitlines()], filename)
# linecache does in fact have an attribute `cache`, thanks
linecache.cache[filename] = cache_entry # type: ignore
return filename
class MacroFuzzParser(jinja2.parser.Parser):
def parse_macro(self):
node = jinja2.nodes.Macro(lineno=next(self.stream).lineno)
# modified to fuzz macros defined in the same file. this way
# dbt can understand the stack of macros being called.
# - @cmcarthur
node.name = get_dbt_macro_name(self.parse_assign_target(name_only=True).name)
self.parse_signature(node)
node.body = self.parse_statements(("name:endmacro",), drop_needle=True)
return node
class MacroFuzzEnvironment(jinja2.sandbox.SandboxedEnvironment):
def _parse(self, source, name, filename):
return MacroFuzzParser(self, source, name, filename).parse()
def _compile(self, source, filename):
"""Override jinja's compilation to stash the rendered source inside
the python linecache for debugging when the appropriate environment
variable is set.
If the value is 'write', also write the files to disk.
WARNING: This can write a ton of data if you aren't careful.
"""
if filename == "<template>" and MACRO_DEBUGGING:
write = MACRO_DEBUGGING == "write"
filename = _linecache_inject(source, write)
return super()._compile(source, filename) # type: ignore
class NativeSandboxEnvironment(MacroFuzzEnvironment):
code_generator_class = jinja2.nativetypes.NativeCodeGenerator
class TextMarker(str):
"""A special native-env marker that indicates a value is text and is
not to be evaluated. Use this to prevent your numbery-strings from becoming
numbers!
"""
class NativeMarker(str):
"""A special native-env marker that indicates the field should be passed to
literal_eval.
"""
class BoolMarker(NativeMarker):
pass
class NumberMarker(NativeMarker):
pass
def _is_number(value) -> bool:
return isinstance(value, (int, float)) and not isinstance(value, bool)
def quoted_native_concat(nodes):
"""This is almost native_concat from the NativeTemplate, except in the
special case of a single argument that is a quoted string and returns a
string, the quotes are re-inserted.
"""
head = list(islice(nodes, 2))
if not head:
return ""
if len(head) == 1:
raw = head[0]
if isinstance(raw, TextMarker):
return str(raw)
elif not isinstance(raw, NativeMarker):
# return non-strings as-is
return raw
else:
# multiple nodes become a string.
return "".join([str(v) for v in chain(head, nodes)])
try:
result = literal_eval(raw)
except (ValueError, SyntaxError, MemoryError):
result = raw
if isinstance(raw, BoolMarker) and not isinstance(result, bool):
raise JinjaRenderingError(f"Could not convert value '{raw!s}' into type 'bool'")
if isinstance(raw, NumberMarker) and not _is_number(result):
raise JinjaRenderingError(f"Could not convert value '{raw!s}' into type 'number'")
return result
class NativeSandboxTemplate(jinja2.nativetypes.NativeTemplate): # mypy: ignore
environment_class = NativeSandboxEnvironment # type: ignore
def render(self, *args, **kwargs):
"""Render the template to produce a native Python type. If the
result is a single node, its value is returned. Otherwise, the
nodes are concatenated as strings. If the result can be parsed
with :func:`ast.literal_eval`, the parsed value is returned.
Otherwise, the string is returned.
"""
vars = dict(*args, **kwargs)
try:
return quoted_native_concat(self.root_render_func(self.new_context(vars)))
except Exception:
return self.environment.handle_exception()
NativeSandboxEnvironment.template_class = NativeSandboxTemplate # type: ignore
class TemplateCache:
def __init__(self) -> None:
self.file_cache: Dict[str, jinja2.Template] = {}
def get_node_template(self, node) -> jinja2.Template:
key = node.macro_sql
if key in self.file_cache:
return self.file_cache[key]
template = get_template(
string=node.macro_sql,
ctx={},
node=node,
)
self.file_cache[key] = template
return template
def clear(self):
self.file_cache.clear()
template_cache = TemplateCache()
class BaseMacroGenerator:
def __init__(self, context: Optional[Dict[str, Any]] = None) -> None:
self.context: Optional[Dict[str, Any]] = context
def get_template(self):
raise NotImplementedError("get_template not implemented!")
def get_name(self) -> str:
raise NotImplementedError("get_name not implemented!")
def get_macro(self):
name = self.get_name()
template = self.get_template()
# make the module. previously we set both vars and local, but that's
# redundant: They both end up in the same place
# make_module is in jinja2.environment. It returns a TemplateModule
module = template.make_module(vars=self.context, shared=False)
macro = module.__dict__[get_dbt_macro_name(name)]
module.__dict__.update(self.context)
return macro
@contextmanager
def exception_handler(self) -> Iterator[None]:
try:
yield
except (TypeError, jinja2.exceptions.TemplateRuntimeError) as e:
raise CaughtMacroError(e)
def call_macro(self, *args, **kwargs):
# called from __call__ methods
if self.context is None:
raise DbtInternalError("Context is still None in call_macro!")
assert self.context is not None
macro = self.get_macro()
with self.exception_handler():
try:
return macro(*args, **kwargs)
except MacroReturn as e:
return e.value
class MacroProtocol(Protocol):
name: str
macro_sql: str
class CallableMacroGenerator(BaseMacroGenerator):
def __init__(
self,
macro: MacroProtocol,
context: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__(context)
self.macro = macro
def get_template(self):
return template_cache.get_node_template(self.macro)
def get_name(self) -> str:
return self.macro.name
@contextmanager
def exception_handler(self) -> Iterator[None]:
try:
yield
except (TypeError, jinja2.exceptions.TemplateRuntimeError) as e:
raise CaughtMacroErrorWithNodeError(exc=e, node=self.macro)
except CompilationError as e:
e.stack.append(self.macro)
raise e
# this makes MacroGenerator objects callable like functions
def __call__(self, *args, **kwargs):
return self.call_macro(*args, **kwargs)
class MaterializationExtension(jinja2.ext.Extension):
tags = ["materialization"]
def parse(self, parser):
node = jinja2.nodes.Macro(lineno=next(parser.stream).lineno)
materialization_name = parser.parse_assign_target(name_only=True).name
adapter_name = "default"
node.args = []
node.defaults = []
while parser.stream.skip_if("comma"):
target = parser.parse_assign_target(name_only=True)
if target.name == "default":
pass
elif target.name == "adapter":
parser.stream.expect("assign")
value = parser.parse_expression()
adapter_name = value.value
elif target.name == "supported_languages":
target.set_ctx("param")
node.args.append(target)
parser.stream.expect("assign")
languages = parser.parse_expression()
node.defaults.append(languages)
else:
raise MaterializationArgError(materialization_name, target.name)
if SUPPORTED_LANG_ARG not in node.args:
node.args.append(SUPPORTED_LANG_ARG)
node.defaults.append(jinja2.nodes.List([jinja2.nodes.Const("sql")]))
node.name = get_materialization_macro_name(materialization_name, adapter_name)
node.body = parser.parse_statements(("name:endmaterialization",), drop_needle=True)
return node
class DocumentationExtension(jinja2.ext.Extension):
tags = ["docs"]
def parse(self, parser):
node = jinja2.nodes.Macro(lineno=next(parser.stream).lineno)
docs_name = parser.parse_assign_target(name_only=True).name
node.args = []
node.defaults = []
node.name = get_docs_macro_name(docs_name)
node.body = parser.parse_statements(("name:enddocs",), drop_needle=True)
return node
class TestExtension(jinja2.ext.Extension):
tags = ["test"]
def parse(self, parser):
node = jinja2.nodes.Macro(lineno=next(parser.stream).lineno)
test_name = parser.parse_assign_target(name_only=True).name
parser.parse_signature(node)
node.name = get_test_macro_name(test_name)
node.body = parser.parse_statements(("name:endtest",), drop_needle=True)
return node
def _is_dunder_name(name):
return name.startswith("__") and name.endswith("__")
def create_undefined(node=None):
class Undefined(jinja2.Undefined):
def __init__(self, hint=None, obj=None, name=None, exc=None):
super().__init__(hint=hint, name=name)
self.node = node
self.name = name
self.hint = hint
# jinja uses these for safety, so we have to override them.
# see https://github.com/pallets/jinja/blob/master/jinja2/sandbox.py#L332-L339 # noqa
self.unsafe_callable = False
self.alters_data = False
def __getitem__(self, name):
# Propagate the undefined value if a caller accesses this as if it
# were a dictionary
return self
def __getattr__(self, name):
if name == "name" or _is_dunder_name(name):
raise AttributeError(
"'{}' object has no attribute '{}'".format(type(self).__name__, name)
)
self.name = name
return self.__class__(hint=self.hint, name=self.name)
def __call__(self, *args, **kwargs):
return self
def __reduce__(self):
raise UndefinedCompilationError(name=self.name, node=node)
return Undefined
NATIVE_FILTERS: Dict[str, Callable[[Any], Any]] = {
"as_text": TextMarker,
"as_bool": BoolMarker,
"as_native": NativeMarker,
"as_number": NumberMarker,
}
TEXT_FILTERS: Dict[str, Callable[[Any], Any]] = {
"as_text": lambda x: x,
"as_bool": lambda x: x,
"as_native": lambda x: x,
"as_number": lambda x: x,
}
def get_environment(
node=None,
capture_macros: bool = False,
native: bool = False,
) -> jinja2.Environment:
args: Dict[str, List[Union[str, Type[jinja2.ext.Extension]]]] = {
"extensions": ["jinja2.ext.do", "jinja2.ext.loopcontrols"]
}
if capture_macros:
args["undefined"] = create_undefined(node)
args["extensions"].append(MaterializationExtension)
args["extensions"].append(DocumentationExtension)
args["extensions"].append(TestExtension)
env_cls: Type[jinja2.Environment]
text_filter: Type
if native:
env_cls = NativeSandboxEnvironment
filters = NATIVE_FILTERS
else:
env_cls = MacroFuzzEnvironment
filters = TEXT_FILTERS
env = env_cls(**args)
env.filters.update(filters)
return env
@contextmanager
def catch_jinja(node=None) -> Iterator[None]:
try:
yield
except jinja2.exceptions.TemplateSyntaxError as e:
e.translated = False
raise CompilationError(str(e), node) from e
except jinja2.exceptions.UndefinedError as e:
raise UndefinedMacroError(str(e), node) from e
except CompilationError as exc:
exc.add_node(node)
raise
def parse(string):
with catch_jinja():
return get_environment().parse(str(string))
def get_template(
string: str,
ctx: Dict[str, Any],
node=None,
capture_macros: bool = False,
native: bool = False,
):
with catch_jinja(node):
env = get_environment(node, capture_macros, native=native)
template_source = str(string)
return env.from_string(template_source, globals=ctx)
def render_template(template, ctx: Dict[str, Any], node=None) -> str:
with catch_jinja(node):
return template.render(ctx)
def extract_toplevel_blocks(
data: str,
allowed_blocks: Optional[Set[str]] = None,
collect_raw_data: bool = True,
) -> List[Union[BlockData, BlockTag]]:
"""Extract the top-level blocks with matching block types from a jinja
file, with some special handling for block nesting.
:param data: The data to extract blocks from.
:param allowed_blocks: The names of the blocks to extract from the file.
They may not be nested within if/for blocks. If None, use the default
values.
:param collect_raw_data: If set, raw data between matched blocks will also
be part of the results, as `BlockData` objects. They have a
`block_type_name` field of `'__dbt_data'` and will never have a
`block_name`.
:return: A list of `BlockTag`s matching the allowed block types and (if
`collect_raw_data` is `True`) `BlockData` objects.
"""
return BlockIterator(data).lex_for_blocks(
allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data
)

View File

@@ -1,571 +0,0 @@
import dbt.common.exceptions.base
import errno
import fnmatch
import functools
import json
import os
import os.path
import re
import shutil
import stat
import subprocess
import sys
import tarfile
from pathlib import Path
from typing import Any, Callable, Dict, List, NoReturn, Optional, Tuple, Type, Union
import dbt.common.exceptions
import requests
from dbt.common.events.functions import fire_event
from dbt.common.events.types import (
SystemCouldNotWrite,
SystemExecutingCmd,
SystemStdOut,
SystemStdErr,
SystemReportReturnCode,
)
from dbt.common.exceptions import DbtInternalError
from dbt.common.utils.connection import connection_exception_retry
from pathspec import PathSpec # type: ignore
if sys.platform == "win32":
from ctypes import WinDLL, c_bool
else:
WinDLL = None
c_bool = None
def find_matching(
root_path: str,
relative_paths_to_search: List[str],
file_pattern: str,
ignore_spec: Optional[PathSpec] = None,
) -> List[Dict[str, Any]]:
"""
Given an absolute `root_path`, a list of relative paths to that
absolute root path (`relative_paths_to_search`), and a `file_pattern`
like '*.sql', returns information about the files. For example:
> find_matching('/root/path', ['models'], '*.sql')
[ { 'absolute_path': '/root/path/models/model_one.sql',
'relative_path': 'model_one.sql',
'searched_path': 'models' },
{ 'absolute_path': '/root/path/models/subdirectory/model_two.sql',
'relative_path': 'subdirectory/model_two.sql',
'searched_path': 'models' } ]
"""
matching = []
root_path = os.path.normpath(root_path)
regex = fnmatch.translate(file_pattern)
reobj = re.compile(regex, re.IGNORECASE)
for relative_path_to_search in relative_paths_to_search:
# potential speedup for ignore_spec
# if ignore_spec.matches(relative_path_to_search):
# continue
absolute_path_to_search = os.path.join(root_path, relative_path_to_search)
walk_results = os.walk(absolute_path_to_search)
for current_path, subdirectories, local_files in walk_results:
# potential speedup for ignore_spec
# relative_dir = os.path.relpath(current_path, root_path) + os.sep
# if ignore_spec.match(relative_dir):
# continue
for local_file in local_files:
absolute_path = os.path.join(current_path, local_file)
relative_path = os.path.relpath(absolute_path, absolute_path_to_search)
relative_path_to_root = os.path.join(relative_path_to_search, relative_path)
modification_time = os.path.getmtime(absolute_path)
if reobj.match(local_file) and (
not ignore_spec or not ignore_spec.match_file(relative_path_to_root)
):
matching.append(
{
"searched_path": relative_path_to_search,
"absolute_path": absolute_path,
"relative_path": relative_path,
"modification_time": modification_time,
}
)
return matching
def load_file_contents(path: str, strip: bool = True) -> str:
path = convert_path(path)
with open(path, "rb") as handle:
to_return = handle.read().decode("utf-8")
if strip:
to_return = to_return.strip()
return to_return
@functools.singledispatch
def make_directory(path=None) -> None:
"""
Make a directory and any intermediate directories that don't already
exist. This function handles the case where two threads try to create
a directory at once.
"""
raise DbtInternalError(f"Can not create directory from {type(path)} ")
@make_directory.register
def _(path: str) -> None:
path = convert_path(path)
if not os.path.exists(path):
# concurrent writes that try to create the same dir can fail
try:
os.makedirs(path)
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise e
@make_directory.register
def _(path: Path) -> None:
path.mkdir(parents=True, exist_ok=True)
def make_file(path: str, contents: str = "", overwrite: bool = False) -> bool:
"""
Make a file at `path` assuming that the directory it resides in already
exists. The file is saved with contents `contents`
"""
if overwrite or not os.path.exists(path):
path = convert_path(path)
with open(path, "w") as fh:
fh.write(contents)
return True
return False
def make_symlink(source: str, link_path: str) -> None:
"""
Create a symlink at `link_path` referring to `source`.
"""
if not supports_symlinks():
# TODO: why not import these at top?
raise dbt.common.exceptions.SymbolicLinkError()
os.symlink(source, link_path)
def supports_symlinks() -> bool:
return getattr(os, "symlink", None) is not None
def write_file(path: str, contents: str = "") -> bool:
path = convert_path(path)
try:
make_directory(os.path.dirname(path))
with open(path, "w", encoding="utf-8") as f:
f.write(str(contents))
except Exception as exc:
# note that you can't just catch FileNotFound, because sometimes
# windows apparently raises something else.
# It's also not sufficient to look at the path length, because
# sometimes windows fails to write paths that are less than the length
# limit. So on windows, suppress all errors that happen from writing
# to disk.
if os.name == "nt":
# sometimes we get a winerror of 3 which means the path was
# definitely too long, but other times we don't and it means the
# path was just probably too long. This is probably based on the
# windows/python version.
if getattr(exc, "winerror", 0) == 3:
reason = "Path was too long"
else:
reason = "Path was possibly too long"
# all our hard work and the path was still too long. Log and
# continue.
fire_event(SystemCouldNotWrite(path=path, reason=reason, exc=str(exc)))
else:
raise
return True
def read_json(path: str) -> Dict[str, Any]:
return json.loads(load_file_contents(path))
def write_json(path: str, data: Dict[str, Any]) -> bool:
return write_file(path, json.dumps(data, cls=dbt.common.utils.encoding.JSONEncoder))
def _windows_rmdir_readonly(func: Callable[[str], Any], path: str, exc: Tuple[Any, OSError, Any]):
exception_val = exc[1]
if exception_val.errno == errno.EACCES:
os.chmod(path, stat.S_IWUSR)
func(path)
else:
raise
def resolve_path_from_base(path_to_resolve: str, base_path: str) -> str:
"""
If path_to_resolve is a relative path, create an absolute path
with base_path as the base.
If path_to_resolve is an absolute path or a user path (~), just
resolve it to an absolute path and return.
"""
return os.path.abspath(os.path.join(base_path, os.path.expanduser(path_to_resolve)))
def rmdir(path: str) -> None:
"""
Recursively deletes a directory. Includes an error handler to retry with
different permissions on Windows. Otherwise, removing directories (eg.
cloned via git) can cause rmtree to throw a PermissionError exception
"""
path = convert_path(path)
if sys.platform == "win32":
onerror = _windows_rmdir_readonly
else:
onerror = None
shutil.rmtree(path, onerror=onerror)
def _win_prepare_path(path: str) -> str:
"""Given a windows path, prepare it for use by making sure it is absolute
and normalized.
"""
path = os.path.normpath(path)
# if a path starts with '\', splitdrive() on it will return '' for the
# drive, but the prefix requires a drive letter. So let's add the drive
# letter back in.
# Unless it starts with '\\'. In that case, the path is a UNC mount point
# and splitdrive will be fine.
if not path.startswith("\\\\") and path.startswith("\\"):
curdrive = os.path.splitdrive(os.getcwd())[0]
path = curdrive + path
# now our path is either an absolute UNC path or relative to the current
# directory. If it's relative, we need to make it absolute or the prefix
# won't work. `ntpath.abspath` allegedly doesn't always play nice with long
# paths, so do this instead.
if not os.path.splitdrive(path)[0]:
path = os.path.join(os.getcwd(), path)
return path
def _supports_long_paths() -> bool:
if sys.platform != "win32":
return True
# Eryk Sun says to use `WinDLL('ntdll')` instead of `windll.ntdll` because
# of pointer caching in a comment here:
# https://stackoverflow.com/a/35097999/11262881
# I don't know exaclty what he means, but I am inclined to believe him as
# he's pretty active on Python windows bugs!
else:
try:
dll = WinDLL("ntdll")
except OSError: # I don't think this happens? you need ntdll to run python
return False
# not all windows versions have it at all
if not hasattr(dll, "RtlAreLongPathsEnabled"):
return False
# tell windows we want to get back a single unsigned byte (a bool).
dll.RtlAreLongPathsEnabled.restype = c_bool
return dll.RtlAreLongPathsEnabled()
def convert_path(path: str) -> str:
"""Convert a path that dbt has, which might be >260 characters long, to one
that will be writable/readable on Windows.
On other platforms, this is a no-op.
"""
# some parts of python seem to append '\*.*' to strings, better safe than
# sorry.
if len(path) < 250:
return path
if _supports_long_paths():
return path
prefix = "\\\\?\\"
# Nothing to do
if path.startswith(prefix):
return path
path = _win_prepare_path(path)
# add the prefix. The check is just in case os.getcwd() does something
# unexpected - I believe this if-state should always be True though!
if not path.startswith(prefix):
path = prefix + path
return path
def remove_file(path: str) -> None:
path = convert_path(path)
os.remove(path)
def path_exists(path: str) -> bool:
path = convert_path(path)
return os.path.lexists(path)
def path_is_symlink(path: str) -> bool:
path = convert_path(path)
return os.path.islink(path)
def open_dir_cmd() -> str:
# https://docs.python.org/2/library/sys.html#sys.platform
if sys.platform == "win32":
return "start"
elif sys.platform == "darwin":
return "open"
else:
return "xdg-open"
def _handle_posix_cwd_error(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn:
if exc.errno == errno.ENOENT:
message = "Directory does not exist"
elif exc.errno == errno.EACCES:
message = "Current user cannot access directory, check permissions"
elif exc.errno == errno.ENOTDIR:
message = "Not a directory"
else:
message = "Unknown OSError: {} - cwd".format(str(exc))
raise dbt.common.exceptions.WorkingDirectoryError(cwd, cmd, message)
def _handle_posix_cmd_error(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn:
if exc.errno == errno.ENOENT:
message = "Could not find command, ensure it is in the user's PATH"
elif exc.errno == errno.EACCES:
message = "User does not have permissions for this command"
else:
message = "Unknown OSError: {} - cmd".format(str(exc))
raise dbt.common.exceptions.ExecutableError(cwd, cmd, message)
def _handle_posix_error(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn:
"""OSError handling for POSIX systems.
Some things that could happen to trigger an OSError:
- cwd could not exist
- exc.errno == ENOENT
- exc.filename == cwd
- cwd could have permissions that prevent the current user moving to it
- exc.errno == EACCES
- exc.filename == cwd
- cwd could exist but not be a directory
- exc.errno == ENOTDIR
- exc.filename == cwd
- cmd[0] could not exist
- exc.errno == ENOENT
- exc.filename == None(?)
- cmd[0] could exist but have permissions that prevents the current
user from executing it (executable bit not set for the user)
- exc.errno == EACCES
- exc.filename == None(?)
"""
if getattr(exc, "filename", None) == cwd:
_handle_posix_cwd_error(exc, cwd, cmd)
else:
_handle_posix_cmd_error(exc, cwd, cmd)
def _handle_windows_error(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn:
cls: Type[dbt.common.exceptions.DbtBaseException] = dbt.common.exceptions.base.CommandError
if exc.errno == errno.ENOENT:
message = (
"Could not find command, ensure it is in the user's PATH "
"and that the user has permissions to run it"
)
cls = dbt.common.exceptions.ExecutableError
elif exc.errno == errno.ENOEXEC:
message = "Command was not executable, ensure it is valid"
cls = dbt.common.exceptions.ExecutableError
elif exc.errno == errno.ENOTDIR:
message = (
"Unable to cd: path does not exist, user does not have"
" permissions, or not a directory"
)
cls = dbt.common.exceptions.WorkingDirectoryError
else:
message = 'Unknown error: {} (errno={}: "{}")'.format(
str(exc), exc.errno, errno.errorcode.get(exc.errno, "<Unknown!>")
)
raise cls(cwd, cmd, message)
def _interpret_oserror(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn:
"""Interpret an OSError exception and raise the appropriate dbt exception."""
if len(cmd) == 0:
raise dbt.common.exceptions.base.CommandError(cwd, cmd)
# all of these functions raise unconditionally
if os.name == "nt":
_handle_windows_error(exc, cwd, cmd)
else:
_handle_posix_error(exc, cwd, cmd)
# this should not be reachable, raise _something_ at least!
raise dbt.common.exceptions.DbtInternalError(
"Unhandled exception in _interpret_oserror: {}".format(exc)
)
def run_cmd(cwd: str, cmd: List[str], env: Optional[Dict[str, Any]] = None) -> Tuple[bytes, bytes]:
fire_event(SystemExecutingCmd(cmd=cmd))
if len(cmd) == 0:
raise dbt.common.exceptions.base.CommandError(cwd, cmd)
# the env argument replaces the environment entirely, which has exciting
# consequences on Windows! Do an update instead.
full_env = env
if env is not None:
full_env = os.environ.copy()
full_env.update(env)
try:
exe_pth = shutil.which(cmd[0])
if exe_pth:
cmd = [os.path.abspath(exe_pth)] + list(cmd[1:])
proc = subprocess.Popen(
cmd, cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=full_env
)
out, err = proc.communicate()
except OSError as exc:
_interpret_oserror(exc, cwd, cmd)
fire_event(SystemStdOut(bmsg=str(out)))
fire_event(SystemStdErr(bmsg=str(err)))
if proc.returncode != 0:
fire_event(SystemReportReturnCode(returncode=proc.returncode))
raise dbt.common.exceptions.CommandResultError(cwd, cmd, proc.returncode, out, err)
return out, err
def download_with_retries(
url: str, path: str, timeout: Optional[Union[float, tuple]] = None
) -> None:
download_fn = functools.partial(download, url, path, timeout)
connection_exception_retry(download_fn, 5)
def download(
url: str,
path: str,
timeout: Optional[Union[float, Tuple[float, float], Tuple[float, None]]] = None,
) -> None:
path = convert_path(path)
connection_timeout = timeout or float(os.getenv("DBT_HTTP_TIMEOUT", 10))
response = requests.get(url, timeout=connection_timeout)
with open(path, "wb") as handle:
for block in response.iter_content(1024 * 64):
handle.write(block)
def rename(from_path: str, to_path: str, force: bool = False) -> None:
from_path = convert_path(from_path)
to_path = convert_path(to_path)
is_symlink = path_is_symlink(to_path)
if os.path.exists(to_path) and force:
if is_symlink:
remove_file(to_path)
else:
rmdir(to_path)
shutil.move(from_path, to_path)
def untar_package(tar_path: str, dest_dir: str, rename_to: Optional[str] = None) -> None:
tar_path = convert_path(tar_path)
tar_dir_name = None
with tarfile.open(tar_path, "r:gz") as tarball:
tarball.extractall(dest_dir)
tar_dir_name = os.path.commonprefix(tarball.getnames())
if rename_to:
downloaded_path = os.path.join(dest_dir, tar_dir_name)
desired_path = os.path.join(dest_dir, rename_to)
dbt.common.clients.system.rename(downloaded_path, desired_path, force=True)
def chmod_and_retry(func, path, exc_info):
"""Define an error handler to pass to shutil.rmtree.
On Windows, when a file is marked read-only as git likes to do, rmtree will
fail. To handle that, on errors try to make the file writable.
We want to retry most operations here, but listdir is one that we know will
be useless.
"""
if func is os.listdir or os.name != "nt":
raise
os.chmod(path, stat.S_IREAD | stat.S_IWRITE)
# on error,this will raise.
func(path)
def _absnorm(path):
return os.path.normcase(os.path.abspath(path))
def move(src, dst):
"""A re-implementation of shutil.move that properly removes the source
directory on windows when it has read-only files in it and the move is
between two drives.
This is almost identical to the real shutil.move, except it, uses our rmtree
and skips handling non-windows OSes since the existing one works ok there.
"""
src = convert_path(src)
dst = convert_path(dst)
if os.name != "nt":
return shutil.move(src, dst)
if os.path.isdir(dst):
if _absnorm(src) == _absnorm(dst):
os.rename(src, dst)
return
dst = os.path.join(dst, os.path.basename(src.rstrip("/\\")))
if os.path.exists(dst):
raise EnvironmentError("Path '{}' already exists".format(dst))
try:
os.rename(src, dst)
except OSError:
# probably different drives
if os.path.isdir(src):
if _absnorm(dst + "\\").startswith(_absnorm(src + "\\")):
# dst is inside src
raise EnvironmentError(
"Cannot move a directory '{}' into itself '{}'".format(src, dst)
)
shutil.copytree(src, dst, symlinks=True)
rmtree(src)
else:
shutil.copy2(src, dst)
os.unlink(src)
def rmtree(path):
"""Recursively remove the path. On permissions errors on windows, try to remove
the read-only flag and try again.
"""
path = convert_path(path)
return shutil.rmtree(path, onerror=chmod_and_retry)

View File

@@ -1 +0,0 @@
SECRET_ENV_PREFIX = "DBT_ENV_SECRET_"

View File

@@ -1,259 +0,0 @@
# necessary for annotating constructors
from __future__ import annotations
from dataclasses import dataclass, Field
from itertools import chain
from typing import Callable, Dict, Any, List, TypeVar, Type
from dbt.common.contracts.config.metadata import Metadata
from dbt.common.exceptions import CompilationError, DbtInternalError
from dbt.common.contracts.config.properties import AdditionalPropertiesAllowed
from dbt.common.contracts.util import Replaceable
T = TypeVar("T", bound="BaseConfig")
@dataclass
class BaseConfig(AdditionalPropertiesAllowed, Replaceable):
# enable syntax like: config['key']
def __getitem__(self, key):
return self.get(key)
# like doing 'get' on a dictionary
def get(self, key, default=None):
if hasattr(self, key):
return getattr(self, key)
elif key in self._extra:
return self._extra[key]
else:
return default
# enable syntax like: config['key'] = value
def __setitem__(self, key, value):
if hasattr(self, key):
setattr(self, key, value)
else:
self._extra[key] = value
def __delitem__(self, key):
if hasattr(self, key):
msg = (
'Error, tried to delete config key "{}": Cannot delete ' "built-in keys"
).format(key)
raise CompilationError(msg)
else:
del self._extra[key]
def _content_iterator(self, include_condition: Callable[[Field], bool]):
seen = set()
for fld, _ in self._get_fields():
seen.add(fld.name)
if include_condition(fld):
yield fld.name
for key in self._extra:
if key not in seen:
seen.add(key)
yield key
def __iter__(self):
yield from self._content_iterator(include_condition=lambda f: True)
def __len__(self):
return len(self._get_fields()) + len(self._extra)
@staticmethod
def compare_key(
unrendered: Dict[str, Any],
other: Dict[str, Any],
key: str,
) -> bool:
if key not in unrendered and key not in other:
return True
elif key not in unrendered and key in other:
return False
elif key in unrendered and key not in other:
return False
else:
return unrendered[key] == other[key]
@classmethod
def same_contents(cls, unrendered: Dict[str, Any], other: Dict[str, Any]) -> bool:
"""This is like __eq__, except it ignores some fields."""
seen = set()
for fld, target_name in cls._get_fields():
key = target_name
seen.add(key)
if CompareBehavior.should_include(fld):
if not cls.compare_key(unrendered, other, key):
return False
for key in chain(unrendered, other):
if key not in seen:
seen.add(key)
if not cls.compare_key(unrendered, other, key):
return False
return True
# This is used in 'add_config_call' to create the combined config_call_dict.
# 'meta' moved here from node
mergebehavior = {
"append": ["pre-hook", "pre_hook", "post-hook", "post_hook", "tags"],
"update": [
"quoting",
"column_types",
"meta",
"docs",
"contract",
],
"dict_key_append": ["grants"],
}
@classmethod
def _merge_dicts(cls, src: Dict[str, Any], data: Dict[str, Any]) -> Dict[str, Any]:
"""Find all the items in data that match a target_field on this class,
and merge them with the data found in `src` for target_field, using the
field's specified merge behavior. Matching items will be removed from
`data` (but _not_ `src`!).
Returns a dict with the merge results.
That means this method mutates its input! Any remaining values in data
were not merged.
"""
result = {}
for fld, target_field in cls._get_fields():
if target_field not in data:
continue
data_attr = data.pop(target_field)
if target_field not in src:
result[target_field] = data_attr
continue
merge_behavior = MergeBehavior.from_field(fld)
self_attr = src[target_field]
result[target_field] = _merge_field_value(
merge_behavior=merge_behavior,
self_value=self_attr,
other_value=data_attr,
)
return result
def update_from(
self: T, data: Dict[str, Any], config_cls: Type[BaseConfig], validate: bool = True
) -> T:
"""Given a dict of keys, update the current config from them, validate
it, and return a new config with the updated values
"""
dct = self.to_dict(omit_none=False)
self_merged = self._merge_dicts(dct, data)
dct.update(self_merged)
adapter_merged = config_cls._merge_dicts(dct, data)
dct.update(adapter_merged)
# any remaining fields must be "clobber"
dct.update(data)
# any validation failures must have come from the update
if validate:
self.validate(dct)
return self.from_dict(dct)
def finalize_and_validate(self: T) -> T:
dct = self.to_dict(omit_none=False)
self.validate(dct)
return self.from_dict(dct)
class MergeBehavior(Metadata):
Append = 1
Update = 2
Clobber = 3
DictKeyAppend = 4
@classmethod
def default_field(cls) -> "MergeBehavior":
return cls.Clobber
@classmethod
def metadata_key(cls) -> str:
return "merge"
class CompareBehavior(Metadata):
Include = 1
Exclude = 2
@classmethod
def default_field(cls) -> "CompareBehavior":
return cls.Include
@classmethod
def metadata_key(cls) -> str:
return "compare"
@classmethod
def should_include(cls, fld: Field) -> bool:
return cls.from_field(fld) == cls.Include
def _listify(value: Any) -> List:
if isinstance(value, list):
return value[:]
else:
return [value]
# There are two versions of this code. The one here is for config
# objects, the one in _add_config_call in core context_config.py is for
# config_call_dict dictionaries.
def _merge_field_value(
merge_behavior: MergeBehavior,
self_value: Any,
other_value: Any,
):
if merge_behavior == MergeBehavior.Clobber:
return other_value
elif merge_behavior == MergeBehavior.Append:
return _listify(self_value) + _listify(other_value)
elif merge_behavior == MergeBehavior.Update:
if not isinstance(self_value, dict):
raise DbtInternalError(f"expected dict, got {self_value}")
if not isinstance(other_value, dict):
raise DbtInternalError(f"expected dict, got {other_value}")
value = self_value.copy()
value.update(other_value)
return value
elif merge_behavior == MergeBehavior.DictKeyAppend:
if not isinstance(self_value, dict):
raise DbtInternalError(f"expected dict, got {self_value}")
if not isinstance(other_value, dict):
raise DbtInternalError(f"expected dict, got {other_value}")
new_dict = {}
for key in self_value.keys():
new_dict[key] = _listify(self_value[key])
for key in other_value.keys():
extend = False
new_key = key
# This might start with a +, to indicate we should extend the list
# instead of just clobbering it
if new_key.startswith("+"):
new_key = key.lstrip("+")
extend = True
if new_key in new_dict and extend:
# extend the list
value = other_value[key]
new_dict[new_key].extend(_listify(value))
else:
# clobber the list
new_dict[new_key] = _listify(other_value[key])
return new_dict
else:
raise DbtInternalError(f"Got an invalid merge_behavior: {merge_behavior}")

View File

@@ -1,11 +0,0 @@
from dbt.common.dataclass_schema import StrEnum
class OnConfigurationChangeOption(StrEnum):
Apply = "apply"
Continue = "continue"
Fail = "fail"
@classmethod
def default(cls) -> "OnConfigurationChangeOption":
return cls.Apply

View File

@@ -1,69 +0,0 @@
from dataclasses import Field
from enum import Enum
from typing import TypeVar, Type, Optional, Dict, Any
from dbt.common.exceptions import DbtInternalError
M = TypeVar("M", bound="Metadata")
class Metadata(Enum):
@classmethod
def from_field(cls: Type[M], fld: Field) -> M:
default = cls.default_field()
key = cls.metadata_key()
return _get_meta_value(cls, fld, key, default)
def meta(self, existing: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
key = self.metadata_key()
return _set_meta_value(self, key, existing)
@classmethod
def default_field(cls) -> "Metadata":
raise NotImplementedError("Not implemented")
@classmethod
def metadata_key(cls) -> str:
raise NotImplementedError("Not implemented")
def _get_meta_value(cls: Type[M], fld: Field, key: str, default: Any) -> M:
# a metadata field might exist. If it does, it might have a matching key.
# If it has both, make sure the value is valid and return it. If it
# doesn't, return the default.
if fld.metadata:
value = fld.metadata.get(key, default)
else:
value = default
try:
return cls(value)
except ValueError as exc:
raise DbtInternalError(f"Invalid {cls} value: {value}") from exc
def _set_meta_value(obj: M, key: str, existing: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
if existing is None:
result = {}
else:
result = existing.copy()
result.update({key: obj})
return result
class ShowBehavior(Metadata):
Show = 1
Hide = 2
@classmethod
def default_field(cls) -> "ShowBehavior":
return cls.Show
@classmethod
def metadata_key(cls) -> str:
return "show_hide"
@classmethod
def should_show(cls, fld: Field) -> bool:
return cls.from_field(fld) == cls.Show

View File

@@ -1,63 +0,0 @@
from dataclasses import dataclass, field
from typing import Dict, Any
from dbt.common.dataclass_schema import ExtensibleDbtClassMixin
class AdditionalPropertiesMixin:
"""Make this class an extensible property.
The underlying class definition must include a type definition for a field
named '_extra' that is of type `Dict[str, Any]`.
"""
ADDITIONAL_PROPERTIES = True
# This takes attributes in the dictionary that are
# not in the class definitions and puts them in an
# _extra dict in the class
@classmethod
def __pre_deserialize__(cls, data):
# dir() did not work because fields with
# metadata settings are not found
# The original version of this would create the
# object first and then update extra with the
# extra keys, but that won't work here, so
# we're copying the dict so we don't insert the
# _extra in the original data. This also requires
# that Mashumaro actually build the '_extra' field
cls_keys = cls._get_field_names()
new_dict = {}
for key, value in data.items():
# The pre-hook/post-hook mess hasn't been converted yet... That happens in
# the super().__pre_deserialize__ below...
if key not in cls_keys and key not in ["_extra", "pre-hook", "post-hook"]:
if "_extra" not in new_dict:
new_dict["_extra"] = {}
new_dict["_extra"][key] = value
else:
new_dict[key] = value
data = new_dict
data = super().__pre_deserialize__(data)
return data
def __post_serialize__(self, dct):
data = super().__post_serialize__(dct)
data.update(self.extra)
if "_extra" in data:
del data["_extra"]
return data
def replace(self, **kwargs):
dct = self.to_dict(omit_none=False)
dct.update(kwargs)
return self.from_dict(dct)
@property
def extra(self):
return self._extra
@dataclass
class AdditionalPropertiesAllowed(AdditionalPropertiesMixin, ExtensibleDbtClassMixin):
_extra: Dict[str, Any] = field(default_factory=dict)

View File

@@ -1,43 +0,0 @@
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional, List
from dbt.common.dataclass_schema import dbtClassMixin
class ConstraintType(str, Enum):
check = "check"
not_null = "not_null"
unique = "unique"
primary_key = "primary_key"
foreign_key = "foreign_key"
custom = "custom"
@classmethod
def is_valid(cls, item) -> bool:
try:
cls(item)
except ValueError:
return False
return True
@dataclass
class ColumnLevelConstraint(dbtClassMixin):
type: ConstraintType
name: Optional[str] = None
# expression is a user-provided field that will depend on the constraint type.
# It could be a predicate (check type), or a sequence sql keywords (e.g. unique type),
# so the vague naming of 'expression' is intended to capture this range.
expression: Optional[str] = None
warn_unenforced: bool = (
True # Warn if constraint cannot be enforced by platform but will be in DDL
)
warn_unsupported: bool = (
True # Warn if constraint is not supported by the platform and won't be in DDL
)
@dataclass
class ModelLevelConstraint(ColumnLevelConstraint):
columns: List[str] = field(default_factory=list)

View File

@@ -1,7 +0,0 @@
import dataclasses
# TODO: remove from dbt.contracts.util:: Replaceable + references
class Replaceable:
def replace(self, **kwargs):
return dataclasses.replace(self, **kwargs)

View File

@@ -1,165 +0,0 @@
from typing import ClassVar, cast, get_type_hints, List, Tuple, Dict, Any, Optional
import re
import jsonschema
from dataclasses import fields, Field
from enum import Enum
from datetime import datetime
from dateutil.parser import parse
# type: ignore
from mashumaro import DataClassDictMixin
from mashumaro.config import TO_DICT_ADD_OMIT_NONE_FLAG, BaseConfig as MashBaseConfig
from mashumaro.types import SerializableType, SerializationStrategy
from mashumaro.jsonschema import build_json_schema
import functools
class ValidationError(jsonschema.ValidationError):
pass
class DateTimeSerialization(SerializationStrategy):
def serialize(self, value) -> str:
out = value.isoformat()
# Assume UTC if timezone is missing
if value.tzinfo is None:
out += "Z"
return out
def deserialize(self, value) -> datetime:
return value if isinstance(value, datetime) else parse(cast(str, value))
class dbtMashConfig(MashBaseConfig):
code_generation_options = [
TO_DICT_ADD_OMIT_NONE_FLAG,
]
serialization_strategy = {
datetime: DateTimeSerialization(),
}
json_schema = {
"additionalProperties": False,
}
serialize_by_alias = True
# This class pulls in DataClassDictMixin from Mashumaro. The 'to_dict'
# and 'from_dict' methods come from Mashumaro.
class dbtClassMixin(DataClassDictMixin):
"""The Mixin adds methods to generate a JSON schema and
convert to and from JSON encodable dicts with validation
against the schema
"""
_mapped_fields: ClassVar[Optional[Dict[Any, List[Tuple[Field, str]]]]] = None
# Config class used by Mashumaro
class Config(dbtMashConfig):
pass
ADDITIONAL_PROPERTIES: ClassVar[bool] = False
# This is called by the mashumaro from_dict in order to handle
# nested classes. We no longer do any munging here, but leaving here
# so that subclasses can leave super() in place for possible future needs.
@classmethod
def __pre_deserialize__(cls, data):
return data
# This is called by the mashumaro to_dict in order to handle
# nested classes. We no longer do any munging here, but leaving here
# so that subclasses can leave super() in place for possible future needs.
def __post_serialize__(self, data):
return data
@classmethod
@functools.lru_cache
def json_schema(cls):
json_schema_obj = build_json_schema(cls)
json_schema = json_schema_obj.to_dict()
return json_schema
@classmethod
def validate(cls, data):
json_schema = cls.json_schema()
validator = jsonschema.Draft7Validator(json_schema)
error = next(iter(validator.iter_errors(data)), None)
if error is not None:
raise ValidationError.create_from(error) from error
# This method was copied from hologram. Used in model_config.py and relation.py
@classmethod
def _get_fields(cls) -> List[Tuple[Field, str]]:
if cls._mapped_fields is None:
cls._mapped_fields = {}
if cls.__name__ not in cls._mapped_fields:
mapped_fields = []
type_hints = get_type_hints(cls)
for f in fields(cls): # type: ignore
# Skip internal fields
if f.name.startswith("_"):
continue
# Note fields() doesn't resolve forward refs
f.type = type_hints[f.name]
# hologram used the "field_mapping" here, but we use the
# the field's metadata "alias". Since this method is mainly
# just used in merging config dicts, it mostly applies to
# pre-hook and post-hook.
field_name = f.metadata.get("alias", f.name)
mapped_fields.append((f, field_name))
cls._mapped_fields[cls.__name__] = mapped_fields
return cls._mapped_fields[cls.__name__]
# copied from hologram. Used in tests
@classmethod
def _get_field_names(cls):
return [element[1] for element in cls._get_fields()]
class ValidatedStringMixin(str, SerializableType):
ValidationRegex = ""
@classmethod
def _deserialize(cls, value: str) -> "ValidatedStringMixin":
cls.validate(value)
return ValidatedStringMixin(value)
def _serialize(self) -> str:
return str(self)
@classmethod
def validate(cls, value):
res = re.match(cls.ValidationRegex, value)
if res is None:
raise ValidationError(f"Invalid value: {value}") # TODO
# These classes must be in this order or it doesn't work
class StrEnum(str, SerializableType, Enum):
def __str__(self):
return self.value
# https://docs.python.org/3.6/library/enum.html#using-automatic-values
def _generate_next_value_(name, *_):
return name
def _serialize(self) -> str:
return self.value
@classmethod
def _deserialize(cls, value: str):
return cls(value)
class ExtensibleDbtClassMixin(dbtClassMixin):
ADDITIONAL_PROPERTIES = True
class Config(dbtMashConfig):
json_schema = {
"additionalProperties": True,
}

View File

@@ -1,41 +0,0 @@
# Events Module
The Events module is responsible for communicating internal dbt structures into a consumable interface. Because the "event" classes are based entirely on protobuf definitions, the interface is really clearly defined, whether or not protobufs are used to consume it. We use Betterproto for compiling the protobuf message definitions into Python classes.
# Using the Events Module
The event module provides types that represent what is happening in dbt in `events.types`. These types are intended to represent an exhaustive list of all things happening within dbt that will need to be logged, streamed, or printed. To fire an event, `events.functions::fire_event` is the entry point to the module from everywhere in dbt.
# Logging
When events are processed via `fire_event`, nearly everything is logged. Whether or not the user has enabled the debug flag, all debug messages are still logged to the file. However, some events are particularly time consuming to construct because they return a huge amount of data. Today, the only messages in this category are cache events and are only logged if the `--log-cache-events` flag is on. This is important because these messages should not be created unless they are going to be logged, because they cause a noticable performance degredation. These events use a "fire_event_if" functions.
# Adding a New Event
* Add a new message in types.proto, and a second message with the same name + "Msg". The "Msg" message should have two fields, an "info" field of EventInfo, and a "data" field referring to the message name without "Msg"
* run the protoc compiler to update types_pb2.py: make proto_types
* Add a wrapping class in core/dbt/event/types.py with a Level superclass plus code and message methods
* Add the class to tests/unit/test_events.py
We have switched from using betterproto to using google protobuf, because of a lack of support for Struct fields in betterproto.
The google protobuf interface is janky and very much non-Pythonic. The "generated" classes in types_pb2.py do not resemble regular Python classes. They do not have normal constructors; they can only be constructed empty. They can be "filled" by setting fields individually or using a json_format method like ParseDict. We have wrapped the logging events with a class (in types.py) which allows using a constructor -- keywords only, no positional parameters.
## Required for Every Event
- a method `code`, that's unique across events
- assign a log level by using the Level mixin: `DebugLevel`, `InfoLevel`, `WarnLevel`, or `ErrorLevel`
- a message()
Example
```
class PartialParsingDeletedExposure(DebugLevel):
def code(self):
return "I049"
def message(self) -> str:
return f"Partial parsing: deleted exposure {self.unique_id}"
```
## Compiling types.proto
After adding a new message in `types.proto`, either:
- In the repository root directory: `make proto_types`
- In the `core/dbt/common/events` directory: `protoc -I=. --python_out=. types.proto`

View File

@@ -1,9 +0,0 @@
from dbt.common.events.base_types import EventLevel
from dbt.common.events.event_manager_client import get_event_manager
from dbt.common.events.functions import get_stdout_config
from dbt.common.events.logger import LineFormat
# make sure event manager starts with a logger
get_event_manager().add_logger(
get_stdout_config(LineFormat.PlainText, True, EventLevel.INFO, False)
)

View File

@@ -1,185 +0,0 @@
from enum import Enum
import os
import threading
from dbt.common.events import types_pb2
import sys
from google.protobuf.json_format import ParseDict, MessageToDict, MessageToJson
from google.protobuf.message import Message
from dbt.common.events.helpers import get_json_string_utcnow
from typing import Optional
from dbt.common.invocation import get_invocation_id
if sys.version_info >= (3, 8):
from typing import Protocol
else:
from typing_extensions import Protocol
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# These base types define the _required structure_ for the concrete event #
# types defined in types.py #
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
def get_global_metadata_vars() -> dict:
from dbt.common.events.functions import get_metadata_vars
return get_metadata_vars()
# exactly one pid per concrete event
def get_pid() -> int:
return os.getpid()
# in theory threads can change, so we don't cache them.
def get_thread_name() -> str:
return threading.current_thread().name
# EventLevel is an Enum, but mixing in the 'str' type is suggested in the Python
# documentation, and provides support for json conversion, which fails otherwise.
class EventLevel(str, Enum):
DEBUG = "debug"
TEST = "test"
INFO = "info"
WARN = "warn"
ERROR = "error"
class BaseEvent:
"""BaseEvent for proto message generated python events"""
PROTO_TYPES_MODULE = types_pb2
def __init__(self, *args, **kwargs) -> None:
class_name = type(self).__name__
msg_cls = getattr(self.PROTO_TYPES_MODULE, class_name)
if class_name == "Formatting" and len(args) > 0:
kwargs["msg"] = args[0]
args = ()
assert (
len(args) == 0
), f"[{class_name}] Don't use positional arguments when constructing logging events"
if "base_msg" in kwargs:
kwargs["base_msg"] = str(kwargs["base_msg"])
if "msg" in kwargs:
kwargs["msg"] = str(kwargs["msg"])
try:
self.pb_msg = ParseDict(kwargs, msg_cls())
except Exception:
# Imports need to be here to avoid circular imports
from dbt.common.events.types import Note
from dbt.common.events.functions import fire_event
error_msg = f"[{class_name}]: Unable to parse dict {kwargs}"
# If we're testing throw an error so that we notice failures
if "pytest" in sys.modules:
raise Exception(error_msg)
else:
fire_event(Note(msg=error_msg), level=EventLevel.WARN)
self.pb_msg = msg_cls()
def __setattr__(self, key, value):
if key == "pb_msg":
super().__setattr__(key, value)
else:
super().__getattribute__("pb_msg").__setattr__(key, value)
def __getattr__(self, key):
if key == "pb_msg":
return super().__getattribute__(key)
else:
return super().__getattribute__("pb_msg").__getattribute__(key)
def to_dict(self):
return MessageToDict(
self.pb_msg, preserving_proto_field_name=True, including_default_value_fields=True
)
def to_json(self) -> str:
return MessageToJson(
self.pb_msg,
preserving_proto_field_name=True,
including_default_value_fields=True,
indent=None,
)
def level_tag(self) -> EventLevel:
return EventLevel.DEBUG
def message(self) -> str:
raise Exception("message() not implemented for event")
def code(self) -> str:
raise Exception("code() not implemented for event")
class EventInfo(Protocol):
level: str
name: str
ts: str
code: str
class EventMsg(Protocol):
info: EventInfo
data: Message
def msg_from_base_event(event: BaseEvent, level: Optional[EventLevel] = None):
msg_class_name = f"{type(event).__name__}Msg"
msg_cls = getattr(event.PROTO_TYPES_MODULE, msg_class_name)
# level in EventInfo must be a string, not an EventLevel
msg_level: str = level.value if level else event.level_tag().value
assert msg_level is not None
event_info = {
"level": msg_level,
"msg": event.message(),
"invocation_id": get_invocation_id(),
"extra": get_global_metadata_vars(),
"ts": get_json_string_utcnow(),
"pid": get_pid(),
"thread": get_thread_name(),
"code": event.code(),
"name": type(event).__name__,
}
new_event = ParseDict({"info": event_info}, msg_cls())
new_event.data.CopyFrom(event.pb_msg)
return new_event
# DynamicLevel requires that the level be supplied on the
# event construction call using the "info" function from functions.py
class DynamicLevel(BaseEvent):
pass
class TestLevel(BaseEvent):
__test__ = False
def level_tag(self) -> EventLevel:
return EventLevel.TEST
class DebugLevel(BaseEvent):
def level_tag(self) -> EventLevel:
return EventLevel.DEBUG
class InfoLevel(BaseEvent):
def level_tag(self) -> EventLevel:
return EventLevel.INFO
class WarnLevel(BaseEvent):
def level_tag(self) -> EventLevel:
return EventLevel.WARN
class ErrorLevel(BaseEvent):
def level_tag(self) -> EventLevel:
return EventLevel.ERROR

View File

@@ -1,114 +0,0 @@
import contextlib
import contextvars
from typing import Any, Generator, Mapping, Dict
LOG_PREFIX = "log_"
TASK_PREFIX = "task_"
_context_vars: Dict[str, contextvars.ContextVar] = {}
def get_contextvars(prefix: str) -> Dict[str, Any]:
rv = {}
ctx = contextvars.copy_context()
prefix_len = len(prefix)
for k in ctx:
if k.name.startswith(prefix) and ctx[k] is not Ellipsis:
rv[k.name[prefix_len:]] = ctx[k]
return rv
def get_node_info():
cvars = get_contextvars(LOG_PREFIX)
if "node_info" in cvars:
return cvars["node_info"]
else:
return {}
def get_project_root():
cvars = get_contextvars(TASK_PREFIX)
if "project_root" in cvars:
return cvars["project_root"]
else:
return None
def clear_contextvars(prefix: str) -> None:
ctx = contextvars.copy_context()
for k in ctx:
if k.name.startswith(prefix):
k.set(Ellipsis)
def set_log_contextvars(**kwargs: Any) -> Mapping[str, contextvars.Token]:
return set_contextvars(LOG_PREFIX, **kwargs)
def set_task_contextvars(**kwargs: Any) -> Mapping[str, contextvars.Token]:
return set_contextvars(TASK_PREFIX, **kwargs)
# put keys and values into context. Returns the contextvar.Token mapping
# Save and pass to reset_contextvars
def set_contextvars(prefix: str, **kwargs: Any) -> Mapping[str, contextvars.Token]:
cvar_tokens = {}
for k, v in kwargs.items():
log_key = f"{prefix}{k}"
try:
var = _context_vars[log_key]
except KeyError:
var = contextvars.ContextVar(log_key, default=Ellipsis)
_context_vars[log_key] = var
cvar_tokens[k] = var.set(v)
return cvar_tokens
# reset by Tokens
def reset_contextvars(prefix: str, **kwargs: contextvars.Token) -> None:
for k, v in kwargs.items():
log_key = f"{prefix}{k}"
var = _context_vars[log_key]
var.reset(v)
# remove from contextvars
def unset_contextvars(prefix: str, *keys: str) -> None:
for k in keys:
if k in _context_vars:
log_key = f"{prefix}{k}"
_context_vars[log_key].set(Ellipsis)
# Context manager or decorator to set and unset the context vars
@contextlib.contextmanager
def log_contextvars(**kwargs: Any) -> Generator[None, None, None]:
context = get_contextvars(LOG_PREFIX)
saved = {k: context[k] for k in context.keys() & kwargs.keys()}
set_contextvars(LOG_PREFIX, **kwargs)
try:
yield
finally:
unset_contextvars(LOG_PREFIX, *kwargs.keys())
set_contextvars(LOG_PREFIX, **saved)
# Context manager for earlier in task.run
@contextlib.contextmanager
def task_contextvars(**kwargs: Any) -> Generator[None, None, None]:
context = get_contextvars(TASK_PREFIX)
saved = {k: context[k] for k in context.keys() & kwargs.keys()}
set_contextvars(TASK_PREFIX, **kwargs)
try:
yield
finally:
unset_contextvars(TASK_PREFIX, *kwargs.keys())
set_contextvars(TASK_PREFIX, **saved)

View File

@@ -1,40 +0,0 @@
import logging
from typing import Union
from dbt.common.events.base_types import EventLevel
from dbt.common.events.types import Note
from dbt.common.events.event_manager import IEventManager
_log_level_to_event_level_map = {
logging.DEBUG: EventLevel.DEBUG,
logging.INFO: EventLevel.INFO,
logging.WARN: EventLevel.WARN,
logging.WARNING: EventLevel.WARN,
logging.ERROR: EventLevel.ERROR,
logging.CRITICAL: EventLevel.ERROR,
}
class DbtEventLoggingHandler(logging.Handler):
"""A logging handler that wraps the EventManager
This allows non-dbt packages to log to the dbt event stream.
All logs are generated as "Note" events.
"""
def __init__(self, event_manager: IEventManager, level):
super().__init__(level)
self.event_manager = event_manager
def emit(self, record: logging.LogRecord):
note = Note(msg=record.getMessage())
level = _log_level_to_event_level_map[record.levelno]
self.event_manager.fire_event(e=note, level=level)
def set_package_logging(package_name: str, log_level: Union[str, int], event_mgr: IEventManager):
"""Attach dbt's custom logging handler to the package's logger."""
log = logging.getLogger(package_name)
log.setLevel(log_level)
event_handler = DbtEventLoggingHandler(event_manager=event_mgr, level=log_level)
log.addHandler(event_handler)

View File

@@ -1,66 +0,0 @@
import os
import traceback
from typing import Callable, List, Optional, Protocol, Tuple
from dbt.common.events.base_types import BaseEvent, EventLevel, msg_from_base_event, EventMsg
from dbt.common.events.logger import LoggerConfig, _Logger, _TextLogger, _JsonLogger, LineFormat
class EventManager:
def __init__(self) -> None:
self.loggers: List[_Logger] = []
self.callbacks: List[Callable[[EventMsg], None]] = []
def fire_event(self, e: BaseEvent, level: Optional[EventLevel] = None) -> None:
msg = msg_from_base_event(e, level=level)
if os.environ.get("DBT_TEST_BINARY_SERIALIZATION"):
print(f"--- {msg.info.name}")
try:
msg.SerializeToString()
except Exception as exc:
raise Exception(
f"{msg.info.name} is not serializable to binary. Originating exception: {exc}, {traceback.format_exc()}"
)
for logger in self.loggers:
if logger.filter(msg): # type: ignore
logger.write_line(msg)
for callback in self.callbacks:
callback(msg)
def add_logger(self, config: LoggerConfig) -> None:
logger = (
_JsonLogger(config) if config.line_format == LineFormat.Json else _TextLogger(config)
)
self.loggers.append(logger)
def flush(self) -> None:
for logger in self.loggers:
logger.flush()
class IEventManager(Protocol):
callbacks: List[Callable[[EventMsg], None]]
loggers: List[_Logger]
def fire_event(self, e: BaseEvent, level: Optional[EventLevel] = None) -> None:
...
def add_logger(self, config: LoggerConfig) -> None:
...
class TestEventManager(IEventManager):
__test__ = False
def __init__(self) -> None:
self.event_history: List[Tuple[BaseEvent, Optional[EventLevel]]] = []
self.loggers = []
def fire_event(self, e: BaseEvent, level: Optional[EventLevel] = None) -> None:
self.event_history.append((e, level))
def add_logger(self, config: LoggerConfig) -> None:
raise NotImplementedError()

View File

@@ -1,29 +0,0 @@
# Since dbt-rpc does not do its own log setup, and since some events can
# currently fire before logs can be configured by setup_event_logger(), we
# create a default configuration with default settings and no file output.
from dbt.common.events.event_manager import IEventManager, EventManager
_EVENT_MANAGER: IEventManager = EventManager()
def get_event_manager() -> IEventManager:
global _EVENT_MANAGER
return _EVENT_MANAGER
def add_logger_to_manager(logger) -> None:
global _EVENT_MANAGER
_EVENT_MANAGER.add_logger(logger)
def ctx_set_event_manager(event_manager: IEventManager) -> None:
global _EVENT_MANAGER
_EVENT_MANAGER = event_manager
def cleanup_event_logger() -> None:
# Reset to a no-op manager to release streams associated with logs. This is
# especially important for tests, since pytest replaces the stdout stream
# during test runs, and closes the stream after the test is over.
_EVENT_MANAGER.loggers.clear()
_EVENT_MANAGER.callbacks.clear()

View File

@@ -1,56 +0,0 @@
from dbt.common import ui
from typing import Optional, Union
from datetime import datetime
from dbt.common.events.interfaces import LoggableDbtObject
def format_fancy_output_line(
msg: str,
status: str,
index: Optional[int],
total: Optional[int],
execution_time: Optional[float] = None,
truncate: bool = False,
) -> str:
if index is None or total is None:
progress = ""
else:
progress = "{} of {} ".format(index, total)
prefix = "{progress}{message} ".format(progress=progress, message=msg)
truncate_width = ui.printer_width() - 3
justified = prefix.ljust(ui.printer_width(), ".")
if truncate and len(justified) > truncate_width:
justified = justified[:truncate_width] + "..."
if execution_time is None:
status_time = ""
else:
status_time = " in {execution_time:0.2f}s".format(execution_time=execution_time)
output = "{justified} [{status}{status_time}]".format(
justified=justified, status=status, status_time=status_time
)
return output
def _pluralize(string: Union[str, LoggableDbtObject]) -> str:
if isinstance(string, LoggableDbtObject):
return string.pluralize()
else:
return f"{string}s"
def pluralize(count, string: Union[str, LoggableDbtObject]) -> str:
pluralized: str = str(string)
if count != 1:
pluralized = _pluralize(string)
return f"{count} {pluralized}"
def timestamp_to_datetime_string(ts) -> str:
timestamp_dt = datetime.fromtimestamp(ts.seconds + ts.nanos / 1e9)
return timestamp_dt.strftime("%H:%M:%S.%f")

View File

@@ -1,162 +0,0 @@
from pathlib import Path
from dbt.common.events.event_manager_client import get_event_manager
from dbt.common.invocation import get_invocation_id
from dbt.common.helper_types import WarnErrorOptions
from dbt.common.utils import ForgivingJSONEncoder
from dbt.common.events.base_types import BaseEvent, EventLevel, EventMsg
from dbt.common.events.logger import LoggerConfig, LineFormat
from dbt.common.exceptions import scrub_secrets, env_secrets
from dbt.common.events.types import Note
from functools import partial
import json
import os
import sys
from typing import Callable, Dict, Optional, TextIO, Union
from google.protobuf.json_format import MessageToDict
LOG_VERSION = 3
metadata_vars: Optional[Dict[str, str]] = None
_METADATA_ENV_PREFIX = "DBT_ENV_CUSTOM_ENV_"
WARN_ERROR_OPTIONS = WarnErrorOptions(include=[], exclude=[])
WARN_ERROR = False
# This global, and the following two functions for capturing stdout logs are
# an unpleasant hack we intend to remove as part of API-ification. The GitHub
# issue #6350 was opened for that work.
CAPTURE_STREAM: Optional[TextIO] = None
def stdout_filter(
log_cache_events: bool,
line_format: LineFormat,
msg: EventMsg,
) -> bool:
return msg.info.name not in ["CacheAction", "CacheDumpGraph"] or log_cache_events
def get_stdout_config(
line_format: LineFormat,
use_colors: bool,
level: EventLevel,
log_cache_events: bool,
) -> LoggerConfig:
return LoggerConfig(
name="stdout_log",
level=level,
use_colors=use_colors,
line_format=line_format,
scrubber=env_scrubber,
filter=partial(
stdout_filter,
log_cache_events,
line_format,
),
invocation_id=get_invocation_id(),
output_stream=sys.stdout,
)
def make_log_dir_if_missing(log_path: Union[Path, str]) -> None:
if isinstance(log_path, str):
log_path = Path(log_path)
log_path.mkdir(parents=True, exist_ok=True)
def env_scrubber(msg: str) -> str:
return scrub_secrets(msg, env_secrets())
# used for integration tests
def capture_stdout_logs(stream: TextIO) -> None:
global CAPTURE_STREAM
CAPTURE_STREAM = stream
def stop_capture_stdout_logs() -> None:
global CAPTURE_STREAM
CAPTURE_STREAM = None
def get_capture_stream() -> Optional[TextIO]:
return CAPTURE_STREAM
# returns a dictionary representation of the event fields.
# the message may contain secrets which must be scrubbed at the usage site.
def msg_to_json(msg: EventMsg) -> str:
msg_dict = msg_to_dict(msg)
raw_log_line = json.dumps(msg_dict, sort_keys=True, cls=ForgivingJSONEncoder)
return raw_log_line
def msg_to_dict(msg: EventMsg) -> dict:
msg_dict = dict()
try:
msg_dict = MessageToDict(
msg, preserving_proto_field_name=True, including_default_value_fields=True # type: ignore
)
except Exception as exc:
event_type = type(msg).__name__
fire_event(
Note(msg=f"type {event_type} is not serializable. {str(exc)}"), level=EventLevel.WARN
)
# We don't want an empty NodeInfo in output
if (
"data" in msg_dict
and "node_info" in msg_dict["data"]
and msg_dict["data"]["node_info"]["node_name"] == ""
):
del msg_dict["data"]["node_info"]
return msg_dict
def warn_or_error(event, node=None) -> None:
if WARN_ERROR or WARN_ERROR_OPTIONS.includes(type(event).__name__):
# TODO: resolve this circular import when at top
from dbt.common.exceptions import EventCompilationError
raise EventCompilationError(event.message(), node)
else:
fire_event(event)
# an alternative to fire_event which only creates and logs the event value
# if the condition is met. Does nothing otherwise.
def fire_event_if(
conditional: bool, lazy_e: Callable[[], BaseEvent], level: Optional[EventLevel] = None
) -> None:
if conditional:
fire_event(lazy_e(), level=level)
# a special case of fire_event_if, to only fire events in our unit/functional tests
def fire_event_if_test(
lazy_e: Callable[[], BaseEvent], level: Optional[EventLevel] = None
) -> None:
fire_event_if(conditional=("pytest" in sys.modules), lazy_e=lazy_e, level=level)
# top-level method for accessing the new eventing system
# this is where all the side effects happen branched by event type
# (i.e. - mutating the event history, printing to stdout, logging
# to files, etc.)
def fire_event(e: BaseEvent, level: Optional[EventLevel] = None) -> None:
get_event_manager().fire_event(e, level=level)
def get_metadata_vars() -> Dict[str, str]:
global metadata_vars
if not metadata_vars:
metadata_vars = {
k[len(_METADATA_ENV_PREFIX) :]: v
for k, v in os.environ.items()
if k.startswith(_METADATA_ENV_PREFIX)
}
return metadata_vars
def reset_metadata_vars() -> None:
global metadata_vars
metadata_vars = None

View File

@@ -1,14 +0,0 @@
from datetime import datetime
# This converts a datetime to a json format datetime string which
# is used in constructing protobuf message timestamps.
def datetime_to_json_string(dt: datetime) -> str:
return dt.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
# preformatted time stamp
def get_json_string_utcnow() -> str:
ts = datetime.utcnow()
ts_rfc3339 = datetime_to_json_string(ts)
return ts_rfc3339

View File

@@ -1,7 +0,0 @@
from typing import Protocol, runtime_checkable
@runtime_checkable
class LoggableDbtObject(Protocol):
def pluralize(self) -> str:
...

View File

@@ -1,180 +0,0 @@
import json
import logging
import threading
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from logging.handlers import RotatingFileHandler
from typing import Optional, TextIO, Any, Callable
from colorama import Style
from dbt.common.events.base_types import EventLevel, EventMsg
from dbt.common.events.format import timestamp_to_datetime_string
from dbt.common.utils import ForgivingJSONEncoder
# A Filter is a function which takes a BaseEvent and returns True if the event
# should be logged, False otherwise.
Filter = Callable[[EventMsg], bool]
# Default filter which logs every event
def NoFilter(_: EventMsg) -> bool:
return True
# A Scrubber removes secrets from an input string, returning a sanitized string.
Scrubber = Callable[[str], str]
# Provide a pass-through scrubber implementation, also used as a default
def NoScrubber(s: str) -> str:
return s
class LineFormat(Enum):
PlainText = 1
DebugText = 2
Json = 3
# Map from dbt event levels to python log levels
_log_level_map = {
EventLevel.DEBUG: 10,
EventLevel.TEST: 10,
EventLevel.INFO: 20,
EventLevel.WARN: 30,
EventLevel.ERROR: 40,
}
# We need this function for now because the numeric log severity levels in
# Python do not match those for logbook, so we have to explicitly call the
# correct function by name.
def send_to_logger(l, level: str, log_line: str):
if level == "test":
l.debug(log_line)
elif level == "debug":
l.debug(log_line)
elif level == "info":
l.info(log_line)
elif level == "warn":
l.warning(log_line)
elif level == "error":
l.error(log_line)
else:
raise AssertionError(
f"While attempting to log {log_line}, encountered the unhandled level: {level}"
)
@dataclass
class LoggerConfig:
name: str
filter: Filter = NoFilter
scrubber: Scrubber = NoScrubber
line_format: LineFormat = LineFormat.PlainText
level: EventLevel = EventLevel.WARN
invocation_id: Optional[str] = None
use_colors: bool = False
output_stream: Optional[TextIO] = None
output_file_name: Optional[str] = None
output_file_max_bytes: Optional[int] = 10 * 1024 * 1024 # 10 mb
logger: Optional[Any] = None
class _Logger:
def __init__(self, config: LoggerConfig) -> None:
self.name: str = config.name
self.filter: Filter = config.filter
self.scrubber: Scrubber = config.scrubber
self.level: EventLevel = config.level
self.invocation_id: Optional[str] = config.invocation_id
self._python_logger: Optional[logging.Logger] = config.logger
if config.output_stream is not None:
stream_handler = logging.StreamHandler(config.output_stream)
self._python_logger = self._get_python_log_for_handler(stream_handler)
if config.output_file_name:
file_handler = RotatingFileHandler(
filename=str(config.output_file_name),
encoding="utf8",
maxBytes=config.output_file_max_bytes, # type: ignore
backupCount=5,
)
self._python_logger = self._get_python_log_for_handler(file_handler)
def _get_python_log_for_handler(self, handler: logging.Handler):
log = logging.getLogger(self.name)
log.setLevel(_log_level_map[self.level])
handler.setFormatter(logging.Formatter(fmt="%(message)s"))
log.handlers.clear()
log.propagate = False
log.addHandler(handler)
return log
def create_line(self, msg: EventMsg) -> str:
raise NotImplementedError()
def write_line(self, msg: EventMsg):
line = self.create_line(msg)
if self._python_logger is not None:
send_to_logger(self._python_logger, msg.info.level, line)
def flush(self):
if self._python_logger is not None:
for handler in self._python_logger.handlers:
handler.flush()
class _TextLogger(_Logger):
def __init__(self, config: LoggerConfig) -> None:
super().__init__(config)
self.use_colors = config.use_colors
self.use_debug_format = config.line_format == LineFormat.DebugText
def create_line(self, msg: EventMsg) -> str:
return self.create_debug_line(msg) if self.use_debug_format else self.create_info_line(msg)
def create_info_line(self, msg: EventMsg) -> str:
ts: str = datetime.utcnow().strftime("%H:%M:%S")
scrubbed_msg: str = self.scrubber(msg.info.msg) # type: ignore
return f"{self._get_color_tag()}{ts} {scrubbed_msg}"
def create_debug_line(self, msg: EventMsg) -> str:
log_line: str = ""
# Create a separator if this is the beginning of an invocation
# TODO: This is an ugly hack, get rid of it if we can
ts: str = timestamp_to_datetime_string(msg.info.ts)
if msg.info.name == "MainReportVersion":
separator = 30 * "="
log_line = f"\n\n{separator} {ts} | {self.invocation_id} {separator}\n"
scrubbed_msg: str = self.scrubber(msg.info.msg) # type: ignore
level = msg.info.level
log_line += (
f"{self._get_color_tag()}{ts} [{level:<5}]{self._get_thread_name()} {scrubbed_msg}"
)
return log_line
def _get_color_tag(self) -> str:
return "" if not self.use_colors else Style.RESET_ALL
def _get_thread_name(self) -> str:
thread_name = ""
if threading.current_thread().name:
thread_name = threading.current_thread().name
thread_name = thread_name[:10]
thread_name = thread_name.ljust(10, " ")
thread_name = f" [{thread_name}]:"
return thread_name
class _JsonLogger(_Logger):
def create_line(self, msg: EventMsg) -> str:
from dbt.common.events.functions import msg_to_dict
msg_dict = msg_to_dict(msg)
raw_log_line = json.dumps(msg_dict, sort_keys=True, cls=ForgivingJSONEncoder)
line = self.scrubber(raw_log_line) # type: ignore
return line

View File

@@ -1,121 +0,0 @@
syntax = "proto3";
package proto_types;
import "google/protobuf/timestamp.proto";
// Common event info
message EventInfo {
string name = 1;
string code = 2;
string msg = 3;
string level = 4;
string invocation_id = 5;
int32 pid = 6;
string thread = 7;
google.protobuf.Timestamp ts = 8;
map<string, string> extra = 9;
string category = 10;
}
// GenericMessage, used for deserializing only
message GenericMessage {
EventInfo info = 1;
}
// M - Deps generation
// M020
message RetryExternalCall {
int32 attempt = 1;
int32 max = 2;
}
message RetryExternalCallMsg {
EventInfo info = 1;
RetryExternalCall data = 2;
}
// M021
message RecordRetryException {
string exc = 1;
}
message RecordRetryExceptionMsg {
EventInfo info = 1;
RecordRetryException data = 2;
}
// Z - Misc
// Z005
message SystemCouldNotWrite {
string path = 1;
string reason = 2;
string exc = 3;
}
message SystemCouldNotWriteMsg {
EventInfo info = 1;
SystemCouldNotWrite data = 2;
}
// Z006
message SystemExecutingCmd {
repeated string cmd = 1;
}
message SystemExecutingCmdMsg {
EventInfo info = 1;
SystemExecutingCmd data = 2;
}
// Z007
message SystemStdOut{
string bmsg = 1;
}
message SystemStdOutMsg {
EventInfo info = 1;
SystemStdOut data = 2;
}
// Z008
message SystemStdErr {
string bmsg = 1;
}
message SystemStdErrMsg {
EventInfo info = 1;
SystemStdErr data = 2;
}
// Z009
message SystemReportReturnCode {
int32 returncode = 1;
}
message SystemReportReturnCodeMsg {
EventInfo info = 1;
SystemReportReturnCode data = 2;
}
// Z017
message Formatting {
string msg = 1;
}
message FormattingMsg {
EventInfo info = 1;
Formatting data = 2;
}
// Z050
message Note {
string msg = 1;
}
message NoteMsg {
EventInfo info = 1;
Note data = 2;
}

View File

@@ -1,124 +0,0 @@
from dbt.common.events.base_types import (
DebugLevel,
InfoLevel,
)
# The classes in this file represent the data necessary to describe a
# particular event to both human readable logs, and machine reliable
# event streams. classes extend superclasses that indicate what
# destinations they are intended for, which mypy uses to enforce
# that the necessary methods are defined.
# Event codes have prefixes which follow this table
#
# | Code | Description |
# |:----:|:-------------------:|
# | A | Pre-project loading |
# | D | Deprecations |
# | E | DB adapter |
# | I | Project parsing |
# | M | Deps generation |
# | P | Artifacts |
# | Q | Node execution |
# | W | Node testing |
# | Z | Misc |
# | T | Test only |
#
# The basic idea is that event codes roughly translate to the natural order of running a dbt task
# =======================================================
# M - Deps generation
# =======================================================
class RetryExternalCall(DebugLevel):
def code(self) -> str:
return "M020"
def message(self) -> str:
return f"Retrying external call. Attempt: {self.attempt} Max attempts: {self.max}"
class RecordRetryException(DebugLevel):
def code(self) -> str:
return "M021"
def message(self) -> str:
return f"External call exception: {self.exc}"
# =======================================================
# Z - Misc
# =======================================================
class SystemCouldNotWrite(DebugLevel):
def code(self) -> str:
return "Z005"
def message(self) -> str:
return (
f"Could not write to path {self.path}({len(self.path)} characters): "
f"{self.reason}\nexception: {self.exc}"
)
class SystemExecutingCmd(DebugLevel):
def code(self) -> str:
return "Z006"
def message(self) -> str:
return f'Executing "{" ".join(self.cmd)}"'
class SystemStdOut(DebugLevel):
def code(self) -> str:
return "Z007"
def message(self) -> str:
return f'STDOUT: "{str(self.bmsg)}"'
class SystemStdErr(DebugLevel):
def code(self) -> str:
return "Z008"
def message(self) -> str:
return f'STDERR: "{str(self.bmsg)}"'
class SystemReportReturnCode(DebugLevel):
def code(self) -> str:
return "Z009"
def message(self) -> str:
return f"command return code={self.returncode}"
# We use events to create console output, but also think of them as a sequence of important and
# meaningful occurrences to be used for debugging and monitoring. The Formatting event helps eases
# the tension between these two goals by allowing empty lines, heading separators, and other
# formatting to be written to the console, while they can be ignored for other purposes. For
# general information that isn't simple formatting, the Note event should be used instead.
class Formatting(InfoLevel):
def code(self) -> str:
return "Z017"
def message(self) -> str:
return self.msg
class Note(InfoLevel):
"""The Note event provides a way to log messages which aren't likely to be
useful as more structured events. For console formatting text like empty
lines and separator bars, use the Formatting event instead."""
def code(self) -> str:
return "Z050"
def message(self) -> str:
return self.msg

View File

@@ -1,69 +0,0 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: types.proto
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0btypes.proto\x12\x0bproto_types\x1a\x1fgoogle/protobuf/timestamp.proto\"\x91\x02\n\tEventInfo\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04\x63ode\x18\x02 \x01(\t\x12\x0b\n\x03msg\x18\x03 \x01(\t\x12\r\n\x05level\x18\x04 \x01(\t\x12\x15\n\rinvocation_id\x18\x05 \x01(\t\x12\x0b\n\x03pid\x18\x06 \x01(\x05\x12\x0e\n\x06thread\x18\x07 \x01(\t\x12&\n\x02ts\x18\x08 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x30\n\x05\x65xtra\x18\t \x03(\x0b\x32!.proto_types.EventInfo.ExtraEntry\x12\x10\n\x08\x63\x61tegory\x18\n \x01(\t\x1a,\n\nExtraEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"6\n\x0eGenericMessage\x12$\n\x04info\x18\x01 \x01(\x0b\x32\x16.proto_types.EventInfo\"1\n\x11RetryExternalCall\x12\x0f\n\x07\x61ttempt\x18\x01 \x01(\x05\x12\x0b\n\x03max\x18\x02 \x01(\x05\"j\n\x14RetryExternalCallMsg\x12$\n\x04info\x18\x01 \x01(\x0b\x32\x16.proto_types.EventInfo\x12,\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x1e.proto_types.RetryExternalCall\"#\n\x14RecordRetryException\x12\x0b\n\x03\x65xc\x18\x01 \x01(\t\"p\n\x17RecordRetryExceptionMsg\x12$\n\x04info\x18\x01 \x01(\x0b\x32\x16.proto_types.EventInfo\x12/\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32!.proto_types.RecordRetryException\"@\n\x13SystemCouldNotWrite\x12\x0c\n\x04path\x18\x01 \x01(\t\x12\x0e\n\x06reason\x18\x02 \x01(\t\x12\x0b\n\x03\x65xc\x18\x03 \x01(\t\"n\n\x16SystemCouldNotWriteMsg\x12$\n\x04info\x18\x01 \x01(\x0b\x32\x16.proto_types.EventInfo\x12.\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32 .proto_types.SystemCouldNotWrite\"!\n\x12SystemExecutingCmd\x12\x0b\n\x03\x63md\x18\x01 \x03(\t\"l\n\x15SystemExecutingCmdMsg\x12$\n\x04info\x18\x01 \x01(\x0b\x32\x16.proto_types.EventInfo\x12-\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x1f.proto_types.SystemExecutingCmd\"\x1c\n\x0cSystemStdOut\x12\x0c\n\x04\x62msg\x18\x01 \x01(\t\"`\n\x0fSystemStdOutMsg\x12$\n\x04info\x18\x01 \x01(\x0b\x32\x16.proto_types.EventInfo\x12\'\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x19.proto_types.SystemStdOut\"\x1c\n\x0cSystemStdErr\x12\x0c\n\x04\x62msg\x18\x01 \x01(\t\"`\n\x0fSystemStdErrMsg\x12$\n\x04info\x18\x01 \x01(\x0b\x32\x16.proto_types.EventInfo\x12\'\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x19.proto_types.SystemStdErr\",\n\x16SystemReportReturnCode\x12\x12\n\nreturncode\x18\x01 \x01(\x05\"t\n\x19SystemReportReturnCodeMsg\x12$\n\x04info\x18\x01 \x01(\x0b\x32\x16.proto_types.EventInfo\x12\x31\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32#.proto_types.SystemReportReturnCode\"\x19\n\nFormatting\x12\x0b\n\x03msg\x18\x01 \x01(\t\"\\\n\rFormattingMsg\x12$\n\x04info\x18\x01 \x01(\x0b\x32\x16.proto_types.EventInfo\x12%\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x17.proto_types.Formatting\"\x13\n\x04Note\x12\x0b\n\x03msg\x18\x01 \x01(\t\"P\n\x07NoteMsg\x12$\n\x04info\x18\x01 \x01(\x0b\x32\x16.proto_types.EventInfo\x12\x1f\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x11.proto_types.Noteb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'types_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_EVENTINFO_EXTRAENTRY._options = None
_EVENTINFO_EXTRAENTRY._serialized_options = b'8\001'
_globals['_EVENTINFO']._serialized_start=62
_globals['_EVENTINFO']._serialized_end=335
_globals['_EVENTINFO_EXTRAENTRY']._serialized_start=291
_globals['_EVENTINFO_EXTRAENTRY']._serialized_end=335
_globals['_GENERICMESSAGE']._serialized_start=337
_globals['_GENERICMESSAGE']._serialized_end=391
_globals['_RETRYEXTERNALCALL']._serialized_start=393
_globals['_RETRYEXTERNALCALL']._serialized_end=442
_globals['_RETRYEXTERNALCALLMSG']._serialized_start=444
_globals['_RETRYEXTERNALCALLMSG']._serialized_end=550
_globals['_RECORDRETRYEXCEPTION']._serialized_start=552
_globals['_RECORDRETRYEXCEPTION']._serialized_end=587
_globals['_RECORDRETRYEXCEPTIONMSG']._serialized_start=589
_globals['_RECORDRETRYEXCEPTIONMSG']._serialized_end=701
_globals['_SYSTEMCOULDNOTWRITE']._serialized_start=703
_globals['_SYSTEMCOULDNOTWRITE']._serialized_end=767
_globals['_SYSTEMCOULDNOTWRITEMSG']._serialized_start=769
_globals['_SYSTEMCOULDNOTWRITEMSG']._serialized_end=879
_globals['_SYSTEMEXECUTINGCMD']._serialized_start=881
_globals['_SYSTEMEXECUTINGCMD']._serialized_end=914
_globals['_SYSTEMEXECUTINGCMDMSG']._serialized_start=916
_globals['_SYSTEMEXECUTINGCMDMSG']._serialized_end=1024
_globals['_SYSTEMSTDOUT']._serialized_start=1026
_globals['_SYSTEMSTDOUT']._serialized_end=1054
_globals['_SYSTEMSTDOUTMSG']._serialized_start=1056
_globals['_SYSTEMSTDOUTMSG']._serialized_end=1152
_globals['_SYSTEMSTDERR']._serialized_start=1154
_globals['_SYSTEMSTDERR']._serialized_end=1182
_globals['_SYSTEMSTDERRMSG']._serialized_start=1184
_globals['_SYSTEMSTDERRMSG']._serialized_end=1280
_globals['_SYSTEMREPORTRETURNCODE']._serialized_start=1282
_globals['_SYSTEMREPORTRETURNCODE']._serialized_end=1326
_globals['_SYSTEMREPORTRETURNCODEMSG']._serialized_start=1328
_globals['_SYSTEMREPORTRETURNCODEMSG']._serialized_end=1444
_globals['_FORMATTING']._serialized_start=1446
_globals['_FORMATTING']._serialized_end=1471
_globals['_FORMATTINGMSG']._serialized_start=1473
_globals['_FORMATTINGMSG']._serialized_end=1565
_globals['_NOTE']._serialized_start=1567
_globals['_NOTE']._serialized_end=1586
_globals['_NOTEMSG']._serialized_start=1588
_globals['_NOTEMSG']._serialized_end=1668
# @@protoc_insertion_point(module_scope)

View File

@@ -1,7 +0,0 @@
from dbt.common.exceptions.base import * # noqa
from dbt.common.exceptions.events import * # noqa
from dbt.common.exceptions.macros import * # noqa
from dbt.common.exceptions.contracts import * # noqa
from dbt.common.exceptions.connection import * # noqa
from dbt.common.exceptions.system import * # noqa
from dbt.common.exceptions.jinja import * # noqa

View File

@@ -1,275 +0,0 @@
import builtins
from typing import List, Any, Optional
import os
from dbt.common.constants import SECRET_ENV_PREFIX
from dbt.common.dataclass_schema import ValidationError
def env_secrets() -> List[str]:
return [v for k, v in os.environ.items() if k.startswith(SECRET_ENV_PREFIX) and v.strip()]
def scrub_secrets(msg: str, secrets: List[str]) -> str:
scrubbed = str(msg)
for secret in secrets:
scrubbed = scrubbed.replace(secret, "*****")
return scrubbed
class DbtBaseException(Exception):
CODE = -32000
MESSAGE = "Server Error"
def data(self):
# if overriding, make sure the result is json-serializable.
return {
"type": self.__class__.__name__,
"message": str(self),
}
class DbtInternalError(DbtBaseException):
def __init__(self, msg: str):
self.stack: List = []
self.msg = scrub_secrets(msg, env_secrets())
@property
def type(self):
return "Internal"
def process_stack(self):
lines = []
stack = self.stack
first = True
if len(stack) > 1:
lines.append("")
for item in stack:
msg = "called by"
if first:
msg = "in"
first = False
lines.append(f"> {msg}")
return lines
def __str__(self):
if hasattr(self.msg, "split"):
split_msg = self.msg.split("\n")
else:
split_msg = str(self.msg).split("\n")
lines = ["{}".format(self.type + " Error")] + split_msg
lines += self.process_stack()
return lines[0] + "\n" + "\n".join([" " + line for line in lines[1:]])
class DbtRuntimeError(RuntimeError, DbtBaseException):
CODE = 10001
MESSAGE = "Runtime error"
def __init__(self, msg: str, node=None) -> None:
self.stack: List = []
self.node = node
self.msg = scrub_secrets(msg, env_secrets())
def add_node(self, node=None):
if node is not None and node is not self.node:
if self.node is not None:
self.stack.append(self.node)
self.node = node
@property
def type(self):
return "Runtime"
def node_to_string(self, node: Any):
"""
Given a node-like object we attempt to create the best identifier we can
"""
result = ""
if hasattr(node, "resource_type"):
result += node.resource_type
if hasattr(node, "name"):
result += f" {node.name}"
if hasattr(node, "original_file_path"):
result += f" ({node.original_file_path})"
return result.strip() if result != "" else "<Unknown>"
def process_stack(self):
lines = []
stack = self.stack + [self.node]
first = True
if len(stack) > 1:
lines.append("")
for item in stack:
msg = "called by"
if first:
msg = "in"
first = False
lines.append(f"> {msg} {self.node_to_string(item)}")
return lines
def validator_error_message(self, exc: builtins.Exception):
"""Given a dbt.dataclass_schema.ValidationError (which is basically a
jsonschema.ValidationError), return the relevant parts as a string
"""
if not isinstance(exc, ValidationError):
return str(exc)
path = "[%s]" % "][".join(map(repr, exc.relative_path))
return f"at path {path}: {exc.message}"
def __str__(self, prefix: str = "! "):
node_string = ""
if self.node is not None:
node_string = f" in {self.node_to_string(self.node)}"
if hasattr(self.msg, "split"):
split_msg = self.msg.split("\n")
else:
split_msg = str(self.msg).split("\n")
lines = ["{}{}".format(self.type + " Error", node_string)] + split_msg
lines += self.process_stack()
return lines[0] + "\n" + "\n".join([" " + line for line in lines[1:]])
def data(self):
result = DbtBaseException.data(self)
if self.node is None:
return result
result.update(
{
"raw_code": self.node.raw_code,
# the node isn't always compiled, but if it is, include that!
"compiled_code": getattr(self.node, "compiled_code", None),
}
)
return result
class CompilationError(DbtRuntimeError):
CODE = 10004
MESSAGE = "Compilation Error"
@property
def type(self):
return "Compilation"
def _fix_dupe_msg(self, path_1: str, path_2: str, name: str, type_name: str) -> str:
if path_1 == path_2:
return (
f"remove one of the {type_name} entries for {name} in this file:\n - {path_1!s}\n"
)
else:
return (
f"remove the {type_name} entry for {name} in one of these files:\n"
f" - {path_1!s}\n{path_2!s}"
)
class RecursionError(DbtRuntimeError):
pass
class DbtConfigError(DbtRuntimeError):
CODE = 10007
MESSAGE = "DBT Configuration Error"
# ToDo: Can we remove project?
def __init__(self, msg: str, project=None, result_type="invalid_project", path=None) -> None:
self.project = project
super().__init__(msg)
self.result_type = result_type
self.path = path
def __str__(self, prefix="! ") -> str:
msg = super().__str__(prefix)
if self.path is None:
return msg
else:
return f"{msg}\n\nError encountered in {self.path}"
class NotImplementedError(DbtBaseException):
def __init__(self, msg: str) -> None:
self.msg = msg
self.formatted_msg = f"ERROR: {self.msg}"
super().__init__(self.formatted_msg)
class SemverError(Exception):
def __init__(self, msg: Optional[str] = None) -> None:
self.msg = msg
if msg is not None:
super().__init__(msg)
else:
super().__init__()
class VersionsNotCompatibleError(SemverError):
pass
class DbtValidationError(DbtRuntimeError):
CODE = 10005
MESSAGE = "Validation Error"
class DbtDatabaseError(DbtRuntimeError):
CODE = 10003
MESSAGE = "Database Error"
def process_stack(self):
lines = []
if hasattr(self.node, "build_path") and self.node.build_path:
lines.append(f"compiled Code at {self.node.build_path}")
return lines + DbtRuntimeError.process_stack(self)
@property
def type(self):
return "Database"
class UnexpectedNullError(DbtDatabaseError):
def __init__(self, field_name: str, source):
self.field_name = field_name
self.source = source
msg = (
f"Expected a non-null value when querying field '{self.field_name}' of table "
f" {self.source} but received value 'null' instead"
)
super().__init__(msg)
class CommandError(DbtRuntimeError):
def __init__(self, cwd: str, cmd: List[str], msg: str = "Error running command") -> None:
cmd_scrubbed = list(scrub_secrets(cmd_txt, env_secrets()) for cmd_txt in cmd)
super().__init__(msg)
self.cwd = cwd
self.cmd = cmd_scrubbed
self.args = (cwd, cmd_scrubbed, msg)
def __str__(self):
if len(self.cmd) == 0:
return f"{self.msg}: No arguments given"
return f'{self.msg}: "{self.cmd[0]}"'

View File

@@ -1,7 +0,0 @@
class ConnectionError(Exception):
"""
There was a problem with the connection that returned a bad response,
timed out, or resulted in a file that is corrupt.
"""
pass

View File

@@ -1,17 +0,0 @@
from typing import Any
from dbt.common.exceptions import CompilationError
# this is part of the context and also raised in dbt.contracts.relation.py
class DataclassNotDictError(CompilationError):
def __init__(self, obj: Any):
self.obj = obj
super().__init__(msg=self.get_message())
def get_message(self) -> str:
msg = (
f'The object ("{self.obj}") was used as a dictionary. This '
"capability has been removed from objects of this type."
)
return msg

View File

@@ -1,9 +0,0 @@
from dbt.common.exceptions import CompilationError, scrub_secrets, env_secrets
# event level exception
class EventCompilationError(CompilationError):
def __init__(self, msg: str, node) -> None:
self.msg = scrub_secrets(msg, env_secrets())
self.node = node
super().__init__(msg=self.msg)

View File

@@ -1,85 +0,0 @@
from dbt.common.exceptions import CompilationError
class BlockDefinitionNotAtTopError(CompilationError):
def __init__(self, tag_parser, tag_start) -> None:
self.tag_parser = tag_parser
self.tag_start = tag_start
super().__init__(msg=self.get_message())
def get_message(self) -> str:
position = self.tag_parser.linepos(self.tag_start)
msg = (
f"Got a block definition inside control flow at {position}. "
"All dbt block definitions must be at the top level"
)
return msg
class MissingCloseTagError(CompilationError):
def __init__(self, block_type_name: str, linecount: int) -> None:
self.block_type_name = block_type_name
self.linecount = linecount
super().__init__(msg=self.get_message())
def get_message(self) -> str:
msg = f"Reached EOF without finding a close tag for {self.block_type_name} (searched from line {self.linecount})"
return msg
class MissingControlFlowStartTagError(CompilationError):
def __init__(self, tag, expected_tag: str, tag_parser) -> None:
self.tag = tag
self.expected_tag = expected_tag
self.tag_parser = tag_parser
super().__init__(msg=self.get_message())
def get_message(self) -> str:
linepos = self.tag_parser.linepos(self.tag.start)
msg = (
f"Got an unexpected control flow end tag, got {self.tag.block_type_name} but "
f"expected {self.expected_tag} next (@ {linepos})"
)
return msg
class NestedTagsError(CompilationError):
def __init__(self, outer, inner) -> None:
self.outer = outer
self.inner = inner
super().__init__(msg=self.get_message())
def get_message(self) -> str:
msg = (
f"Got nested tags: {self.outer.block_type_name} (started at {self.outer.start}) did "
f"not have a matching {{{{% end{self.outer.block_type_name} %}}}} before a "
f"subsequent {self.inner.block_type_name} was found (started at {self.inner.start})"
)
return msg
class UnexpectedControlFlowEndTagError(CompilationError):
def __init__(self, tag, expected_tag: str, tag_parser) -> None:
self.tag = tag
self.expected_tag = expected_tag
self.tag_parser = tag_parser
super().__init__(msg=self.get_message())
def get_message(self) -> str:
linepos = self.tag_parser.linepos(self.tag.start)
msg = (
f"Got an unexpected control flow end tag, got {self.tag.block_type_name} but "
f"never saw a preceeding {self.expected_tag} (@ {linepos})"
)
return msg
class UnexpectedMacroEOFError(CompilationError):
def __init__(self, expected_name: str, actual_name: str) -> None:
self.expected_name = expected_name
self.actual_name = actual_name
super().__init__(msg=self.get_message())
def get_message(self) -> str:
msg = f'unexpected EOF, expected {self.expected_name}, got "{self.actual_name}"'
return msg

View File

@@ -1,110 +0,0 @@
from typing import Any
from dbt.common.exceptions import CompilationError, DbtBaseException
class MacroReturn(DbtBaseException):
"""
Hack of all hacks
This is not actually an exception.
It's how we return a value from a macro.
"""
def __init__(self, value) -> None:
self.value = value
class UndefinedMacroError(CompilationError):
def __str__(self, prefix: str = "! ") -> str:
msg = super().__str__(prefix)
return (
f"{msg}. This can happen when calling a macro that does "
"not exist. Check for typos and/or install package dependencies "
'with "dbt deps".'
)
class UndefinedCompilationError(CompilationError):
def __init__(self, name: str, node) -> None:
self.name = name
self.node = node
self.msg = f"{self.name} is undefined"
super().__init__(msg=self.msg)
class CaughtMacroError(CompilationError):
def __init__(self, exc) -> None:
self.exc = exc
super().__init__(msg=str(exc))
class CaughtMacroErrorWithNodeError(CompilationError):
def __init__(self, exc, node) -> None:
self.exc = exc
self.node = node
super().__init__(msg=str(exc))
class JinjaRenderingError(CompilationError):
pass
class MaterializationArgError(CompilationError):
def __init__(self, name: str, argument: str) -> None:
self.name = name
self.argument = argument
super().__init__(msg=self.get_message())
def get_message(self) -> str:
msg = f"materialization '{self.name}' received unknown argument '{self.argument}'."
return msg
class MacroNameNotStringError(CompilationError):
def __init__(self, kwarg_value) -> None:
self.kwarg_value = kwarg_value
super().__init__(msg=self.get_message())
def get_message(self) -> str:
msg = (
f"The macro_name parameter ({self.kwarg_value}) "
"to adapter.dispatch was not a string"
)
return msg
class MacrosSourcesUnWriteableError(CompilationError):
def __init__(self, node) -> None:
self.node = node
msg = 'cannot "write" macros or sources'
super().__init__(msg=msg)
class MacroArgTypeError(CompilationError):
def __init__(self, method_name: str, arg_name: str, got_value: Any, expected_type) -> None:
self.method_name = method_name
self.arg_name = arg_name
self.got_value = got_value
self.expected_type = expected_type
super().__init__(msg=self.get_message())
def get_message(self) -> str:
got_type = type(self.got_value)
msg = (
f"'adapter.{self.method_name}' expects argument "
f"'{self.arg_name}' to be of type '{self.expected_type}', instead got "
f"{self.got_value} ({got_type})"
)
return msg
class MacroResultError(CompilationError):
def __init__(self, freshness_macro_name: str, table):
self.freshness_macro_name = freshness_macro_name
self.table = table
super().__init__(msg=self.get_message())
def get_message(self) -> str:
msg = f'Got an invalid result from "{self.freshness_macro_name}" macro: {[tuple(r) for r in self.table]}'
return msg

View File

@@ -1,50 +0,0 @@
from typing import List, Union, Any
from dbt.common.exceptions import CompilationError, CommandError, scrub_secrets, env_secrets
class SymbolicLinkError(CompilationError):
def __init__(self) -> None:
super().__init__(msg=self.get_message())
def get_message(self) -> str:
msg = (
"dbt encountered an error when attempting to create a symbolic link. "
"If this error persists, please create an issue at: \n\n"
"https://github.com/dbt-labs/dbt-core"
)
return msg
class ExecutableError(CommandError):
def __init__(self, cwd: str, cmd: List[str], msg: str) -> None:
super().__init__(cwd, cmd, msg)
class WorkingDirectoryError(CommandError):
def __init__(self, cwd: str, cmd: List[str], msg: str) -> None:
super().__init__(cwd, cmd, msg)
def __str__(self):
return f'{self.msg}: "{self.cwd}"'
class CommandResultError(CommandError):
def __init__(
self,
cwd: str,
cmd: List[str],
returncode: Union[int, Any],
stdout: bytes,
stderr: bytes,
msg: str = "Got a non-zero returncode",
) -> None:
super().__init__(cwd, cmd, msg)
self.returncode = returncode
self.stdout = scrub_secrets(stdout.decode("utf-8"), env_secrets())
self.stderr = scrub_secrets(stderr.decode("utf-8"), env_secrets())
self.args = (cwd, self.cmd, returncode, self.stdout, self.stderr, msg)
def __str__(self):
return f"{self.msg} running: {self.cmd}"

View File

@@ -1,122 +0,0 @@
# never name this package "types", or mypy will crash in ugly ways
# necessary for annotating constructors
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Tuple, AbstractSet, Union
from typing import Callable, cast, Generic, Optional, TypeVar, List, NewType, Set
from dbt.common.dataclass_schema import (
dbtClassMixin,
ValidationError,
StrEnum,
)
Port = NewType("Port", int)
class NVEnum(StrEnum):
novalue = "novalue"
def __eq__(self, other):
return isinstance(other, NVEnum)
@dataclass
class NoValue(dbtClassMixin):
"""Sometimes, you want a way to say none that isn't None"""
novalue: NVEnum = field(default_factory=lambda: NVEnum.novalue)
@dataclass
class IncludeExclude(dbtClassMixin):
INCLUDE_ALL = ("all", "*")
include: Union[str, List[str]]
exclude: List[str] = field(default_factory=list)
def __post_init__(self):
if isinstance(self.include, str) and self.include not in self.INCLUDE_ALL:
raise ValidationError(
f"include must be one of {self.INCLUDE_ALL} or a list of strings"
)
if self.exclude and self.include not in self.INCLUDE_ALL:
raise ValidationError(
f"exclude can only be specified if include is one of {self.INCLUDE_ALL}"
)
if isinstance(self.include, list):
self._validate_items(self.include)
if isinstance(self.exclude, list):
self._validate_items(self.exclude)
def includes(self, item_name: str):
return (
item_name in self.include or self.include in self.INCLUDE_ALL
) and item_name not in self.exclude
def _validate_items(self, items: List[str]):
pass
class WarnErrorOptions(IncludeExclude):
def __init__(
self,
include: Union[str, List[str]],
exclude: Optional[List[str]] = None,
valid_error_names: Optional[Set[str]] = None,
):
self._valid_error_names: Set[str] = valid_error_names or set()
super().__init__(include=include, exclude=(exclude or []))
def _validate_items(self, items: List[str]):
for item in items:
if item not in self._valid_error_names:
raise ValidationError(f"{item} is not a valid dbt error name.")
FQNPath = Tuple[str, ...]
PathSet = AbstractSet[FQNPath]
T = TypeVar("T")
# A data type for representing lazily evaluated values.
#
# usage:
# x = Lazy.defer(lambda: expensive_fn())
# y = x.force()
#
# inspired by the purescript data type
# https://pursuit.purescript.org/packages/purescript-lazy/5.0.0/docs/Data.Lazy
@dataclass
class Lazy(Generic[T]):
_f: Callable[[], T]
memo: Optional[T] = None
# constructor for lazy values
@classmethod
def defer(cls, f: Callable[[], T]) -> Lazy[T]:
return Lazy(f)
# workaround for open mypy issue:
# https://github.com/python/mypy/issues/6910
def _typed_eval_f(self) -> T:
return cast(Callable[[], T], getattr(self, "_f"))()
# evaluates the function if the value has not been memoized already
def force(self) -> T:
if self.memo is None:
self.memo = self._typed_eval_f()
return self.memo
# This class is used in to_target_dict, so that accesses to missing keys
# will return an empty string instead of Undefined
class DictDefaultEmptyStr(dict):
def __getitem__(self, key):
return dict.get(self, key, "")

View File

@@ -1,12 +0,0 @@
import uuid
_INVOCATION_ID = str(uuid.uuid4())
def get_invocation_id() -> str:
return _INVOCATION_ID
def reset_invocation_id():
global _INVOCATION_ID
_INVOCATION_ID = str(uuid.uuid4())

View File

@@ -1,473 +0,0 @@
from dataclasses import dataclass
import re
from typing import List
import dbt.common.exceptions.base
from dbt.common.exceptions import VersionsNotCompatibleError
from dbt.common.dataclass_schema import dbtClassMixin, StrEnum
from typing import Optional
class Matchers(StrEnum):
GREATER_THAN = ">"
GREATER_THAN_OR_EQUAL = ">="
LESS_THAN = "<"
LESS_THAN_OR_EQUAL = "<="
EXACT = "="
@dataclass
class VersionSpecification(dbtClassMixin):
major: Optional[str] = None
minor: Optional[str] = None
patch: Optional[str] = None
prerelease: Optional[str] = None
build: Optional[str] = None
matcher: Matchers = Matchers.EXACT
_MATCHERS = r"(?P<matcher>\>=|\>|\<|\<=|=)?"
_NUM_NO_LEADING_ZEROS = r"(0|[1-9]\d*)"
_ALPHA = r"[0-9A-Za-z-]*"
_ALPHA_NO_LEADING_ZEROS = r"(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)"
_BASE_VERSION_REGEX = r"""
(?P<major>{num_no_leading_zeros})\.
(?P<minor>{num_no_leading_zeros})\.
(?P<patch>{num_no_leading_zeros})
""".format(
num_no_leading_zeros=_NUM_NO_LEADING_ZEROS
)
_VERSION_EXTRA_REGEX = r"""
(\-?
(?P<prerelease>
{alpha_no_leading_zeros}(\.{alpha_no_leading_zeros})*))?
(\+
(?P<build>
{alpha}(\.{alpha})*))?
""".format(
alpha_no_leading_zeros=_ALPHA_NO_LEADING_ZEROS, alpha=_ALPHA
)
_VERSION_REGEX_PAT_STR = r"""
^
{matchers}
{base_version_regex}
{version_extra_regex}
$
""".format(
matchers=_MATCHERS,
base_version_regex=_BASE_VERSION_REGEX,
version_extra_regex=_VERSION_EXTRA_REGEX,
)
_VERSION_REGEX = re.compile(_VERSION_REGEX_PAT_STR, re.VERBOSE)
def _cmp(a, b):
"""Return negative if a<b, zero if a==b, positive if a>b."""
return (a > b) - (a < b)
@dataclass
class VersionSpecifier(VersionSpecification):
def to_version_string(self, skip_matcher=False):
prerelease = ""
build = ""
matcher = ""
if self.prerelease:
prerelease = "-" + self.prerelease
if self.build:
build = "+" + self.build
if not skip_matcher:
matcher = self.matcher
return "{}{}.{}.{}{}{}".format(
matcher, self.major, self.minor, self.patch, prerelease, build
)
@classmethod
def from_version_string(cls, version_string):
match = _VERSION_REGEX.match(version_string)
if not match:
raise dbt.common.exceptions.base.SemverError(
f'"{version_string}" is not a valid semantic version.'
)
matched = {k: v for k, v in match.groupdict().items() if v is not None}
return cls.from_dict(matched)
def __str__(self):
return self.to_version_string()
def to_range(self) -> "VersionRange":
range_start: VersionSpecifier = UnboundedVersionSpecifier()
range_end: VersionSpecifier = UnboundedVersionSpecifier()
if self.matcher == Matchers.EXACT:
range_start = self
range_end = self
elif self.matcher in [Matchers.GREATER_THAN, Matchers.GREATER_THAN_OR_EQUAL]:
range_start = self
elif self.matcher in [Matchers.LESS_THAN, Matchers.LESS_THAN_OR_EQUAL]:
range_end = self
return VersionRange(start=range_start, end=range_end)
def compare(self, other):
if self.is_unbounded or other.is_unbounded:
return 0
for key in ["major", "minor", "patch", "prerelease"]:
(a, b) = (getattr(self, key), getattr(other, key))
if key == "prerelease":
if a is None and b is None:
continue
if a is None:
if self.matcher == Matchers.LESS_THAN:
# If 'a' is not a pre-release but 'b' is, and b must be
# less than a, return -1 to prevent installations of
# pre-releases with greater base version than a
# maximum specified non-pre-release version.
return -1
# Otherwise, stable releases are considered greater than
# pre-release
return 1
if b is None:
return -1
# Check the prerelease component only
prcmp = self._nat_cmp(a, b)
if prcmp != 0: # either -1 or 1
return prcmp
# else is equal and will fall through
else: # major/minor/patch, should all be numbers
if int(a) > int(b):
return 1
elif int(a) < int(b):
return -1
# else is equal and will fall through
equal = (
self.matcher == Matchers.GREATER_THAN_OR_EQUAL
and other.matcher == Matchers.LESS_THAN_OR_EQUAL
) or (
self.matcher == Matchers.LESS_THAN_OR_EQUAL
and other.matcher == Matchers.GREATER_THAN_OR_EQUAL
)
if equal:
return 0
lt = (
(self.matcher == Matchers.LESS_THAN and other.matcher == Matchers.LESS_THAN_OR_EQUAL)
or (
other.matcher == Matchers.GREATER_THAN
and self.matcher == Matchers.GREATER_THAN_OR_EQUAL
)
or (self.is_upper_bound and other.is_lower_bound)
)
if lt:
return -1
gt = (
(other.matcher == Matchers.LESS_THAN and self.matcher == Matchers.LESS_THAN_OR_EQUAL)
or (
self.matcher == Matchers.GREATER_THAN
and other.matcher == Matchers.GREATER_THAN_OR_EQUAL
)
or (self.is_lower_bound and other.is_upper_bound)
)
if gt:
return 1
return 0
def __lt__(self, other):
return self.compare(other) == -1
def __gt__(self, other):
return self.compare(other) == 1
def __eq___(self, other):
return self.compare(other) == 0
def __cmp___(self, other):
return self.compare(other)
@property
def is_unbounded(self):
return False
@property
def is_lower_bound(self):
return self.matcher in [Matchers.GREATER_THAN, Matchers.GREATER_THAN_OR_EQUAL]
@property
def is_upper_bound(self):
return self.matcher in [Matchers.LESS_THAN, Matchers.LESS_THAN_OR_EQUAL]
@property
def is_exact(self):
return self.matcher == Matchers.EXACT
@classmethod
def _nat_cmp(cls, a, b):
def cmp_prerelease_tag(a, b):
if isinstance(a, int) and isinstance(b, int):
return _cmp(a, b)
elif isinstance(a, int):
return -1
elif isinstance(b, int):
return 1
else:
return _cmp(a, b)
a, b = a or "", b or ""
a_parts, b_parts = a.split("."), b.split(".")
a_parts = [int(x) if re.match(r"^\d+$", x) else x for x in a_parts]
b_parts = [int(x) if re.match(r"^\d+$", x) else x for x in b_parts]
for sub_a, sub_b in zip(a_parts, b_parts):
cmp_result = cmp_prerelease_tag(sub_a, sub_b)
if cmp_result != 0:
return cmp_result
else:
return _cmp(len(a), len(b))
@dataclass
class VersionRange:
start: VersionSpecifier
end: VersionSpecifier
def _try_combine_exact(self, a, b):
if a.compare(b) == 0:
return a
else:
raise VersionsNotCompatibleError()
def _try_combine_lower_bound_with_exact(self, lower, exact):
comparison = lower.compare(exact)
if comparison < 0 or (comparison == 0 and lower.matcher == Matchers.GREATER_THAN_OR_EQUAL):
return exact
raise VersionsNotCompatibleError()
def _try_combine_lower_bound(self, a, b):
if b.is_unbounded:
return a
elif a.is_unbounded:
return b
if not (a.is_exact or b.is_exact):
comparison = a.compare(b) < 0
if comparison:
return b
else:
return a
elif a.is_exact:
return self._try_combine_lower_bound_with_exact(b, a)
elif b.is_exact:
return self._try_combine_lower_bound_with_exact(a, b)
def _try_combine_upper_bound_with_exact(self, upper, exact):
comparison = upper.compare(exact)
if comparison > 0 or (comparison == 0 and upper.matcher == Matchers.LESS_THAN_OR_EQUAL):
return exact
raise VersionsNotCompatibleError()
def _try_combine_upper_bound(self, a, b):
if b.is_unbounded:
return a
elif a.is_unbounded:
return b
if not (a.is_exact or b.is_exact):
comparison = a.compare(b) > 0
if comparison:
return b
else:
return a
elif a.is_exact:
return self._try_combine_upper_bound_with_exact(b, a)
elif b.is_exact:
return self._try_combine_upper_bound_with_exact(a, b)
def reduce(self, other):
start = None
if self.start.is_exact and other.start.is_exact:
start = end = self._try_combine_exact(self.start, other.start)
else:
start = self._try_combine_lower_bound(self.start, other.start)
end = self._try_combine_upper_bound(self.end, other.end)
if start.compare(end) > 0:
raise VersionsNotCompatibleError()
return VersionRange(start=start, end=end)
def __str__(self):
result = []
if self.start.is_unbounded and self.end.is_unbounded:
return "ANY"
if not self.start.is_unbounded:
result.append(self.start.to_version_string())
if not self.end.is_unbounded:
result.append(self.end.to_version_string())
return ", ".join(result)
def to_version_string_pair(self):
to_return = []
if not self.start.is_unbounded:
to_return.append(self.start.to_version_string())
if not self.end.is_unbounded:
to_return.append(self.end.to_version_string())
return to_return
class UnboundedVersionSpecifier(VersionSpecifier):
def __init__(self, *args, **kwargs) -> None:
super().__init__(
matcher=Matchers.EXACT, major=None, minor=None, patch=None, prerelease=None, build=None
)
def __str__(self):
return "*"
@property
def is_unbounded(self):
return True
@property
def is_lower_bound(self):
return False
@property
def is_upper_bound(self):
return False
@property
def is_exact(self):
return False
def reduce_versions(*args):
version_specifiers = []
for version in args:
if isinstance(version, UnboundedVersionSpecifier) or version is None:
continue
elif isinstance(version, VersionSpecifier):
version_specifiers.append(version)
elif isinstance(version, VersionRange):
if not isinstance(version.start, UnboundedVersionSpecifier):
version_specifiers.append(version.start)
if not isinstance(version.end, UnboundedVersionSpecifier):
version_specifiers.append(version.end)
else:
version_specifiers.append(VersionSpecifier.from_version_string(version))
for version_specifier in version_specifiers:
if not isinstance(version_specifier, VersionSpecifier):
raise Exception(version_specifier)
if not version_specifiers:
return VersionRange(start=UnboundedVersionSpecifier(), end=UnboundedVersionSpecifier())
try:
to_return = version_specifiers.pop().to_range()
for version_specifier in version_specifiers:
to_return = to_return.reduce(version_specifier.to_range())
except VersionsNotCompatibleError:
raise VersionsNotCompatibleError(
"Could not find a satisfactory version from options: {}".format([str(a) for a in args])
)
return to_return
def versions_compatible(*args):
if len(args) == 1:
return True
try:
reduce_versions(*args)
return True
except VersionsNotCompatibleError:
return False
def find_possible_versions(requested_range, available_versions):
possible_versions = []
for version_string in available_versions:
version = VersionSpecifier.from_version_string(version_string)
if versions_compatible(version, requested_range.start, requested_range.end):
possible_versions.append(version)
sorted_versions = sorted(possible_versions, reverse=True)
return [v.to_version_string(skip_matcher=True) for v in sorted_versions]
def resolve_to_specific_version(requested_range, available_versions):
max_version = None
max_version_string = None
for version_string in available_versions:
version = VersionSpecifier.from_version_string(version_string)
if versions_compatible(version, requested_range.start, requested_range.end) and (
max_version is None or max_version.compare(version) < 0
):
max_version = version
max_version_string = version_string
return max_version_string
def filter_installable(versions: List[str], install_prerelease: bool) -> List[str]:
installable = []
installable_dict = {}
for version_string in versions:
version = VersionSpecifier.from_version_string(version_string)
if install_prerelease or not version.prerelease:
installable.append(version)
installable_dict[str(version)] = version_string
sorted_installable = sorted(installable)
sorted_installable_original_versions = [
str(installable_dict.get(str(version))) for version in sorted_installable
]
return sorted_installable_original_versions

View File

@@ -1,68 +0,0 @@
import textwrap
from typing import Dict
import colorama
COLORS: Dict[str, str] = {
"red": colorama.Fore.RED,
"green": colorama.Fore.GREEN,
"yellow": colorama.Fore.YELLOW,
"reset_all": colorama.Style.RESET_ALL,
}
COLOR_FG_RED = COLORS["red"]
COLOR_FG_GREEN = COLORS["green"]
COLOR_FG_YELLOW = COLORS["yellow"]
COLOR_RESET_ALL = COLORS["reset_all"]
USE_COLOR = True
PRINTER_WIDTH = 80
def color(text: str, color_code: str) -> str:
if USE_COLOR:
return "{}{}{}".format(color_code, text, COLOR_RESET_ALL)
else:
return text
def printer_width() -> int:
return PRINTER_WIDTH
def green(text: str) -> str:
return color(text, COLOR_FG_GREEN)
def yellow(text: str) -> str:
return color(text, COLOR_FG_YELLOW)
def red(text: str) -> str:
return color(text, COLOR_FG_RED)
def line_wrap_message(msg: str, subtract: int = 0, dedent: bool = True, prefix: str = "") -> str:
"""
Line wrap the given message to PRINTER_WIDTH - {subtract}. Convert double
newlines to newlines and avoid calling textwrap.fill() on them (like
markdown)
"""
width = printer_width() - subtract
if dedent:
msg = textwrap.dedent(msg)
if prefix:
msg = f"{prefix}{msg}"
# If the input had an explicit double newline, we want to preserve that
# (we'll turn it into a single line soon). Support windows, too.
splitter = "\r\n\r\n" if "\r\n\r\n" in msg else "\n\n"
chunks = msg.split(splitter)
return "\n".join(textwrap.fill(chunk, width=width, break_on_hyphens=False) for chunk in chunks)
def warning_tag(msg: str) -> str:
return f'[{yellow("WARNING")}]: {msg}'

View File

@@ -1,26 +0,0 @@
from dbt.common.utils.encoding import md5, JSONEncoder, ForgivingJSONEncoder
from dbt.common.utils.casting import (
cast_to_str,
cast_to_int,
cast_dict_to_dict_of_strings,
)
from dbt.common.utils.dict import (
AttrDict,
filter_null_values,
merge,
deep_merge,
deep_merge_item,
deep_map_render,
)
from dbt.common.utils.executor import executor
from dbt.common.utils.jinja import (
get_dbt_macro_name,
get_docs_macro_name,
get_materialization_macro_name,
get_test_macro_name,
MACRO_PREFIX,
)

View File

@@ -1,25 +0,0 @@
# This is useful for proto generated classes in particular, since
# the default for protobuf for strings is the empty string, so
# Optional[str] types don't work for generated Python classes.
from typing import Optional
def cast_to_str(string: Optional[str]) -> str:
if string is None:
return ""
else:
return string
def cast_to_int(integer: Optional[int]) -> int:
if integer is None:
return 0
else:
return integer
def cast_dict_to_dict_of_strings(dct):
new_dct = {}
for k, v in dct.items():
new_dct[str(k)] = str(v)
return new_dct

View File

@@ -1,33 +0,0 @@
import time
from dbt.common.events.types import RecordRetryException, RetryExternalCall
from dbt.common.exceptions import ConnectionError
from tarfile import ReadError
import requests
def connection_exception_retry(fn, max_attempts: int, attempt: int = 0):
"""Attempts to run a function that makes an external call, if the call fails
on a Requests exception or decompression issue (ReadError), it will be tried
up to 5 more times. All exceptions that Requests explicitly raises inherit from
requests.exceptions.RequestException. See https://github.com/dbt-labs/dbt-core/issues/4579
for context on this decompression issues specifically.
"""
try:
return fn()
except (
requests.exceptions.RequestException,
ReadError,
EOFError,
) as exc:
if attempt <= max_attempts - 1:
# This import needs to be inline to avoid circular dependency
from dbt.common.events.functions import fire_event
fire_event(RecordRetryException(exc=str(exc)))
fire_event(RetryExternalCall(attempt=attempt, max=max_attempts))
time.sleep(1)
return connection_exception_retry(fn, max_attempts, attempt + 1)
else:
raise ConnectionError("External connection exception occurred: " + str(exc))

View File

@@ -1,128 +0,0 @@
import copy
import datetime
from typing import Dict, Optional, TypeVar, Callable, Any, Tuple, Union, Type
from dbt.common.exceptions import DbtConfigError, RecursionError
K_T = TypeVar("K_T")
V_T = TypeVar("V_T")
def filter_null_values(input: Dict[K_T, Optional[V_T]]) -> Dict[K_T, V_T]:
return {k: v for k, v in input.items() if v is not None}
class AttrDict(dict):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.__dict__ = self
def merge(*args):
if len(args) == 0:
return None
if len(args) == 1:
return args[0]
lst = list(args)
last = lst.pop(len(lst) - 1)
return _merge(merge(*lst), last)
def _merge(a, b):
to_return = a.copy()
to_return.update(b)
return to_return
# http://stackoverflow.com/questions/20656135/python-deep-merge-dictionary-data
def deep_merge(*args):
"""
>>> dbt.common.utils.deep_merge({'a': 1, 'b': 2, 'c': 3}, {'a': 2}, {'a': 3, 'b': 1}) # noqa
{'a': 3, 'b': 1, 'c': 3}
"""
if len(args) == 0:
return None
if len(args) == 1:
return copy.deepcopy(args[0])
lst = list(args)
last = copy.deepcopy(lst.pop(len(lst) - 1))
return _deep_merge(deep_merge(*lst), last)
def _deep_merge(destination, source):
if isinstance(source, dict):
for key, value in source.items():
deep_merge_item(destination, key, value)
return destination
def deep_merge_item(destination, key, value):
if isinstance(value, dict):
node = destination.setdefault(key, {})
destination[key] = deep_merge(node, value)
elif isinstance(value, tuple) or isinstance(value, list):
if key in destination:
destination[key] = list(value) + list(destination[key])
else:
destination[key] = value
else:
destination[key] = value
def _deep_map_render(
func: Callable[[Any, Tuple[Union[str, int], ...]], Any],
value: Any,
keypath: Tuple[Union[str, int], ...],
) -> Any:
atomic_types: Tuple[Type[Any], ...] = (int, float, str, type(None), bool, datetime.date)
ret: Any
if isinstance(value, list):
ret = [_deep_map_render(func, v, (keypath + (idx,))) for idx, v in enumerate(value)]
elif isinstance(value, dict):
ret = {k: _deep_map_render(func, v, (keypath + (str(k),))) for k, v in value.items()}
elif isinstance(value, atomic_types):
ret = func(value, keypath)
else:
container_types: Tuple[Type[Any], ...] = (list, dict)
ok_types = container_types + atomic_types
raise DbtConfigError(
"in _deep_map_render, expected one of {!r}, got {!r}".format(ok_types, type(value))
)
return ret
def deep_map_render(func: Callable[[Any, Tuple[Union[str, int], ...]], Any], value: Any) -> Any:
"""This function renders a nested dictionary derived from a yaml
file. It is used to render dbt_project.yml, profiles.yml, and
schema files.
It maps the function func() onto each non-container value in 'value'
recursively, returning a new value. As long as func does not manipulate
the value, then deep_map_render will also not manipulate it.
value should be a value returned by `yaml.safe_load` or `json.load` - the
only expected types are list, dict, native python number, str, NoneType,
and bool.
func() will be called on numbers, strings, Nones, and booleans. Its first
parameter will be the value, and the second will be its keypath, an
iterable over the __getitem__ keys needed to get to it.
:raises: If there are cycles in the value, raises a
dbt.common.exceptions.RecursionError
"""
try:
return _deep_map_render(func, value, ())
except RuntimeError as exc:
if "maximum recursion depth exceeded" in str(exc):
raise RecursionError("Cycle detected in deep_map_render")
raise

View File

@@ -1,56 +0,0 @@
import datetime
import decimal
import hashlib
import json
from typing import Tuple, Type, Any
import jinja2
import sys
DECIMALS: Tuple[Type[Any], ...]
try:
import cdecimal # typing: ignore
except ImportError:
DECIMALS = (decimal.Decimal,)
else:
DECIMALS = (decimal.Decimal, cdecimal.Decimal)
def md5(string, charset="utf-8"):
if sys.version_info >= (3, 9):
return hashlib.md5(string.encode(charset), usedforsecurity=False).hexdigest()
else:
return hashlib.md5(string.encode(charset)).hexdigest()
class JSONEncoder(json.JSONEncoder):
"""A 'custom' json encoder that does normal json encoder things, but also
handles `Decimal`s and `Undefined`s. Decimals can lose precision because
they get converted to floats. Undefined's are serialized to an empty string
"""
def default(self, obj):
if isinstance(obj, DECIMALS):
return float(obj)
elif isinstance(obj, (datetime.datetime, datetime.date, datetime.time)):
return obj.isoformat()
elif isinstance(obj, jinja2.Undefined):
return ""
elif isinstance(obj, Exception):
return repr(obj)
elif hasattr(obj, "to_dict"):
# if we have a to_dict we should try to serialize the result of
# that!
return obj.to_dict(omit_none=True)
else:
return super().default(obj)
class ForgivingJSONEncoder(JSONEncoder):
def default(self, obj):
# let dbt's default JSON encoder handle it if possible, fallback to
# str()
try:
return super().default(obj)
except TypeError:
return str(obj)

View File

@@ -1,67 +0,0 @@
import concurrent.futures
from contextlib import contextmanager
from typing import Protocol, Optional
class ConnectingExecutor(concurrent.futures.Executor):
def submit_connected(self, adapter, conn_name, func, *args, **kwargs):
def connected(conn_name, func, *args, **kwargs):
with self.connection_named(adapter, conn_name):
return func(*args, **kwargs)
return self.submit(connected, conn_name, func, *args, **kwargs)
# a little concurrent.futures.Executor for single-threaded mode
class SingleThreadedExecutor(ConnectingExecutor):
def submit(*args, **kwargs):
# this basic pattern comes from concurrent.futures.Executor itself,
# but without handling the `fn=` form.
if len(args) >= 2:
self, fn, *args = args
elif not args:
raise TypeError(
"descriptor 'submit' of 'SingleThreadedExecutor' object needs an argument"
)
else:
raise TypeError(
"submit expected at least 1 positional argument, got %d" % (len(args) - 1)
)
fut = concurrent.futures.Future()
try:
result = fn(*args, **kwargs)
except Exception as exc:
fut.set_exception(exc)
else:
fut.set_result(result)
return fut
@contextmanager
def connection_named(self, adapter, name):
yield
class MultiThreadedExecutor(
ConnectingExecutor,
concurrent.futures.ThreadPoolExecutor,
):
@contextmanager
def connection_named(self, adapter, name):
with adapter.connection_named(name):
yield
class ThreadedArgs(Protocol):
single_threaded: bool
class HasThreadingConfig(Protocol):
args: ThreadedArgs
threads: Optional[int]
def executor(config: HasThreadingConfig) -> ConnectingExecutor:
if config.args.single_threaded:
return SingleThreadedExecutor()
else:
return MultiThreadedExecutor(max_workers=config.threads)

View File

@@ -1,8 +0,0 @@
from typing import Optional
def lowercase(value: Optional[str]) -> Optional[str]:
if value is None:
return None
else:
return value.lower()

View File

@@ -1,33 +0,0 @@
from dbt.common.exceptions import DbtInternalError
MACRO_PREFIX = "dbt_macro__"
DOCS_PREFIX = "dbt_docs__"
def get_dbt_macro_name(name):
if name is None:
raise DbtInternalError("Got None for a macro name!")
return f"{MACRO_PREFIX}{name}"
def get_dbt_docs_name(name):
if name is None:
raise DbtInternalError("Got None for a doc name!")
return f"{DOCS_PREFIX}{name}"
def get_materialization_macro_name(materialization_name, adapter_type=None, with_prefix=True):
if adapter_type is None:
adapter_type = "default"
name = f"materialization_{materialization_name}_{adapter_type}"
return get_dbt_macro_name(name) if with_prefix else name
def get_docs_macro_name(docs_name, with_prefix=True):
return get_dbt_docs_name(docs_name) if with_prefix else docs_name
def get_test_macro_name(test_name, with_prefix=True):
name = f"test_{test_name}"
return get_dbt_macro_name(name) if with_prefix else name

View File

@@ -7,11 +7,11 @@ import pickle
from collections import defaultdict from collections import defaultdict
from typing import List, Dict, Any, Tuple, Optional from typing import List, Dict, Any, Tuple, Optional
from dbt.common.invocation import get_invocation_id from dbt_common.invocation import get_invocation_id
from dbt.flags import get_flags from dbt.flags import get_flags
from dbt.adapters.factory import get_adapter from dbt.adapters.factory import get_adapter
from dbt.clients import jinja from dbt.clients import jinja
from dbt.common.clients.system import make_directory from dbt_common.clients.system import make_directory
from dbt.context.providers import generate_runtime_model_context from dbt.context.providers import generate_runtime_model_context
from dbt.contracts.graph.manifest import Manifest, UniqueID from dbt.contracts.graph.manifest import Manifest, UniqueID
from dbt.contracts.graph.nodes import ( from dbt.contracts.graph.nodes import (
@@ -28,12 +28,12 @@ from dbt.exceptions import (
DbtRuntimeError, DbtRuntimeError,
) )
from dbt.graph import Graph from dbt.graph import Graph
from dbt.common.events.functions import fire_event from dbt_common.events.functions import fire_event
from dbt.common.events.types import Note from dbt_common.events.types import Note
from dbt.common.events.contextvars import get_node_info from dbt_common.events.contextvars import get_node_info
from dbt.events.types import WritingInjectedSQLForNode, FoundStats from dbt.events.types import WritingInjectedSQLForNode, FoundStats
from dbt.node_types import NodeType, ModelLanguage from dbt.node_types import NodeType, ModelLanguage
from dbt.common.events.format import pluralize from dbt_common.events.format import pluralize
import dbt.tracking import dbt.tracking
import sqlparse import sqlparse

View File

@@ -2,10 +2,10 @@ from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
import os import os
from dbt.common.dataclass_schema import ValidationError from dbt_common.dataclass_schema import ValidationError
from dbt.flags import get_flags from dbt.flags import get_flags
from dbt.common.clients.system import load_file_contents from dbt_common.clients.system import load_file_contents
from dbt.clients.yaml_helper import load_yaml_text from dbt.clients.yaml_helper import load_yaml_text
from dbt.contracts.project import ProfileConfig from dbt.contracts.project import ProfileConfig
from dbt.adapters.contracts.connection import Credentials, HasCredentials from dbt.adapters.contracts.connection import Credentials, HasCredentials
@@ -17,8 +17,8 @@ from dbt.exceptions import (
DbtRuntimeError, DbtRuntimeError,
ProfileConfigError, ProfileConfigError,
) )
from dbt.common.exceptions import DbtValidationError from dbt_common.exceptions import DbtValidationError
from dbt.common.events.functions import fire_event from dbt_common.events.functions import fire_event
from .renderer import ProfileRenderer from .renderer import ProfileRenderer

View File

@@ -22,7 +22,7 @@ from dbt.constants import (
PACKAGE_LOCK_HASH_KEY, PACKAGE_LOCK_HASH_KEY,
DBT_PROJECT_FILE_NAME, DBT_PROJECT_FILE_NAME,
) )
from dbt.common.clients.system import path_exists, load_file_contents from dbt_common.clients.system import path_exists, load_file_contents
from dbt.clients.yaml_helper import load_yaml_text from dbt.clients.yaml_helper import load_yaml_text
from dbt.adapters.contracts.connection import QueryComment from dbt.adapters.contracts.connection import QueryComment
from dbt.exceptions import ( from dbt.exceptions import (
@@ -31,10 +31,10 @@ from dbt.exceptions import (
ProjectContractError, ProjectContractError,
DbtRuntimeError, DbtRuntimeError,
) )
from dbt.common.exceptions import SemverError from dbt_common.exceptions import SemverError
from dbt.graph import SelectionSpec from dbt.graph import SelectionSpec
from dbt.common.helper_types import NoValue from dbt_common.helper_types import NoValue
from dbt.common.semver import VersionSpecifier, versions_compatible from dbt_common.semver import VersionSpecifier, versions_compatible
from dbt.version import get_installed_version from dbt.version import get_installed_version
from dbt.utils import MultiDict, md5, coerce_dict_str from dbt.utils import MultiDict, md5, coerce_dict_str
from dbt.node_types import NodeType from dbt.node_types import NodeType
@@ -45,7 +45,7 @@ from dbt.contracts.project import (
ProjectFlags, ProjectFlags,
) )
from dbt.contracts.project import PackageConfig, ProjectPackageMetadata from dbt.contracts.project import PackageConfig, ProjectPackageMetadata
from dbt.common.dataclass_schema import ValidationError from dbt_common.dataclass_schema import ValidationError
from .renderer import DbtProjectYamlRenderer, PackageRenderer from .renderer import DbtProjectYamlRenderer, PackageRenderer
from .selectors import ( from .selectors import (
selector_config_from_data, selector_config_from_data,

Some files were not shown because too many files have changed in this diff Show More