mirror of
https://github.com/dbt-labs/dbt-core
synced 2025-12-19 09:31:29 +00:00
Compare commits
1 Commits
jerco/sql-
...
testing-pr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
42058de028 |
@@ -41,4 +41,3 @@ first_value = 1
|
||||
[bumpversion:file:plugins/snowflake/dbt/adapters/snowflake/__version__.py]
|
||||
|
||||
[bumpversion:file:plugins/bigquery/dbt/adapters/bigquery/__version__.py]
|
||||
|
||||
|
||||
20
.pre-commit-config.yaml
Normal file
20
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,20 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v2.3.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 20.8b1
|
||||
hooks:
|
||||
- id: black
|
||||
- repo: https://gitlab.com/PyCQA/flake8
|
||||
rev: 3.9.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v0.812
|
||||
hooks:
|
||||
- id: mypy
|
||||
files: ^core/dbt/
|
||||
@@ -139,7 +139,7 @@ brew install postgresql
|
||||
First make sure that you set up your `virtualenv` as described in section _Setting up an environment_. Next, install dbt (and its dependencies) with:
|
||||
|
||||
```
|
||||
pip install -r editable_requirements.txt
|
||||
pip install -r requirements-editable.txt
|
||||
```
|
||||
|
||||
When dbt is installed from source in this way, any changes you make to the dbt source code will be reflected immediately in your next `dbt` run.
|
||||
|
||||
7
Makefile
7
Makefile
@@ -1,10 +1,5 @@
|
||||
.PHONY: install test test-unit test-integration
|
||||
|
||||
changed_tests := `git status --porcelain | grep '^\(M\| M\|A\| A\)' | awk '{ print $$2 }' | grep '\/test_[a-zA-Z_\-\.]\+.py'`
|
||||
|
||||
install:
|
||||
pip install -e .
|
||||
|
||||
test: .env
|
||||
@echo "Full test run starting..."
|
||||
@time docker-compose run --rm test tox
|
||||
@@ -18,7 +13,7 @@ test-integration: .env
|
||||
@time docker-compose run --rm test tox -e integration-postgres-py36,integration-redshift-py36,integration-snowflake-py36,integration-bigquery-py36
|
||||
|
||||
test-quick: .env
|
||||
@echo "Integration test run starting..."
|
||||
@echo "Integration test run starting, will exit on first failure..."
|
||||
@time docker-compose run --rm test tox -e integration-postgres-py36 -- -x
|
||||
|
||||
# This rule creates a file named .env that is used by docker-compose for passing
|
||||
|
||||
@@ -139,7 +139,7 @@ jobs:
|
||||
inputs:
|
||||
versionSpec: '3.7'
|
||||
architecture: 'x64'
|
||||
- script: python -m pip install --upgrade pip setuptools && python -m pip install -r requirements.txt && python -m pip install -r dev_requirements.txt
|
||||
- script: python -m pip install --upgrade pip setuptools && python -m pip install -r requirements.txt && python -m pip install -r requirements-dev.txt
|
||||
displayName: Install dependencies
|
||||
- task: ShellScript@2
|
||||
inputs:
|
||||
|
||||
@@ -63,11 +63,13 @@ def main():
|
||||
packages = registry.packages()
|
||||
project_json = init_project_in_packages(args, packages)
|
||||
if args.project["version"] in project_json["versions"]:
|
||||
raise Exception("Version {} already in packages JSON"
|
||||
.format(args.project["version"]),
|
||||
file=sys.stderr)
|
||||
raise Exception(
|
||||
"Version {} already in packages JSON".format(args.project["version"]),
|
||||
file=sys.stderr,
|
||||
)
|
||||
add_version_to_package(args, project_json)
|
||||
print(json.dumps(packages, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -8,10 +8,10 @@ from dbt.exceptions import RuntimeException
|
||||
@dataclass
|
||||
class Column:
|
||||
TYPE_LABELS: ClassVar[Dict[str, str]] = {
|
||||
'STRING': 'TEXT',
|
||||
'TIMESTAMP': 'TIMESTAMP',
|
||||
'FLOAT': 'FLOAT',
|
||||
'INTEGER': 'INT'
|
||||
"STRING": "TEXT",
|
||||
"TIMESTAMP": "TIMESTAMP",
|
||||
"FLOAT": "FLOAT",
|
||||
"INTEGER": "INT",
|
||||
}
|
||||
column: str
|
||||
dtype: str
|
||||
@@ -24,7 +24,7 @@ class Column:
|
||||
return cls.TYPE_LABELS.get(dtype.upper(), dtype)
|
||||
|
||||
@classmethod
|
||||
def create(cls, name, label_or_dtype: str) -> 'Column':
|
||||
def create(cls, name, label_or_dtype: str) -> "Column":
|
||||
column_type = cls.translate_type(label_or_dtype)
|
||||
return cls(name, column_type)
|
||||
|
||||
@@ -41,14 +41,19 @@ class Column:
|
||||
if self.is_string():
|
||||
return Column.string_type(self.string_size())
|
||||
elif self.is_numeric():
|
||||
return Column.numeric_type(self.dtype, self.numeric_precision,
|
||||
self.numeric_scale)
|
||||
return Column.numeric_type(
|
||||
self.dtype, self.numeric_precision, self.numeric_scale
|
||||
)
|
||||
else:
|
||||
return self.dtype
|
||||
|
||||
def is_string(self) -> bool:
|
||||
return self.dtype.lower() in ['text', 'character varying', 'character',
|
||||
'varchar']
|
||||
return self.dtype.lower() in [
|
||||
"text",
|
||||
"character varying",
|
||||
"character",
|
||||
"varchar",
|
||||
]
|
||||
|
||||
def is_number(self):
|
||||
return any([self.is_integer(), self.is_numeric(), self.is_float()])
|
||||
@@ -56,33 +61,45 @@ class Column:
|
||||
def is_float(self):
|
||||
return self.dtype.lower() in [
|
||||
# floats
|
||||
'real', 'float4', 'float', 'double precision', 'float8'
|
||||
"real",
|
||||
"float4",
|
||||
"float",
|
||||
"double precision",
|
||||
"float8",
|
||||
]
|
||||
|
||||
def is_integer(self) -> bool:
|
||||
return self.dtype.lower() in [
|
||||
# real types
|
||||
'smallint', 'integer', 'bigint',
|
||||
'smallserial', 'serial', 'bigserial',
|
||||
"smallint",
|
||||
"integer",
|
||||
"bigint",
|
||||
"smallserial",
|
||||
"serial",
|
||||
"bigserial",
|
||||
# aliases
|
||||
'int2', 'int4', 'int8',
|
||||
'serial2', 'serial4', 'serial8',
|
||||
"int2",
|
||||
"int4",
|
||||
"int8",
|
||||
"serial2",
|
||||
"serial4",
|
||||
"serial8",
|
||||
]
|
||||
|
||||
def is_numeric(self) -> bool:
|
||||
return self.dtype.lower() in ['numeric', 'decimal']
|
||||
return self.dtype.lower() in ["numeric", "decimal"]
|
||||
|
||||
def string_size(self) -> int:
|
||||
if not self.is_string():
|
||||
raise RuntimeException("Called string_size() on non-string field!")
|
||||
|
||||
if self.dtype == 'text' or self.char_size is None:
|
||||
if self.dtype == "text" or self.char_size is None:
|
||||
# char_size should never be None. Handle it reasonably just in case
|
||||
return 256
|
||||
else:
|
||||
return int(self.char_size)
|
||||
|
||||
def can_expand_to(self, other_column: 'Column') -> bool:
|
||||
def can_expand_to(self, other_column: "Column") -> bool:
|
||||
"""returns True if this column can be expanded to the size of the
|
||||
other column"""
|
||||
if not self.is_string() or not other_column.is_string():
|
||||
@@ -110,12 +127,10 @@ class Column:
|
||||
return "<Column {} ({})>".format(self.name, self.data_type)
|
||||
|
||||
@classmethod
|
||||
def from_description(cls, name: str, raw_data_type: str) -> 'Column':
|
||||
match = re.match(r'([^(]+)(\([^)]+\))?', raw_data_type)
|
||||
def from_description(cls, name: str, raw_data_type: str) -> "Column":
|
||||
match = re.match(r"([^(]+)(\([^)]+\))?", raw_data_type)
|
||||
if match is None:
|
||||
raise RuntimeException(
|
||||
f'Could not interpret data type "{raw_data_type}"'
|
||||
)
|
||||
raise RuntimeException(f'Could not interpret data type "{raw_data_type}"')
|
||||
data_type, size_info = match.groups()
|
||||
char_size = None
|
||||
numeric_precision = None
|
||||
@@ -123,7 +138,7 @@ class Column:
|
||||
if size_info is not None:
|
||||
# strip out the parentheses
|
||||
size_info = size_info[1:-1]
|
||||
parts = size_info.split(',')
|
||||
parts = size_info.split(",")
|
||||
if len(parts) == 1:
|
||||
try:
|
||||
char_size = int(parts[0])
|
||||
@@ -148,6 +163,4 @@ class Column:
|
||||
f'could not convert "{parts[1]}" to an integer'
|
||||
)
|
||||
|
||||
return cls(
|
||||
name, data_type, char_size, numeric_precision, numeric_scale
|
||||
)
|
||||
return cls(name, data_type, char_size, numeric_precision, numeric_scale)
|
||||
|
||||
@@ -1,18 +1,21 @@
|
||||
import abc
|
||||
import os
|
||||
|
||||
# multiprocessing.RLock is a function returning this type
|
||||
from multiprocessing.synchronize import RLock
|
||||
from threading import get_ident
|
||||
from typing import (
|
||||
Dict, Tuple, Hashable, Optional, ContextManager, List, Union
|
||||
)
|
||||
from typing import Dict, Tuple, Hashable, Optional, ContextManager, List, Union
|
||||
|
||||
import agate
|
||||
|
||||
import dbt.exceptions
|
||||
from dbt.contracts.connection import (
|
||||
Connection, Identifier, ConnectionState,
|
||||
AdapterRequiredConfig, LazyHandle, AdapterResponse
|
||||
Connection,
|
||||
Identifier,
|
||||
ConnectionState,
|
||||
AdapterRequiredConfig,
|
||||
LazyHandle,
|
||||
AdapterResponse,
|
||||
)
|
||||
from dbt.contracts.graph.manifest import Manifest
|
||||
from dbt.adapters.base.query_headers import (
|
||||
@@ -35,6 +38,7 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
||||
You must also set the 'TYPE' class attribute with a class-unique constant
|
||||
string.
|
||||
"""
|
||||
|
||||
TYPE: str = NotImplemented
|
||||
|
||||
def __init__(self, profile: AdapterRequiredConfig):
|
||||
@@ -65,7 +69,7 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
||||
key = self.get_thread_identifier()
|
||||
if key in self.thread_connections:
|
||||
raise dbt.exceptions.InternalException(
|
||||
'In set_thread_connection, existing connection exists for {}'
|
||||
"In set_thread_connection, existing connection exists for {}"
|
||||
)
|
||||
self.thread_connections[key] = conn
|
||||
|
||||
@@ -105,18 +109,19 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
||||
underlying database.
|
||||
"""
|
||||
raise dbt.exceptions.NotImplementedException(
|
||||
'`exception_handler` is not implemented for this adapter!')
|
||||
"`exception_handler` is not implemented for this adapter!"
|
||||
)
|
||||
|
||||
def set_connection_name(self, name: Optional[str] = None) -> Connection:
|
||||
conn_name: str
|
||||
if name is None:
|
||||
# if a name isn't specified, we'll re-use a single handle
|
||||
# named 'master'
|
||||
conn_name = 'master'
|
||||
conn_name = "master"
|
||||
else:
|
||||
if not isinstance(name, str):
|
||||
raise dbt.exceptions.CompilerException(
|
||||
f'For connection name, got {name} - not a string!'
|
||||
f"For connection name, got {name} - not a string!"
|
||||
)
|
||||
assert isinstance(name, str)
|
||||
conn_name = name
|
||||
@@ -129,20 +134,20 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
||||
state=ConnectionState.INIT,
|
||||
transaction_open=False,
|
||||
handle=None,
|
||||
credentials=self.profile.credentials
|
||||
credentials=self.profile.credentials,
|
||||
)
|
||||
self.set_thread_connection(conn)
|
||||
|
||||
if conn.name == conn_name and conn.state == 'open':
|
||||
if conn.name == conn_name and conn.state == "open":
|
||||
return conn
|
||||
|
||||
logger.debug(
|
||||
'Acquiring new {} connection "{}".'.format(self.TYPE, conn_name))
|
||||
logger.debug('Acquiring new {} connection "{}".'.format(self.TYPE, conn_name))
|
||||
|
||||
if conn.state == 'open':
|
||||
if conn.state == "open":
|
||||
logger.debug(
|
||||
'Re-using an available connection from the pool (formerly {}).'
|
||||
.format(conn.name)
|
||||
"Re-using an available connection from the pool (formerly {}).".format(
|
||||
conn.name
|
||||
)
|
||||
)
|
||||
else:
|
||||
conn.handle = LazyHandle(self.open)
|
||||
@@ -154,7 +159,7 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
||||
def cancel_open(self) -> Optional[List[str]]:
|
||||
"""Cancel all open connections on the adapter. (passable)"""
|
||||
raise dbt.exceptions.NotImplementedException(
|
||||
'`cancel_open` is not implemented for this adapter!'
|
||||
"`cancel_open` is not implemented for this adapter!"
|
||||
)
|
||||
|
||||
@abc.abstractclassmethod
|
||||
@@ -168,7 +173,7 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
||||
connection should not be in either in_use or available.
|
||||
"""
|
||||
raise dbt.exceptions.NotImplementedException(
|
||||
'`open` is not implemented for this adapter!'
|
||||
"`open` is not implemented for this adapter!"
|
||||
)
|
||||
|
||||
def release(self) -> None:
|
||||
@@ -189,12 +194,14 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
||||
def cleanup_all(self) -> None:
|
||||
with self.lock:
|
||||
for connection in self.thread_connections.values():
|
||||
if connection.state not in {'closed', 'init'}:
|
||||
logger.debug("Connection '{}' was left open."
|
||||
.format(connection.name))
|
||||
if connection.state not in {"closed", "init"}:
|
||||
logger.debug(
|
||||
"Connection '{}' was left open.".format(connection.name)
|
||||
)
|
||||
else:
|
||||
logger.debug("Connection '{}' was properly closed."
|
||||
.format(connection.name))
|
||||
logger.debug(
|
||||
"Connection '{}' was properly closed.".format(connection.name)
|
||||
)
|
||||
self.close(connection)
|
||||
|
||||
# garbage collect these connections
|
||||
@@ -204,14 +211,14 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
||||
def begin(self) -> None:
|
||||
"""Begin a transaction. (passable)"""
|
||||
raise dbt.exceptions.NotImplementedException(
|
||||
'`begin` is not implemented for this adapter!'
|
||||
"`begin` is not implemented for this adapter!"
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
def commit(self) -> None:
|
||||
"""Commit a transaction. (passable)"""
|
||||
raise dbt.exceptions.NotImplementedException(
|
||||
'`commit` is not implemented for this adapter!'
|
||||
"`commit` is not implemented for this adapter!"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -220,20 +227,17 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
||||
try:
|
||||
connection.handle.rollback()
|
||||
except Exception:
|
||||
logger.debug(
|
||||
'Failed to rollback {}'.format(connection.name),
|
||||
exc_info=True
|
||||
)
|
||||
logger.debug("Failed to rollback {}".format(connection.name), exc_info=True)
|
||||
|
||||
@classmethod
|
||||
def _close_handle(cls, connection: Connection) -> None:
|
||||
"""Perform the actual close operation."""
|
||||
# On windows, sometimes connection handles don't have a close() attr.
|
||||
if hasattr(connection.handle, 'close'):
|
||||
logger.debug(f'On {connection.name}: Close')
|
||||
if hasattr(connection.handle, "close"):
|
||||
logger.debug(f"On {connection.name}: Close")
|
||||
connection.handle.close()
|
||||
else:
|
||||
logger.debug(f'On {connection.name}: No close available on handle')
|
||||
logger.debug(f"On {connection.name}: No close available on handle")
|
||||
|
||||
@classmethod
|
||||
def _rollback(cls, connection: Connection) -> None:
|
||||
@@ -241,16 +245,16 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
||||
if flags.STRICT_MODE:
|
||||
if not isinstance(connection, Connection):
|
||||
raise dbt.exceptions.CompilerException(
|
||||
f'In _rollback, got {connection} - not a Connection!'
|
||||
f"In _rollback, got {connection} - not a Connection!"
|
||||
)
|
||||
|
||||
if connection.transaction_open is False:
|
||||
raise dbt.exceptions.InternalException(
|
||||
f'Tried to rollback transaction on connection '
|
||||
f"Tried to rollback transaction on connection "
|
||||
f'"{connection.name}", but it does not have one open!'
|
||||
)
|
||||
|
||||
logger.debug(f'On {connection.name}: ROLLBACK')
|
||||
logger.debug(f"On {connection.name}: ROLLBACK")
|
||||
cls._rollback_handle(connection)
|
||||
|
||||
connection.transaction_open = False
|
||||
@@ -260,7 +264,7 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
||||
if flags.STRICT_MODE:
|
||||
if not isinstance(connection, Connection):
|
||||
raise dbt.exceptions.CompilerException(
|
||||
f'In close, got {connection} - not a Connection!'
|
||||
f"In close, got {connection} - not a Connection!"
|
||||
)
|
||||
|
||||
# if the connection is in closed or init, there's nothing to do
|
||||
@@ -268,7 +272,7 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
||||
return connection
|
||||
|
||||
if connection.transaction_open and connection.handle:
|
||||
logger.debug('On {}: ROLLBACK'.format(connection.name))
|
||||
logger.debug("On {}: ROLLBACK".format(connection.name))
|
||||
cls._rollback_handle(connection)
|
||||
connection.transaction_open = False
|
||||
|
||||
@@ -302,5 +306,5 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
||||
:rtype: Tuple[Union[str, AdapterResponse], agate.Table]
|
||||
"""
|
||||
raise dbt.exceptions.NotImplementedException(
|
||||
'`execute` is not implemented for this adapter!'
|
||||
"`execute` is not implemented for this adapter!"
|
||||
)
|
||||
|
||||
@@ -4,17 +4,31 @@ from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from itertools import chain
|
||||
from typing import (
|
||||
Optional, Tuple, Callable, Iterable, Type, Dict, Any, List, Mapping,
|
||||
Iterator, Union, Set
|
||||
Optional,
|
||||
Tuple,
|
||||
Callable,
|
||||
Iterable,
|
||||
Type,
|
||||
Dict,
|
||||
Any,
|
||||
List,
|
||||
Mapping,
|
||||
Iterator,
|
||||
Union,
|
||||
Set,
|
||||
)
|
||||
|
||||
import agate
|
||||
import pytz
|
||||
|
||||
from dbt.exceptions import (
|
||||
raise_database_error, raise_compiler_error, invalid_type_error,
|
||||
raise_database_error,
|
||||
raise_compiler_error,
|
||||
invalid_type_error,
|
||||
get_relation_returned_multiple_results,
|
||||
InternalException, NotImplementedException, RuntimeException,
|
||||
InternalException,
|
||||
NotImplementedException,
|
||||
RuntimeException,
|
||||
)
|
||||
from dbt import flags
|
||||
|
||||
@@ -25,9 +39,7 @@ from dbt.adapters.protocol import (
|
||||
)
|
||||
from dbt.clients.agate_helper import empty_table, merge_tables, table_from_rows
|
||||
from dbt.clients.jinja import MacroGenerator
|
||||
from dbt.contracts.graph.compiled import (
|
||||
CompileResultNode, CompiledSeedNode
|
||||
)
|
||||
from dbt.contracts.graph.compiled import CompileResultNode, CompiledSeedNode
|
||||
from dbt.contracts.graph.manifest import Manifest, MacroManifest
|
||||
from dbt.contracts.graph.parsed import ParsedSeedNode
|
||||
from dbt.exceptions import warn_or_error
|
||||
@@ -38,7 +50,10 @@ from dbt.utils import filter_null_values, executor
|
||||
from dbt.adapters.base.connections import Connection, AdapterResponse
|
||||
from dbt.adapters.base.meta import AdapterMeta, available
|
||||
from dbt.adapters.base.relation import (
|
||||
ComponentName, BaseRelation, InformationSchema, SchemaSearchMap
|
||||
ComponentName,
|
||||
BaseRelation,
|
||||
InformationSchema,
|
||||
SchemaSearchMap,
|
||||
)
|
||||
from dbt.adapters.base import Column as BaseColumn
|
||||
from dbt.adapters.cache import RelationsCache
|
||||
@@ -47,15 +62,14 @@ from dbt.adapters.cache import RelationsCache
|
||||
SeedModel = Union[ParsedSeedNode, CompiledSeedNode]
|
||||
|
||||
|
||||
GET_CATALOG_MACRO_NAME = 'get_catalog'
|
||||
FRESHNESS_MACRO_NAME = 'collect_freshness'
|
||||
GET_CATALOG_MACRO_NAME = "get_catalog"
|
||||
FRESHNESS_MACRO_NAME = "collect_freshness"
|
||||
|
||||
|
||||
def _expect_row_value(key: str, row: agate.Row):
|
||||
if key not in row.keys():
|
||||
raise InternalException(
|
||||
'Got a row without "{}" column, columns: {}'
|
||||
.format(key, row.keys())
|
||||
'Got a row without "{}" column, columns: {}'.format(key, row.keys())
|
||||
)
|
||||
return row[key]
|
||||
|
||||
@@ -64,40 +78,37 @@ def _catalog_filter_schemas(manifest: Manifest) -> Callable[[agate.Row], bool]:
|
||||
"""Return a function that takes a row and decides if the row should be
|
||||
included in the catalog output.
|
||||
"""
|
||||
schemas = frozenset((d.lower(), s.lower())
|
||||
for d, s in manifest.get_used_schemas())
|
||||
schemas = frozenset((d.lower(), s.lower()) for d, s in manifest.get_used_schemas())
|
||||
|
||||
def test(row: agate.Row) -> bool:
|
||||
table_database = _expect_row_value('table_database', row)
|
||||
table_schema = _expect_row_value('table_schema', row)
|
||||
table_database = _expect_row_value("table_database", row)
|
||||
table_schema = _expect_row_value("table_schema", row)
|
||||
# the schema may be present but None, which is not an error and should
|
||||
# be filtered out
|
||||
if table_schema is None:
|
||||
return False
|
||||
return (table_database.lower(), table_schema.lower()) in schemas
|
||||
|
||||
return test
|
||||
|
||||
|
||||
def _utc(
|
||||
dt: Optional[datetime], source: BaseRelation, field_name: str
|
||||
) -> datetime:
|
||||
def _utc(dt: Optional[datetime], source: BaseRelation, field_name: str) -> datetime:
|
||||
"""If dt has a timezone, return a new datetime that's in UTC. Otherwise,
|
||||
assume the datetime is already for UTC and add the timezone.
|
||||
"""
|
||||
if dt is None:
|
||||
raise raise_database_error(
|
||||
"Expected a non-null value when querying field '{}' of table "
|
||||
" {} but received value 'null' instead".format(
|
||||
field_name,
|
||||
source))
|
||||
" {} but received value 'null' instead".format(field_name, source)
|
||||
)
|
||||
|
||||
elif not hasattr(dt, 'tzinfo'):
|
||||
elif not hasattr(dt, "tzinfo"):
|
||||
raise raise_database_error(
|
||||
"Expected a timestamp value when querying field '{}' of table "
|
||||
"{} but received value of type '{}' instead".format(
|
||||
field_name,
|
||||
source,
|
||||
type(dt).__name__))
|
||||
field_name, source, type(dt).__name__
|
||||
)
|
||||
)
|
||||
|
||||
elif dt.tzinfo:
|
||||
return dt.astimezone(pytz.UTC)
|
||||
@@ -107,7 +118,7 @@ def _utc(
|
||||
|
||||
def _relation_name(rel: Optional[BaseRelation]) -> str:
|
||||
if rel is None:
|
||||
return 'null relation'
|
||||
return "null relation"
|
||||
else:
|
||||
return str(rel)
|
||||
|
||||
@@ -148,6 +159,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
Macros:
|
||||
- get_catalog
|
||||
"""
|
||||
|
||||
Relation: Type[BaseRelation] = BaseRelation
|
||||
Column: Type[BaseColumn] = BaseColumn
|
||||
ConnectionManager: Type[ConnectionManagerProtocol]
|
||||
@@ -181,12 +193,12 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
self.connections.commit_if_has_connection()
|
||||
|
||||
def debug_query(self) -> None:
|
||||
self.execute('select 1 as id')
|
||||
self.execute("select 1 as id")
|
||||
|
||||
def nice_connection_name(self) -> str:
|
||||
conn = self.connections.get_if_exists()
|
||||
if conn is None or conn.name is None:
|
||||
return '<None>'
|
||||
return "<None>"
|
||||
return conn.name
|
||||
|
||||
@contextmanager
|
||||
@@ -204,13 +216,11 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
self.connections.query_header.reset()
|
||||
|
||||
@contextmanager
|
||||
def connection_for(
|
||||
self, node: CompileResultNode
|
||||
) -> Iterator[None]:
|
||||
def connection_for(self, node: CompileResultNode) -> Iterator[None]:
|
||||
with self.connection_named(node.unique_id, node):
|
||||
yield
|
||||
|
||||
@available.parse(lambda *a, **k: ('', empty_table()))
|
||||
@available.parse(lambda *a, **k: ("", empty_table()))
|
||||
def execute(
|
||||
self, sql: str, auto_begin: bool = False, fetch: bool = False
|
||||
) -> Tuple[Union[str, AdapterResponse], agate.Table]:
|
||||
@@ -224,16 +234,10 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
:return: A tuple of the status and the results (empty if fetch=False).
|
||||
:rtype: Tuple[Union[str, AdapterResponse], agate.Table]
|
||||
"""
|
||||
return self.connections.execute(
|
||||
sql=sql,
|
||||
auto_begin=auto_begin,
|
||||
fetch=fetch
|
||||
)
|
||||
return self.connections.execute(sql=sql, auto_begin=auto_begin, fetch=fetch)
|
||||
|
||||
@available.parse(lambda *a, **k: ('', empty_table()))
|
||||
def get_partitions_metadata(
|
||||
self, table: str
|
||||
) -> Tuple[agate.Table]:
|
||||
@available.parse(lambda *a, **k: ("", empty_table()))
|
||||
def get_partitions_metadata(self, table: str) -> Tuple[agate.Table]:
|
||||
"""Obtain partitions metadata for a BigQuery partitioned table.
|
||||
|
||||
:param str table_id: a partitioned table id, in standard SQL format.
|
||||
@@ -241,9 +245,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
https://cloud.google.com/bigquery/docs/creating-partitioned-tables#getting_partition_metadata_using_meta_tables.
|
||||
:rtype: agate.Table
|
||||
"""
|
||||
return self.connections.get_partitions_metadata(
|
||||
table=table
|
||||
)
|
||||
return self.connections.get_partitions_metadata(table=table)
|
||||
|
||||
###
|
||||
# Methods that should never be overridden
|
||||
@@ -274,6 +276,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
if self._macro_manifest_lazy is None:
|
||||
# avoid a circular import
|
||||
from dbt.parser.manifest import load_macro_manifest
|
||||
|
||||
manifest = load_macro_manifest(
|
||||
self.config, self.connections.set_query_header
|
||||
)
|
||||
@@ -294,8 +297,9 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
return False
|
||||
elif (database, schema) not in self.cache:
|
||||
logger.debug(
|
||||
'On "{}": cache miss for schema "{}.{}", this is inefficient'
|
||||
.format(self.nice_connection_name(), database, schema)
|
||||
'On "{}": cache miss for schema "{}.{}", this is inefficient'.format(
|
||||
self.nice_connection_name(), database, schema
|
||||
)
|
||||
)
|
||||
return False
|
||||
else:
|
||||
@@ -310,8 +314,8 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
self.Relation.create_from(self.config, node).without_identifier()
|
||||
for node in manifest.nodes.values()
|
||||
if (
|
||||
node.resource_type in NodeType.executable() and
|
||||
not node.is_ephemeral_model
|
||||
node.resource_type in NodeType.executable()
|
||||
and not node.is_ephemeral_model
|
||||
)
|
||||
}
|
||||
|
||||
@@ -351,9 +355,9 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
for cache_schema in cache_schemas:
|
||||
fut = tpe.submit_connected(
|
||||
self,
|
||||
f'list_{cache_schema.database}_{cache_schema.schema}',
|
||||
f"list_{cache_schema.database}_{cache_schema.schema}",
|
||||
self.list_relations_without_caching,
|
||||
cache_schema
|
||||
cache_schema,
|
||||
)
|
||||
futures.append(fut)
|
||||
|
||||
@@ -371,9 +375,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
cache_update.add((relation.database, relation.schema))
|
||||
self.cache.update_schemas(cache_update)
|
||||
|
||||
def set_relations_cache(
|
||||
self, manifest: Manifest, clear: bool = False
|
||||
) -> None:
|
||||
def set_relations_cache(self, manifest: Manifest, clear: bool = False) -> None:
|
||||
"""Run a query that gets a populated cache of the relations in the
|
||||
database and set the cache on this adapter.
|
||||
"""
|
||||
@@ -391,12 +393,12 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
if relation is None:
|
||||
name = self.nice_connection_name()
|
||||
raise_compiler_error(
|
||||
'Attempted to cache a null relation for {}'.format(name)
|
||||
"Attempted to cache a null relation for {}".format(name)
|
||||
)
|
||||
if flags.USE_CACHE:
|
||||
self.cache.add(relation)
|
||||
# so jinja doesn't render things
|
||||
return ''
|
||||
return ""
|
||||
|
||||
@available
|
||||
def cache_dropped(self, relation: Optional[BaseRelation]) -> str:
|
||||
@@ -406,11 +408,11 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
if relation is None:
|
||||
name = self.nice_connection_name()
|
||||
raise_compiler_error(
|
||||
'Attempted to drop a null relation for {}'.format(name)
|
||||
"Attempted to drop a null relation for {}".format(name)
|
||||
)
|
||||
if flags.USE_CACHE:
|
||||
self.cache.drop(relation)
|
||||
return ''
|
||||
return ""
|
||||
|
||||
@available
|
||||
def cache_renamed(
|
||||
@@ -426,13 +428,12 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
src_name = _relation_name(from_relation)
|
||||
dst_name = _relation_name(to_relation)
|
||||
raise_compiler_error(
|
||||
'Attempted to rename {} to {} for {}'
|
||||
.format(src_name, dst_name, name)
|
||||
"Attempted to rename {} to {} for {}".format(src_name, dst_name, name)
|
||||
)
|
||||
|
||||
if flags.USE_CACHE:
|
||||
self.cache.rename(from_relation, to_relation)
|
||||
return ''
|
||||
return ""
|
||||
|
||||
###
|
||||
# Abstract methods for database-specific values, attributes, and types
|
||||
@@ -441,12 +442,13 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
def date_function(cls) -> str:
|
||||
"""Get the date function used by this adapter's database."""
|
||||
raise NotImplementedException(
|
||||
'`date_function` is not implemented for this adapter!')
|
||||
"`date_function` is not implemented for this adapter!"
|
||||
)
|
||||
|
||||
@abc.abstractclassmethod
|
||||
def is_cancelable(cls) -> bool:
|
||||
raise NotImplementedException(
|
||||
'`is_cancelable` is not implemented for this adapter!'
|
||||
"`is_cancelable` is not implemented for this adapter!"
|
||||
)
|
||||
|
||||
###
|
||||
@@ -456,7 +458,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
def list_schemas(self, database: str) -> List[str]:
|
||||
"""Get a list of existing schemas in database"""
|
||||
raise NotImplementedException(
|
||||
'`list_schemas` is not implemented for this adapter!'
|
||||
"`list_schemas` is not implemented for this adapter!"
|
||||
)
|
||||
|
||||
@available.parse(lambda *a, **k: False)
|
||||
@@ -467,10 +469,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
and adapters should implement it if there is an optimized path (and
|
||||
there probably is)
|
||||
"""
|
||||
search = (
|
||||
s.lower() for s in
|
||||
self.list_schemas(database=database)
|
||||
)
|
||||
search = (s.lower() for s in self.list_schemas(database=database))
|
||||
return schema.lower() in search
|
||||
|
||||
###
|
||||
@@ -484,7 +483,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
*Implementors must call self.cache.drop() to preserve cache state!*
|
||||
"""
|
||||
raise NotImplementedException(
|
||||
'`drop_relation` is not implemented for this adapter!'
|
||||
"`drop_relation` is not implemented for this adapter!"
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -492,7 +491,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
def truncate_relation(self, relation: BaseRelation) -> None:
|
||||
"""Truncate the given relation."""
|
||||
raise NotImplementedException(
|
||||
'`truncate_relation` is not implemented for this adapter!'
|
||||
"`truncate_relation` is not implemented for this adapter!"
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -505,36 +504,30 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
Implementors must call self.cache.rename() to preserve cache state.
|
||||
"""
|
||||
raise NotImplementedException(
|
||||
'`rename_relation` is not implemented for this adapter!'
|
||||
"`rename_relation` is not implemented for this adapter!"
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
@available.parse_list
|
||||
def get_columns_in_relation(
|
||||
self, relation: BaseRelation
|
||||
) -> List[BaseColumn]:
|
||||
def get_columns_in_relation(self, relation: BaseRelation) -> List[BaseColumn]:
|
||||
"""Get a list of the columns in the given Relation."""
|
||||
raise NotImplementedException(
|
||||
'`get_columns_in_relation` is not implemented for this adapter!'
|
||||
"`get_columns_in_relation` is not implemented for this adapter!"
|
||||
)
|
||||
|
||||
@available.deprecated('get_columns_in_relation', lambda *a, **k: [])
|
||||
def get_columns_in_table(
|
||||
self, schema: str, identifier: str
|
||||
) -> List[BaseColumn]:
|
||||
@available.deprecated("get_columns_in_relation", lambda *a, **k: [])
|
||||
def get_columns_in_table(self, schema: str, identifier: str) -> List[BaseColumn]:
|
||||
"""DEPRECATED: Get a list of the columns in the given table."""
|
||||
relation = self.Relation.create(
|
||||
database=self.config.credentials.database,
|
||||
schema=schema,
|
||||
identifier=identifier,
|
||||
quote_policy=self.config.quoting
|
||||
quote_policy=self.config.quoting,
|
||||
)
|
||||
return self.get_columns_in_relation(relation)
|
||||
|
||||
@abc.abstractmethod
|
||||
def expand_column_types(
|
||||
self, goal: BaseRelation, current: BaseRelation
|
||||
) -> None:
|
||||
def expand_column_types(self, goal: BaseRelation, current: BaseRelation) -> None:
|
||||
"""Expand the current table's types to match the goal table. (passable)
|
||||
|
||||
:param self.Relation goal: A relation that currently exists in the
|
||||
@@ -543,7 +536,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
database with columns of unspecified types.
|
||||
"""
|
||||
raise NotImplementedException(
|
||||
'`expand_target_column_types` is not implemented for this adapter!'
|
||||
"`expand_target_column_types` is not implemented for this adapter!"
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -560,8 +553,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
:rtype: List[self.Relation]
|
||||
"""
|
||||
raise NotImplementedException(
|
||||
'`list_relations_without_caching` is not implemented for this '
|
||||
'adapter!'
|
||||
"`list_relations_without_caching` is not implemented for this " "adapter!"
|
||||
)
|
||||
|
||||
###
|
||||
@@ -576,32 +568,33 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
"""
|
||||
if not isinstance(from_relation, self.Relation):
|
||||
invalid_type_error(
|
||||
method_name='get_missing_columns',
|
||||
arg_name='from_relation',
|
||||
method_name="get_missing_columns",
|
||||
arg_name="from_relation",
|
||||
got_value=from_relation,
|
||||
expected_type=self.Relation)
|
||||
expected_type=self.Relation,
|
||||
)
|
||||
|
||||
if not isinstance(to_relation, self.Relation):
|
||||
invalid_type_error(
|
||||
method_name='get_missing_columns',
|
||||
arg_name='to_relation',
|
||||
method_name="get_missing_columns",
|
||||
arg_name="to_relation",
|
||||
got_value=to_relation,
|
||||
expected_type=self.Relation)
|
||||
expected_type=self.Relation,
|
||||
)
|
||||
|
||||
from_columns = {
|
||||
col.name: col for col in
|
||||
self.get_columns_in_relation(from_relation)
|
||||
col.name: col for col in self.get_columns_in_relation(from_relation)
|
||||
}
|
||||
|
||||
to_columns = {
|
||||
col.name: col for col in
|
||||
self.get_columns_in_relation(to_relation)
|
||||
col.name: col for col in self.get_columns_in_relation(to_relation)
|
||||
}
|
||||
|
||||
missing_columns = set(from_columns.keys()) - set(to_columns.keys())
|
||||
|
||||
return [
|
||||
col for (col_name, col) in from_columns.items()
|
||||
col
|
||||
for (col_name, col) in from_columns.items()
|
||||
if col_name in missing_columns
|
||||
]
|
||||
|
||||
@@ -616,18 +609,19 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
"""
|
||||
if not isinstance(relation, self.Relation):
|
||||
invalid_type_error(
|
||||
method_name='valid_snapshot_target',
|
||||
arg_name='relation',
|
||||
method_name="valid_snapshot_target",
|
||||
arg_name="relation",
|
||||
got_value=relation,
|
||||
expected_type=self.Relation)
|
||||
expected_type=self.Relation,
|
||||
)
|
||||
|
||||
columns = self.get_columns_in_relation(relation)
|
||||
names = set(c.name.lower() for c in columns)
|
||||
expanded_keys = ('scd_id', 'valid_from', 'valid_to')
|
||||
expanded_keys = ("scd_id", "valid_from", "valid_to")
|
||||
extra = []
|
||||
missing = []
|
||||
for legacy in expanded_keys:
|
||||
desired = 'dbt_' + legacy
|
||||
desired = "dbt_" + legacy
|
||||
if desired not in names:
|
||||
missing.append(desired)
|
||||
if legacy in names:
|
||||
@@ -637,13 +631,13 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
if extra:
|
||||
msg = (
|
||||
'Snapshot target has ("{}") but not ("{}") - is it an '
|
||||
'unmigrated previous version archive?'
|
||||
.format('", "'.join(extra), '", "'.join(missing))
|
||||
"unmigrated previous version archive?".format(
|
||||
'", "'.join(extra), '", "'.join(missing)
|
||||
)
|
||||
)
|
||||
else:
|
||||
msg = (
|
||||
'Snapshot target is not a snapshot table (missing "{}")'
|
||||
.format('", "'.join(missing))
|
||||
msg = 'Snapshot target is not a snapshot table (missing "{}")'.format(
|
||||
'", "'.join(missing)
|
||||
)
|
||||
raise_compiler_error(msg)
|
||||
|
||||
@@ -653,17 +647,19 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
) -> None:
|
||||
if not isinstance(from_relation, self.Relation):
|
||||
invalid_type_error(
|
||||
method_name='expand_target_column_types',
|
||||
arg_name='from_relation',
|
||||
method_name="expand_target_column_types",
|
||||
arg_name="from_relation",
|
||||
got_value=from_relation,
|
||||
expected_type=self.Relation)
|
||||
expected_type=self.Relation,
|
||||
)
|
||||
|
||||
if not isinstance(to_relation, self.Relation):
|
||||
invalid_type_error(
|
||||
method_name='expand_target_column_types',
|
||||
arg_name='to_relation',
|
||||
method_name="expand_target_column_types",
|
||||
arg_name="to_relation",
|
||||
got_value=to_relation,
|
||||
expected_type=self.Relation)
|
||||
expected_type=self.Relation,
|
||||
)
|
||||
|
||||
self.expand_column_types(from_relation, to_relation)
|
||||
|
||||
@@ -676,38 +672,41 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
schema_relation = self.Relation.create(
|
||||
database=database,
|
||||
schema=schema,
|
||||
identifier='',
|
||||
quote_policy=self.config.quoting
|
||||
identifier="",
|
||||
quote_policy=self.config.quoting,
|
||||
).without_identifier()
|
||||
|
||||
# we can't build the relations cache because we don't have a
|
||||
# manifest so we can't run any operations.
|
||||
relations = self.list_relations_without_caching(
|
||||
schema_relation
|
||||
)
|
||||
relations = self.list_relations_without_caching(schema_relation)
|
||||
|
||||
logger.debug('with database={}, schema={}, relations={}'
|
||||
.format(database, schema, relations))
|
||||
logger.debug(
|
||||
"with database={}, schema={}, relations={}".format(
|
||||
database, schema, relations
|
||||
)
|
||||
)
|
||||
return relations
|
||||
|
||||
def _make_match_kwargs(
|
||||
self, database: str, schema: str, identifier: str
|
||||
) -> Dict[str, str]:
|
||||
quoting = self.config.quoting
|
||||
if identifier is not None and quoting['identifier'] is False:
|
||||
if identifier is not None and quoting["identifier"] is False:
|
||||
identifier = identifier.lower()
|
||||
|
||||
if schema is not None and quoting['schema'] is False:
|
||||
if schema is not None and quoting["schema"] is False:
|
||||
schema = schema.lower()
|
||||
|
||||
if database is not None and quoting['database'] is False:
|
||||
if database is not None and quoting["database"] is False:
|
||||
database = database.lower()
|
||||
|
||||
return filter_null_values({
|
||||
'database': database,
|
||||
'identifier': identifier,
|
||||
'schema': schema,
|
||||
})
|
||||
return filter_null_values(
|
||||
{
|
||||
"database": database,
|
||||
"identifier": identifier,
|
||||
"schema": schema,
|
||||
}
|
||||
)
|
||||
|
||||
def _make_match(
|
||||
self,
|
||||
@@ -733,25 +732,22 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
) -> Optional[BaseRelation]:
|
||||
relations_list = self.list_relations(database, schema)
|
||||
|
||||
matches = self._make_match(relations_list, database, schema,
|
||||
identifier)
|
||||
matches = self._make_match(relations_list, database, schema, identifier)
|
||||
|
||||
if len(matches) > 1:
|
||||
kwargs = {
|
||||
'identifier': identifier,
|
||||
'schema': schema,
|
||||
'database': database,
|
||||
"identifier": identifier,
|
||||
"schema": schema,
|
||||
"database": database,
|
||||
}
|
||||
get_relation_returned_multiple_results(
|
||||
kwargs, matches
|
||||
)
|
||||
get_relation_returned_multiple_results(kwargs, matches)
|
||||
|
||||
elif matches:
|
||||
return matches[0]
|
||||
|
||||
return None
|
||||
|
||||
@available.deprecated('get_relation', lambda *a, **k: False)
|
||||
@available.deprecated("get_relation", lambda *a, **k: False)
|
||||
def already_exists(self, schema: str, name: str) -> bool:
|
||||
"""DEPRECATED: Return if a model already exists in the database"""
|
||||
database = self.config.credentials.database
|
||||
@@ -767,7 +763,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
def create_schema(self, relation: BaseRelation):
|
||||
"""Create the given schema if it does not exist."""
|
||||
raise NotImplementedException(
|
||||
'`create_schema` is not implemented for this adapter!'
|
||||
"`create_schema` is not implemented for this adapter!"
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -775,16 +771,14 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
def drop_schema(self, relation: BaseRelation):
|
||||
"""Drop the given schema (and everything in it) if it exists."""
|
||||
raise NotImplementedException(
|
||||
'`drop_schema` is not implemented for this adapter!'
|
||||
"`drop_schema` is not implemented for this adapter!"
|
||||
)
|
||||
|
||||
@available
|
||||
@abc.abstractclassmethod
|
||||
def quote(cls, identifier: str) -> str:
|
||||
"""Quote the given identifier, as appropriate for the database."""
|
||||
raise NotImplementedException(
|
||||
'`quote` is not implemented for this adapter!'
|
||||
)
|
||||
raise NotImplementedException("`quote` is not implemented for this adapter!")
|
||||
|
||||
@available
|
||||
def quote_as_configured(self, identifier: str, quote_key: str) -> str:
|
||||
@@ -806,19 +800,17 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
return identifier
|
||||
|
||||
@available
|
||||
def quote_seed_column(
|
||||
self, column: str, quote_config: Optional[bool]
|
||||
) -> str:
|
||||
def quote_seed_column(self, column: str, quote_config: Optional[bool]) -> str:
|
||||
# this is the default for now
|
||||
quote_columns: bool = False
|
||||
if isinstance(quote_config, bool):
|
||||
quote_columns = quote_config
|
||||
elif quote_config is None:
|
||||
deprecations.warn('column-quoting-unset')
|
||||
deprecations.warn("column-quoting-unset")
|
||||
else:
|
||||
raise_compiler_error(
|
||||
f'The seed configuration value of "quote_columns" has an '
|
||||
f'invalid type {type(quote_config)}'
|
||||
f"invalid type {type(quote_config)}"
|
||||
)
|
||||
|
||||
if quote_columns:
|
||||
@@ -831,9 +823,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
# converting agate types into their sql equivalents.
|
||||
###
|
||||
@abc.abstractclassmethod
|
||||
def convert_text_type(
|
||||
cls, agate_table: agate.Table, col_idx: int
|
||||
) -> str:
|
||||
def convert_text_type(cls, agate_table: agate.Table, col_idx: int) -> str:
|
||||
"""Return the type in the database that best maps to the agate.Text
|
||||
type for the given agate table and column index.
|
||||
|
||||
@@ -842,12 +832,11 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
:return: The name of the type in the database
|
||||
"""
|
||||
raise NotImplementedException(
|
||||
'`convert_text_type` is not implemented for this adapter!')
|
||||
"`convert_text_type` is not implemented for this adapter!"
|
||||
)
|
||||
|
||||
@abc.abstractclassmethod
|
||||
def convert_number_type(
|
||||
cls, agate_table: agate.Table, col_idx: int
|
||||
) -> str:
|
||||
def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str:
|
||||
"""Return the type in the database that best maps to the agate.Number
|
||||
type for the given agate table and column index.
|
||||
|
||||
@@ -856,12 +845,11 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
:return: The name of the type in the database
|
||||
"""
|
||||
raise NotImplementedException(
|
||||
'`convert_number_type` is not implemented for this adapter!')
|
||||
"`convert_number_type` is not implemented for this adapter!"
|
||||
)
|
||||
|
||||
@abc.abstractclassmethod
|
||||
def convert_boolean_type(
|
||||
cls, agate_table: agate.Table, col_idx: int
|
||||
) -> str:
|
||||
def convert_boolean_type(cls, agate_table: agate.Table, col_idx: int) -> str:
|
||||
"""Return the type in the database that best maps to the agate.Boolean
|
||||
type for the given agate table and column index.
|
||||
|
||||
@@ -870,12 +858,11 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
:return: The name of the type in the database
|
||||
"""
|
||||
raise NotImplementedException(
|
||||
'`convert_boolean_type` is not implemented for this adapter!')
|
||||
"`convert_boolean_type` is not implemented for this adapter!"
|
||||
)
|
||||
|
||||
@abc.abstractclassmethod
|
||||
def convert_datetime_type(
|
||||
cls, agate_table: agate.Table, col_idx: int
|
||||
) -> str:
|
||||
def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str:
|
||||
"""Return the type in the database that best maps to the agate.DateTime
|
||||
type for the given agate table and column index.
|
||||
|
||||
@@ -884,7 +871,8 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
:return: The name of the type in the database
|
||||
"""
|
||||
raise NotImplementedException(
|
||||
'`convert_datetime_type` is not implemented for this adapter!')
|
||||
"`convert_datetime_type` is not implemented for this adapter!"
|
||||
)
|
||||
|
||||
@abc.abstractclassmethod
|
||||
def convert_date_type(cls, agate_table: agate.Table, col_idx: int) -> str:
|
||||
@@ -896,7 +884,8 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
:return: The name of the type in the database
|
||||
"""
|
||||
raise NotImplementedException(
|
||||
'`convert_date_type` is not implemented for this adapter!')
|
||||
"`convert_date_type` is not implemented for this adapter!"
|
||||
)
|
||||
|
||||
@abc.abstractclassmethod
|
||||
def convert_time_type(cls, agate_table: agate.Table, col_idx: int) -> str:
|
||||
@@ -908,13 +897,12 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
:return: The name of the type in the database
|
||||
"""
|
||||
raise NotImplementedException(
|
||||
'`convert_time_type` is not implemented for this adapter!')
|
||||
"`convert_time_type` is not implemented for this adapter!"
|
||||
)
|
||||
|
||||
@available
|
||||
@classmethod
|
||||
def convert_type(
|
||||
cls, agate_table: agate.Table, col_idx: int
|
||||
) -> Optional[str]:
|
||||
def convert_type(cls, agate_table: agate.Table, col_idx: int) -> Optional[str]:
|
||||
return cls.convert_agate_type(agate_table, col_idx)
|
||||
|
||||
@classmethod
|
||||
@@ -963,7 +951,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
:param release: Ignored.
|
||||
"""
|
||||
if release is not False:
|
||||
deprecations.warn('execute-macro-release')
|
||||
deprecations.warn("execute-macro-release")
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if context_override is None:
|
||||
@@ -977,28 +965,27 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
)
|
||||
if macro is None:
|
||||
if project is None:
|
||||
package_name = 'any package'
|
||||
package_name = "any package"
|
||||
else:
|
||||
package_name = 'the "{}" package'.format(project)
|
||||
|
||||
raise RuntimeException(
|
||||
'dbt could not find a macro with the name "{}" in {}'
|
||||
.format(macro_name, package_name)
|
||||
'dbt could not find a macro with the name "{}" in {}'.format(
|
||||
macro_name, package_name
|
||||
)
|
||||
)
|
||||
# This causes a reference cycle, as generate_runtime_macro()
|
||||
# ends up calling get_adapter, so the import has to be here.
|
||||
from dbt.context.providers import generate_runtime_macro
|
||||
|
||||
macro_context = generate_runtime_macro(
|
||||
macro=macro,
|
||||
config=self.config,
|
||||
manifest=manifest,
|
||||
package_name=project
|
||||
macro=macro, config=self.config, manifest=manifest, package_name=project
|
||||
)
|
||||
macro_context.update(context_override)
|
||||
|
||||
macro_function = MacroGenerator(macro, macro_context)
|
||||
|
||||
with self.connections.exception_handler(f'macro {macro_name}'):
|
||||
with self.connections.exception_handler(f"macro {macro_name}"):
|
||||
result = macro_function(**kwargs)
|
||||
return result
|
||||
|
||||
@@ -1013,7 +1000,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
table = table_from_rows(
|
||||
table.rows,
|
||||
table.column_names,
|
||||
text_only_columns=['table_database', 'table_schema', 'table_name']
|
||||
text_only_columns=["table_database", "table_schema", "table_name"],
|
||||
)
|
||||
return table.where(_catalog_filter_schemas(manifest))
|
||||
|
||||
@@ -1024,10 +1011,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
manifest: Manifest,
|
||||
) -> agate.Table:
|
||||
|
||||
kwargs = {
|
||||
'information_schema': information_schema,
|
||||
'schemas': schemas
|
||||
}
|
||||
kwargs = {"information_schema": information_schema, "schemas": schemas}
|
||||
table = self.execute_macro(
|
||||
GET_CATALOG_MACRO_NAME,
|
||||
kwargs=kwargs,
|
||||
@@ -1039,9 +1023,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
results = self._catalog_filter_table(table, manifest)
|
||||
return results
|
||||
|
||||
def get_catalog(
|
||||
self, manifest: Manifest
|
||||
) -> Tuple[agate.Table, List[Exception]]:
|
||||
def get_catalog(self, manifest: Manifest) -> Tuple[agate.Table, List[Exception]]:
|
||||
schema_map = self._get_catalog_schemas(manifest)
|
||||
|
||||
with executor(self.config) as tpe:
|
||||
@@ -1049,14 +1031,10 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
for info, schemas in schema_map.items():
|
||||
if len(schemas) == 0:
|
||||
continue
|
||||
name = '.'.join([
|
||||
str(info.database),
|
||||
'information_schema'
|
||||
])
|
||||
name = ".".join([str(info.database), "information_schema"])
|
||||
|
||||
fut = tpe.submit_connected(
|
||||
self, name,
|
||||
self._get_one_catalog, info, schemas, manifest
|
||||
self, name, self._get_one_catalog, info, schemas, manifest
|
||||
)
|
||||
futures.append(fut)
|
||||
|
||||
@@ -1073,20 +1051,18 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
source: BaseRelation,
|
||||
loaded_at_field: str,
|
||||
filter: Optional[str],
|
||||
manifest: Optional[Manifest] = None
|
||||
manifest: Optional[Manifest] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Calculate the freshness of sources in dbt, and return it"""
|
||||
kwargs: Dict[str, Any] = {
|
||||
'source': source,
|
||||
'loaded_at_field': loaded_at_field,
|
||||
'filter': filter,
|
||||
"source": source,
|
||||
"loaded_at_field": loaded_at_field,
|
||||
"filter": filter,
|
||||
}
|
||||
|
||||
# run the macro
|
||||
table = self.execute_macro(
|
||||
FRESHNESS_MACRO_NAME,
|
||||
kwargs=kwargs,
|
||||
manifest=manifest
|
||||
FRESHNESS_MACRO_NAME, kwargs=kwargs, manifest=manifest
|
||||
)
|
||||
# now we have a 1-row table of the maximum `loaded_at_field` value and
|
||||
# the current time according to the db.
|
||||
@@ -1106,9 +1082,9 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
snapshotted_at = _utc(table[0][1], source, loaded_at_field)
|
||||
age = (snapshotted_at - max_loaded_at).total_seconds()
|
||||
return {
|
||||
'max_loaded_at': max_loaded_at,
|
||||
'snapshotted_at': snapshotted_at,
|
||||
'age': age,
|
||||
"max_loaded_at": max_loaded_at,
|
||||
"snapshotted_at": snapshotted_at,
|
||||
"age": age,
|
||||
}
|
||||
|
||||
def pre_model_hook(self, config: Mapping[str, Any]) -> Any:
|
||||
@@ -1138,6 +1114,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
|
||||
def get_compiler(self):
|
||||
from dbt.compilation import Compiler
|
||||
|
||||
return Compiler(self.config)
|
||||
|
||||
# Methods used in adapter tests
|
||||
@@ -1148,13 +1125,13 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
clause: str,
|
||||
where_clause: Optional[str] = None,
|
||||
) -> str:
|
||||
clause = f'update {dst_name} set {dst_column} = {clause}'
|
||||
clause = f"update {dst_name} set {dst_column} = {clause}"
|
||||
if where_clause is not None:
|
||||
clause += f' where {where_clause}'
|
||||
clause += f" where {where_clause}"
|
||||
return clause
|
||||
|
||||
def timestamp_add_sql(
|
||||
self, add_to: str, number: int = 1, interval: str = 'hour'
|
||||
self, add_to: str, number: int = 1, interval: str = "hour"
|
||||
) -> str:
|
||||
# for backwards compatibility, we're compelled to set some sort of
|
||||
# default. A lot of searching has lead me to believe that the
|
||||
@@ -1163,23 +1140,24 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
return f"{add_to} + interval '{number} {interval}'"
|
||||
|
||||
def string_add_sql(
|
||||
self, add_to: str, value: str, location='append',
|
||||
self,
|
||||
add_to: str,
|
||||
value: str,
|
||||
location="append",
|
||||
) -> str:
|
||||
if location == 'append':
|
||||
if location == "append":
|
||||
return f"{add_to} || '{value}'"
|
||||
elif location == 'prepend':
|
||||
elif location == "prepend":
|
||||
return f"'{value}' || {add_to}"
|
||||
else:
|
||||
raise RuntimeException(
|
||||
f'Got an unexpected location value of "{location}"'
|
||||
)
|
||||
raise RuntimeException(f'Got an unexpected location value of "{location}"')
|
||||
|
||||
def get_rows_different_sql(
|
||||
self,
|
||||
relation_a: BaseRelation,
|
||||
relation_b: BaseRelation,
|
||||
column_names: Optional[List[str]] = None,
|
||||
except_operator: str = 'EXCEPT',
|
||||
except_operator: str = "EXCEPT",
|
||||
) -> str:
|
||||
"""Generate SQL for a query that returns a single row with a two
|
||||
columns: the number of rows that are different between the two
|
||||
@@ -1192,7 +1170,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
names = sorted((self.quote(c.name) for c in columns))
|
||||
else:
|
||||
names = sorted((self.quote(n) for n in column_names))
|
||||
columns_csv = ', '.join(names)
|
||||
columns_csv = ", ".join(names)
|
||||
|
||||
sql = COLUMNS_EQUAL_SQL.format(
|
||||
columns=columns_csv,
|
||||
@@ -1204,7 +1182,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
||||
return sql
|
||||
|
||||
|
||||
COLUMNS_EQUAL_SQL = '''
|
||||
COLUMNS_EQUAL_SQL = """
|
||||
with diff_count as (
|
||||
SELECT
|
||||
1 as id,
|
||||
@@ -1230,11 +1208,11 @@ select
|
||||
diff_count.num_missing as num_mismatched
|
||||
from row_count_diff
|
||||
join diff_count using (id)
|
||||
'''.strip()
|
||||
""".strip()
|
||||
|
||||
|
||||
def catch_as_completed(
|
||||
futures # typing: List[Future[agate.Table]]
|
||||
futures, # typing: List[Future[agate.Table]]
|
||||
) -> Tuple[agate.Table, List[Exception]]:
|
||||
|
||||
# catalogs: agate.Table = agate.Table(rows=[])
|
||||
@@ -1247,15 +1225,10 @@ def catch_as_completed(
|
||||
if exc is None:
|
||||
catalog = future.result()
|
||||
tables.append(catalog)
|
||||
elif (
|
||||
isinstance(exc, KeyboardInterrupt) or
|
||||
not isinstance(exc, Exception)
|
||||
):
|
||||
elif isinstance(exc, KeyboardInterrupt) or not isinstance(exc, Exception):
|
||||
raise exc
|
||||
else:
|
||||
warn_or_error(
|
||||
f'Encountered an error while generating catalog: {str(exc)}'
|
||||
)
|
||||
warn_or_error(f"Encountered an error while generating catalog: {str(exc)}")
|
||||
# exc is not None, derives from Exception, and isn't ctrl+c
|
||||
exceptions.append(exc)
|
||||
return merge_tables(tables), exceptions
|
||||
|
||||
@@ -30,9 +30,11 @@ class _Available:
|
||||
x.update(big_expensive_db_query())
|
||||
return x
|
||||
"""
|
||||
|
||||
def inner(func):
|
||||
func._parse_replacement_ = parse_replacement
|
||||
return self(func)
|
||||
|
||||
return inner
|
||||
|
||||
def deprecated(
|
||||
@@ -57,13 +59,14 @@ class _Available:
|
||||
The optional parse_replacement, if provided, will provide a parse-time
|
||||
replacement for the actual method (see `available.parse`).
|
||||
"""
|
||||
|
||||
def wrapper(func):
|
||||
func_name = func.__name__
|
||||
renamed_method(func_name, supported_name)
|
||||
|
||||
@wraps(func)
|
||||
def inner(*args, **kwargs):
|
||||
warn('adapter:{}'.format(func_name))
|
||||
warn("adapter:{}".format(func_name))
|
||||
return func(*args, **kwargs)
|
||||
|
||||
if parse_replacement:
|
||||
@@ -71,6 +74,7 @@ class _Available:
|
||||
else:
|
||||
available_function = self
|
||||
return available_function(inner)
|
||||
|
||||
return wrapper
|
||||
|
||||
def parse_none(self, func: Callable) -> Callable:
|
||||
@@ -109,14 +113,14 @@ class AdapterMeta(abc.ABCMeta):
|
||||
|
||||
# collect base class data first
|
||||
for base in bases:
|
||||
available.update(getattr(base, '_available_', set()))
|
||||
replacements.update(getattr(base, '_parse_replacements_', set()))
|
||||
available.update(getattr(base, "_available_", set()))
|
||||
replacements.update(getattr(base, "_parse_replacements_", set()))
|
||||
|
||||
# override with local data if it exists
|
||||
for name, value in namespace.items():
|
||||
if getattr(value, '_is_available_', False):
|
||||
if getattr(value, "_is_available_", False):
|
||||
available.add(name)
|
||||
parse_replacement = getattr(value, '_parse_replacement_', None)
|
||||
parse_replacement = getattr(value, "_parse_replacement_", None)
|
||||
if parse_replacement is not None:
|
||||
replacements[name] = parse_replacement
|
||||
|
||||
|
||||
@@ -8,11 +8,10 @@ from dbt.adapters.protocol import AdapterProtocol
|
||||
def project_name_from_path(include_path: str) -> str:
|
||||
# avoid an import cycle
|
||||
from dbt.config.project import Project
|
||||
|
||||
partial = Project.partial_load(include_path)
|
||||
if partial.project_name is None:
|
||||
raise CompilationException(
|
||||
f'Invalid project at {include_path}: name not set!'
|
||||
)
|
||||
raise CompilationException(f"Invalid project at {include_path}: name not set!")
|
||||
return partial.project_name
|
||||
|
||||
|
||||
@@ -23,12 +22,13 @@ class AdapterPlugin:
|
||||
:param dependencies: A list of adapter names that this adapter depends
|
||||
upon.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
adapter: Type[AdapterProtocol],
|
||||
credentials: Type[Credentials],
|
||||
include_path: str,
|
||||
dependencies: Optional[List[str]] = None
|
||||
dependencies: Optional[List[str]] = None,
|
||||
):
|
||||
|
||||
self.adapter: Type[AdapterProtocol] = adapter
|
||||
|
||||
@@ -15,7 +15,7 @@ class NodeWrapper:
|
||||
self._inner_node = node
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._inner_node, name, '')
|
||||
return getattr(self._inner_node, name, "")
|
||||
|
||||
|
||||
class _QueryComment(local):
|
||||
@@ -24,6 +24,7 @@ class _QueryComment(local):
|
||||
- the current thread's query comment.
|
||||
- a source_name indicating what set the current thread's query comment
|
||||
"""
|
||||
|
||||
def __init__(self, initial):
|
||||
self.query_comment: Optional[str] = initial
|
||||
self.append = False
|
||||
@@ -35,16 +36,16 @@ class _QueryComment(local):
|
||||
if self.append:
|
||||
# replace last ';' with '<comment>;'
|
||||
sql = sql.rstrip()
|
||||
if sql[-1] == ';':
|
||||
if sql[-1] == ";":
|
||||
sql = sql[:-1]
|
||||
return '{}\n/* {} */;'.format(sql, self.query_comment.strip())
|
||||
return "{}\n/* {} */;".format(sql, self.query_comment.strip())
|
||||
|
||||
return '{}\n/* {} */'.format(sql, self.query_comment.strip())
|
||||
return "{}\n/* {} */".format(sql, self.query_comment.strip())
|
||||
|
||||
return '/* {} */\n{}'.format(self.query_comment.strip(), sql)
|
||||
return "/* {} */\n{}".format(self.query_comment.strip(), sql)
|
||||
|
||||
def set(self, comment: Optional[str], append: bool):
|
||||
if isinstance(comment, str) and '*/' in comment:
|
||||
if isinstance(comment, str) and "*/" in comment:
|
||||
# tell the user "no" so they don't hurt themselves by writing
|
||||
# garbage
|
||||
raise RuntimeException(
|
||||
@@ -63,15 +64,17 @@ class MacroQueryStringSetter:
|
||||
self.config = config
|
||||
|
||||
comment_macro = self._get_comment_macro()
|
||||
self.generator: QueryStringFunc = lambda name, model: ''
|
||||
self.generator: QueryStringFunc = lambda name, model: ""
|
||||
# if the comment value was None or the empty string, just skip it
|
||||
if comment_macro:
|
||||
assert isinstance(comment_macro, str)
|
||||
macro = '\n'.join((
|
||||
'{%- macro query_comment_macro(connection_name, node) -%}',
|
||||
macro = "\n".join(
|
||||
(
|
||||
"{%- macro query_comment_macro(connection_name, node) -%}",
|
||||
comment_macro,
|
||||
'{% endmacro %}'
|
||||
))
|
||||
"{% endmacro %}",
|
||||
)
|
||||
)
|
||||
ctx = self._get_context()
|
||||
self.generator = QueryStringGenerator(macro, ctx)
|
||||
self.comment = _QueryComment(None)
|
||||
@@ -87,7 +90,7 @@ class MacroQueryStringSetter:
|
||||
return self.comment.add(sql)
|
||||
|
||||
def reset(self):
|
||||
self.set('master', None)
|
||||
self.set("master", None)
|
||||
|
||||
def set(self, name: str, node: Optional[CompileResultNode]):
|
||||
wrapped: Optional[NodeWrapper] = None
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
from collections.abc import Hashable
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Optional, TypeVar, Any, Type, Dict, Union, Iterator, Tuple, Set
|
||||
)
|
||||
from typing import Optional, TypeVar, Any, Type, Dict, Union, Iterator, Tuple, Set
|
||||
|
||||
from dbt.contracts.graph.compiled import CompiledNode
|
||||
from dbt.contracts.graph.parsed import ParsedSourceDefinition, ParsedNode
|
||||
from dbt.contracts.relation import (
|
||||
RelationType, ComponentName, HasQuoting, FakeAPIObject, Policy, Path
|
||||
RelationType,
|
||||
ComponentName,
|
||||
HasQuoting,
|
||||
FakeAPIObject,
|
||||
Policy,
|
||||
Path,
|
||||
)
|
||||
from dbt.exceptions import InternalException
|
||||
from dbt.node_types import NodeType
|
||||
@@ -16,7 +19,7 @@ from dbt.utils import filter_null_values, deep_merge, classproperty
|
||||
import dbt.exceptions
|
||||
|
||||
|
||||
Self = TypeVar('Self', bound='BaseRelation')
|
||||
Self = TypeVar("Self", bound="BaseRelation")
|
||||
|
||||
|
||||
@dataclass(frozen=True, eq=False, repr=False)
|
||||
@@ -40,7 +43,7 @@ class BaseRelation(FakeAPIObject, Hashable):
|
||||
if field.name == field_name:
|
||||
return field
|
||||
# this should be unreachable
|
||||
raise ValueError(f'BaseRelation has no {field_name} field!')
|
||||
raise ValueError(f"BaseRelation has no {field_name} field!")
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, self.__class__):
|
||||
@@ -49,20 +52,18 @@ class BaseRelation(FakeAPIObject, Hashable):
|
||||
|
||||
@classmethod
|
||||
def get_default_quote_policy(cls) -> Policy:
|
||||
return cls._get_field_named('quote_policy').default
|
||||
return cls._get_field_named("quote_policy").default
|
||||
|
||||
@classmethod
|
||||
def get_default_include_policy(cls) -> Policy:
|
||||
return cls._get_field_named('include_policy').default
|
||||
return cls._get_field_named("include_policy").default
|
||||
|
||||
def get(self, key, default=None):
|
||||
"""Override `.get` to return a metadata object so we don't break
|
||||
dbt_utils.
|
||||
"""
|
||||
if key == 'metadata':
|
||||
return {
|
||||
'type': self.__class__.__name__
|
||||
}
|
||||
if key == "metadata":
|
||||
return {"type": self.__class__.__name__}
|
||||
return super().get(key, default)
|
||||
|
||||
def matches(
|
||||
@@ -71,16 +72,19 @@ class BaseRelation(FakeAPIObject, Hashable):
|
||||
schema: Optional[str] = None,
|
||||
identifier: Optional[str] = None,
|
||||
) -> bool:
|
||||
search = filter_null_values({
|
||||
search = filter_null_values(
|
||||
{
|
||||
ComponentName.Database: database,
|
||||
ComponentName.Schema: schema,
|
||||
ComponentName.Identifier: identifier
|
||||
})
|
||||
ComponentName.Identifier: identifier,
|
||||
}
|
||||
)
|
||||
|
||||
if not search:
|
||||
# nothing was passed in
|
||||
raise dbt.exceptions.RuntimeException(
|
||||
"Tried to match relation, but no search path was passed!")
|
||||
"Tried to match relation, but no search path was passed!"
|
||||
)
|
||||
|
||||
exact_match = True
|
||||
approximate_match = True
|
||||
@@ -109,11 +113,13 @@ class BaseRelation(FakeAPIObject, Hashable):
|
||||
schema: Optional[bool] = None,
|
||||
identifier: Optional[bool] = None,
|
||||
) -> Self:
|
||||
policy = filter_null_values({
|
||||
policy = filter_null_values(
|
||||
{
|
||||
ComponentName.Database: database,
|
||||
ComponentName.Schema: schema,
|
||||
ComponentName.Identifier: identifier
|
||||
})
|
||||
ComponentName.Identifier: identifier,
|
||||
}
|
||||
)
|
||||
|
||||
new_quote_policy = self.quote_policy.replace_dict(policy)
|
||||
return self.replace(quote_policy=new_quote_policy)
|
||||
@@ -124,16 +130,18 @@ class BaseRelation(FakeAPIObject, Hashable):
|
||||
schema: Optional[bool] = None,
|
||||
identifier: Optional[bool] = None,
|
||||
) -> Self:
|
||||
policy = filter_null_values({
|
||||
policy = filter_null_values(
|
||||
{
|
||||
ComponentName.Database: database,
|
||||
ComponentName.Schema: schema,
|
||||
ComponentName.Identifier: identifier
|
||||
})
|
||||
ComponentName.Identifier: identifier,
|
||||
}
|
||||
)
|
||||
|
||||
new_include_policy = self.include_policy.replace_dict(policy)
|
||||
return self.replace(include_policy=new_include_policy)
|
||||
|
||||
def information_schema(self, view_name=None) -> 'InformationSchema':
|
||||
def information_schema(self, view_name=None) -> "InformationSchema":
|
||||
# some of our data comes from jinja, where things can be `Undefined`.
|
||||
if not isinstance(view_name, str):
|
||||
view_name = None
|
||||
@@ -143,10 +151,10 @@ class BaseRelation(FakeAPIObject, Hashable):
|
||||
info_schema = InformationSchema.from_relation(self, view_name)
|
||||
return info_schema.incorporate(path={"schema": None})
|
||||
|
||||
def information_schema_only(self) -> 'InformationSchema':
|
||||
def information_schema_only(self) -> "InformationSchema":
|
||||
return self.information_schema()
|
||||
|
||||
def without_identifier(self) -> 'BaseRelation':
|
||||
def without_identifier(self) -> "BaseRelation":
|
||||
"""Return a form of this relation that only has the database and schema
|
||||
set to included. To get the appropriately-quoted form the schema out of
|
||||
the result (for use as part of a query), use `.render()`. To get the
|
||||
@@ -157,7 +165,7 @@ class BaseRelation(FakeAPIObject, Hashable):
|
||||
return self.include(identifier=False).replace_path(identifier=None)
|
||||
|
||||
def _render_iterator(
|
||||
self
|
||||
self,
|
||||
) -> Iterator[Tuple[Optional[ComponentName], Optional[str]]]:
|
||||
|
||||
for key in ComponentName:
|
||||
@@ -170,13 +178,10 @@ class BaseRelation(FakeAPIObject, Hashable):
|
||||
|
||||
def render(self) -> str:
|
||||
# if there is nothing set, this will return the empty string.
|
||||
return '.'.join(
|
||||
part for _, part in self._render_iterator()
|
||||
if part is not None
|
||||
)
|
||||
return ".".join(part for _, part in self._render_iterator() if part is not None)
|
||||
|
||||
def quoted(self, identifier):
|
||||
return '{quote_char}{identifier}{quote_char}'.format(
|
||||
return "{quote_char}{identifier}{quote_char}".format(
|
||||
quote_char=self.quote_character,
|
||||
identifier=identifier,
|
||||
)
|
||||
@@ -186,11 +191,11 @@ class BaseRelation(FakeAPIObject, Hashable):
|
||||
cls: Type[Self], source: ParsedSourceDefinition, **kwargs: Any
|
||||
) -> Self:
|
||||
source_quoting = source.quoting.to_dict(omit_none=True)
|
||||
source_quoting.pop('column', None)
|
||||
source_quoting.pop("column", None)
|
||||
quote_policy = deep_merge(
|
||||
cls.get_default_quote_policy().to_dict(omit_none=True),
|
||||
source_quoting,
|
||||
kwargs.get('quote_policy', {}),
|
||||
kwargs.get("quote_policy", {}),
|
||||
)
|
||||
|
||||
return cls.create(
|
||||
@@ -198,12 +203,12 @@ class BaseRelation(FakeAPIObject, Hashable):
|
||||
schema=source.schema,
|
||||
identifier=source.identifier,
|
||||
quote_policy=quote_policy,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_ephemeral_prefix(name: str):
|
||||
return f'__dbt__cte__{name}'
|
||||
return f"__dbt__cte__{name}"
|
||||
|
||||
@classmethod
|
||||
def create_ephemeral_from_node(
|
||||
@@ -236,7 +241,8 @@ class BaseRelation(FakeAPIObject, Hashable):
|
||||
schema=node.schema,
|
||||
identifier=node.alias,
|
||||
quote_policy=quote_policy,
|
||||
**kwargs)
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_from(
|
||||
@@ -248,15 +254,16 @@ class BaseRelation(FakeAPIObject, Hashable):
|
||||
if node.resource_type == NodeType.Source:
|
||||
if not isinstance(node, ParsedSourceDefinition):
|
||||
raise InternalException(
|
||||
'type mismatch, expected ParsedSourceDefinition but got {}'
|
||||
.format(type(node))
|
||||
"type mismatch, expected ParsedSourceDefinition but got {}".format(
|
||||
type(node)
|
||||
)
|
||||
)
|
||||
return cls.create_from_source(node, **kwargs)
|
||||
else:
|
||||
if not isinstance(node, (ParsedNode, CompiledNode)):
|
||||
raise InternalException(
|
||||
'type mismatch, expected ParsedNode or CompiledNode but '
|
||||
'got {}'.format(type(node))
|
||||
"type mismatch, expected ParsedNode or CompiledNode but "
|
||||
"got {}".format(type(node))
|
||||
)
|
||||
return cls.create_from_node(config, node, **kwargs)
|
||||
|
||||
@@ -269,14 +276,16 @@ class BaseRelation(FakeAPIObject, Hashable):
|
||||
type: Optional[RelationType] = None,
|
||||
**kwargs,
|
||||
) -> Self:
|
||||
kwargs.update({
|
||||
'path': {
|
||||
'database': database,
|
||||
'schema': schema,
|
||||
'identifier': identifier,
|
||||
kwargs.update(
|
||||
{
|
||||
"path": {
|
||||
"database": database,
|
||||
"schema": schema,
|
||||
"identifier": identifier,
|
||||
},
|
||||
'type': type,
|
||||
})
|
||||
"type": type,
|
||||
}
|
||||
)
|
||||
return cls.from_dict(kwargs)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@@ -342,7 +351,7 @@ class BaseRelation(FakeAPIObject, Hashable):
|
||||
return RelationType
|
||||
|
||||
|
||||
Info = TypeVar('Info', bound='InformationSchema')
|
||||
Info = TypeVar("Info", bound="InformationSchema")
|
||||
|
||||
|
||||
@dataclass(frozen=True, eq=False, repr=False)
|
||||
@@ -352,7 +361,7 @@ class InformationSchema(BaseRelation):
|
||||
def __post_init__(self):
|
||||
if not isinstance(self.information_schema_view, (type(None), str)):
|
||||
raise dbt.exceptions.CompilationException(
|
||||
'Got an invalid name: {}'.format(self.information_schema_view)
|
||||
"Got an invalid name: {}".format(self.information_schema_view)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -362,7 +371,7 @@ class InformationSchema(BaseRelation):
|
||||
return Path(
|
||||
database=relation.database,
|
||||
schema=relation.schema,
|
||||
identifier='INFORMATION_SCHEMA',
|
||||
identifier="INFORMATION_SCHEMA",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -393,9 +402,7 @@ class InformationSchema(BaseRelation):
|
||||
relation: BaseRelation,
|
||||
information_schema_view: Optional[str],
|
||||
) -> Info:
|
||||
include_policy = cls.get_include_policy(
|
||||
relation, information_schema_view
|
||||
)
|
||||
include_policy = cls.get_include_policy(relation, information_schema_view)
|
||||
quote_policy = cls.get_quote_policy(relation, information_schema_view)
|
||||
path = cls.get_path(relation, information_schema_view)
|
||||
return cls(
|
||||
@@ -417,6 +424,7 @@ class SchemaSearchMap(Dict[InformationSchema, Set[Optional[str]]]):
|
||||
search for what schemas. The schema values are all lowercased to avoid
|
||||
duplication.
|
||||
"""
|
||||
|
||||
def add(self, relation: BaseRelation):
|
||||
key = relation.information_schema_only()
|
||||
if key not in self:
|
||||
@@ -426,9 +434,7 @@ class SchemaSearchMap(Dict[InformationSchema, Set[Optional[str]]]):
|
||||
schema = relation.schema.lower()
|
||||
self[key].add(schema)
|
||||
|
||||
def search(
|
||||
self
|
||||
) -> Iterator[Tuple[InformationSchema, Optional[str]]]:
|
||||
def search(self) -> Iterator[Tuple[InformationSchema, Optional[str]]]:
|
||||
for information_schema_name, schemas in self.items():
|
||||
for schema in schemas:
|
||||
yield information_schema_name, schema
|
||||
@@ -442,14 +448,13 @@ class SchemaSearchMap(Dict[InformationSchema, Set[Optional[str]]]):
|
||||
dbt.exceptions.raise_compiler_error(str(seen))
|
||||
|
||||
for information_schema_name, schema in self.search():
|
||||
path = {
|
||||
'database': information_schema_name.database,
|
||||
'schema': schema
|
||||
}
|
||||
new.add(information_schema_name.incorporate(
|
||||
path = {"database": information_schema_name.database, "schema": schema}
|
||||
new.add(
|
||||
information_schema_name.incorporate(
|
||||
path=path,
|
||||
quote_policy={'database': False},
|
||||
include_policy={'database': False},
|
||||
))
|
||||
quote_policy={"database": False},
|
||||
include_policy={"database": False},
|
||||
)
|
||||
)
|
||||
|
||||
return new
|
||||
|
||||
@@ -7,7 +7,7 @@ from dbt.logger import CACHE_LOGGER as logger
|
||||
from dbt.utils import lowercase
|
||||
import dbt.exceptions
|
||||
|
||||
_ReferenceKey = namedtuple('_ReferenceKey', 'database schema identifier')
|
||||
_ReferenceKey = namedtuple("_ReferenceKey", "database schema identifier")
|
||||
|
||||
|
||||
def _make_key(relation) -> _ReferenceKey:
|
||||
@@ -15,9 +15,11 @@ def _make_key(relation) -> _ReferenceKey:
|
||||
to keep track of quoting
|
||||
"""
|
||||
# databases and schemas can both be None
|
||||
return _ReferenceKey(lowercase(relation.database),
|
||||
return _ReferenceKey(
|
||||
lowercase(relation.database),
|
||||
lowercase(relation.schema),
|
||||
lowercase(relation.identifier))
|
||||
lowercase(relation.identifier),
|
||||
)
|
||||
|
||||
|
||||
def dot_separated(key: _ReferenceKey) -> str:
|
||||
@@ -25,7 +27,7 @@ def dot_separated(key: _ReferenceKey) -> str:
|
||||
|
||||
:param _ReferenceKey key: The key to stringify.
|
||||
"""
|
||||
return '.'.join(map(str, key))
|
||||
return ".".join(map(str, key))
|
||||
|
||||
|
||||
class _CachedRelation:
|
||||
@@ -37,13 +39,14 @@ class _CachedRelation:
|
||||
that refer to this relation.
|
||||
:attr BaseRelation inner: The underlying dbt relation.
|
||||
"""
|
||||
|
||||
def __init__(self, inner):
|
||||
self.referenced_by = {}
|
||||
self.inner = inner
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
'_CachedRelation(database={}, schema={}, identifier={}, inner={})'
|
||||
"_CachedRelation(database={}, schema={}, identifier={}, inner={})"
|
||||
).format(self.database, self.schema, self.identifier, self.inner)
|
||||
|
||||
@property
|
||||
@@ -78,7 +81,7 @@ class _CachedRelation:
|
||||
"""
|
||||
return _make_key(self)
|
||||
|
||||
def add_reference(self, referrer: '_CachedRelation'):
|
||||
def add_reference(self, referrer: "_CachedRelation"):
|
||||
"""Add a reference from referrer to self, indicating that if this node
|
||||
were drop...cascaded, the referrer would be dropped as well.
|
||||
|
||||
@@ -122,9 +125,9 @@ class _CachedRelation:
|
||||
# table_name is ever anything but the identifier (via .create())
|
||||
self.inner = self.inner.incorporate(
|
||||
path={
|
||||
'database': new_relation.inner.database,
|
||||
'schema': new_relation.inner.schema,
|
||||
'identifier': new_relation.inner.identifier
|
||||
"database": new_relation.inner.database,
|
||||
"schema": new_relation.inner.schema,
|
||||
"identifier": new_relation.inner.identifier,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -140,8 +143,9 @@ class _CachedRelation:
|
||||
"""
|
||||
if new_key in self.referenced_by:
|
||||
dbt.exceptions.raise_cache_inconsistent(
|
||||
'in rename of "{}" -> "{}", new name is in the cache already'
|
||||
.format(old_key, new_key)
|
||||
'in rename of "{}" -> "{}", new name is in the cache already'.format(
|
||||
old_key, new_key
|
||||
)
|
||||
)
|
||||
|
||||
if old_key not in self.referenced_by:
|
||||
@@ -172,13 +176,16 @@ class RelationsCache:
|
||||
The adapters also hold this lock while filling the cache.
|
||||
:attr Set[str] schemas: The set of known/cached schemas, all lowercased.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.relations: Dict[_ReferenceKey, _CachedRelation] = {}
|
||||
self.lock = threading.RLock()
|
||||
self.schemas: Set[Tuple[Optional[str], Optional[str]]] = set()
|
||||
|
||||
def add_schema(
|
||||
self, database: Optional[str], schema: Optional[str],
|
||||
self,
|
||||
database: Optional[str],
|
||||
schema: Optional[str],
|
||||
) -> None:
|
||||
"""Add a schema to the set of known schemas (case-insensitive)
|
||||
|
||||
@@ -188,7 +195,9 @@ class RelationsCache:
|
||||
self.schemas.add((lowercase(database), lowercase(schema)))
|
||||
|
||||
def drop_schema(
|
||||
self, database: Optional[str], schema: Optional[str],
|
||||
self,
|
||||
database: Optional[str],
|
||||
schema: Optional[str],
|
||||
) -> None:
|
||||
"""Drop the given schema and remove it from the set of known schemas.
|
||||
|
||||
@@ -263,15 +272,15 @@ class RelationsCache:
|
||||
return
|
||||
if referenced is None:
|
||||
dbt.exceptions.raise_cache_inconsistent(
|
||||
'in add_link, referenced link key {} not in cache!'
|
||||
.format(referenced_key)
|
||||
"in add_link, referenced link key {} not in cache!".format(
|
||||
referenced_key
|
||||
)
|
||||
)
|
||||
|
||||
dependent = self.relations.get(dependent_key)
|
||||
if dependent is None:
|
||||
dbt.exceptions.raise_cache_inconsistent(
|
||||
'in add_link, dependent link key {} not in cache!'
|
||||
.format(dependent_key)
|
||||
"in add_link, dependent link key {} not in cache!".format(dependent_key)
|
||||
)
|
||||
|
||||
assert dependent is not None # we just raised!
|
||||
@@ -298,28 +307,23 @@ class RelationsCache:
|
||||
# referring to a table outside our control. There's no need to make
|
||||
# a link - we will never drop the referenced relation during a run.
|
||||
logger.debug(
|
||||
'{dep!s} references {ref!s} but {ref.database}.{ref.schema} '
|
||||
'is not in the cache, skipping assumed external relation'
|
||||
.format(dep=dependent, ref=ref_key)
|
||||
"{dep!s} references {ref!s} but {ref.database}.{ref.schema} "
|
||||
"is not in the cache, skipping assumed external relation".format(
|
||||
dep=dependent, ref=ref_key
|
||||
)
|
||||
)
|
||||
return
|
||||
if ref_key not in self.relations:
|
||||
# Insert a dummy "external" relation.
|
||||
referenced = referenced.replace(
|
||||
type=referenced.External
|
||||
)
|
||||
referenced = referenced.replace(type=referenced.External)
|
||||
self.add(referenced)
|
||||
|
||||
dep_key = _make_key(dependent)
|
||||
if dep_key not in self.relations:
|
||||
# Insert a dummy "external" relation.
|
||||
dependent = dependent.replace(
|
||||
type=referenced.External
|
||||
)
|
||||
dependent = dependent.replace(type=referenced.External)
|
||||
self.add(dependent)
|
||||
logger.debug(
|
||||
'adding link, {!s} references {!s}'.format(dep_key, ref_key)
|
||||
)
|
||||
logger.debug("adding link, {!s} references {!s}".format(dep_key, ref_key))
|
||||
with self.lock:
|
||||
self._add_link(ref_key, dep_key)
|
||||
|
||||
@@ -330,14 +334,14 @@ class RelationsCache:
|
||||
:param BaseRelation relation: The underlying relation.
|
||||
"""
|
||||
cached = _CachedRelation(relation)
|
||||
logger.debug('Adding relation: {!s}'.format(cached))
|
||||
logger.debug("Adding relation: {!s}".format(cached))
|
||||
|
||||
lazy_log('before adding: {!s}', self.dump_graph)
|
||||
lazy_log("before adding: {!s}", self.dump_graph)
|
||||
|
||||
with self.lock:
|
||||
self._setdefault(cached)
|
||||
|
||||
lazy_log('after adding: {!s}', self.dump_graph)
|
||||
lazy_log("after adding: {!s}", self.dump_graph)
|
||||
|
||||
def _remove_refs(self, keys):
|
||||
"""Removes all references to all entries in keys. This does not
|
||||
@@ -359,13 +363,10 @@ class RelationsCache:
|
||||
:param _CachedRelation dropped: An existing _CachedRelation to drop.
|
||||
"""
|
||||
if dropped not in self.relations:
|
||||
logger.debug('dropped a nonexistent relationship: {!s}'
|
||||
.format(dropped))
|
||||
logger.debug("dropped a nonexistent relationship: {!s}".format(dropped))
|
||||
return
|
||||
consequences = self.relations[dropped].collect_consequences()
|
||||
logger.debug(
|
||||
'drop {} is cascading to {}'.format(dropped, consequences)
|
||||
)
|
||||
logger.debug("drop {} is cascading to {}".format(dropped, consequences))
|
||||
self._remove_refs(consequences)
|
||||
|
||||
def drop(self, relation):
|
||||
@@ -380,7 +381,7 @@ class RelationsCache:
|
||||
:param str identifier: The identifier of the relation to drop.
|
||||
"""
|
||||
dropped = _make_key(relation)
|
||||
logger.debug('Dropping relation: {!s}'.format(dropped))
|
||||
logger.debug("Dropping relation: {!s}".format(dropped))
|
||||
with self.lock:
|
||||
self._drop_cascade_relation(dropped)
|
||||
|
||||
@@ -404,8 +405,9 @@ class RelationsCache:
|
||||
for cached in self.relations.values():
|
||||
if cached.is_referenced_by(old_key):
|
||||
logger.debug(
|
||||
'updated reference from {0} -> {2} to {1} -> {2}'
|
||||
.format(old_key, new_key, cached.key())
|
||||
"updated reference from {0} -> {2} to {1} -> {2}".format(
|
||||
old_key, new_key, cached.key()
|
||||
)
|
||||
)
|
||||
cached.rename_key(old_key, new_key)
|
||||
|
||||
@@ -430,14 +432,16 @@ class RelationsCache:
|
||||
"""
|
||||
if new_key in self.relations:
|
||||
dbt.exceptions.raise_cache_inconsistent(
|
||||
'in rename, new key {} already in cache: {}'
|
||||
.format(new_key, list(self.relations.keys()))
|
||||
"in rename, new key {} already in cache: {}".format(
|
||||
new_key, list(self.relations.keys())
|
||||
)
|
||||
)
|
||||
|
||||
if old_key not in self.relations:
|
||||
logger.debug(
|
||||
'old key {} not found in self.relations, assuming temporary'
|
||||
.format(old_key)
|
||||
"old key {} not found in self.relations, assuming temporary".format(
|
||||
old_key
|
||||
)
|
||||
)
|
||||
return False
|
||||
return True
|
||||
@@ -456,11 +460,9 @@ class RelationsCache:
|
||||
"""
|
||||
old_key = _make_key(old)
|
||||
new_key = _make_key(new)
|
||||
logger.debug('Renaming relation {!s} to {!s}'.format(
|
||||
old_key, new_key
|
||||
))
|
||||
logger.debug("Renaming relation {!s} to {!s}".format(old_key, new_key))
|
||||
|
||||
lazy_log('before rename: {!s}', self.dump_graph)
|
||||
lazy_log("before rename: {!s}", self.dump_graph)
|
||||
|
||||
with self.lock:
|
||||
if self._check_rename_constraints(old_key, new_key):
|
||||
@@ -468,7 +470,7 @@ class RelationsCache:
|
||||
else:
|
||||
self._setdefault(_CachedRelation(new))
|
||||
|
||||
lazy_log('after rename: {!s}', self.dump_graph)
|
||||
lazy_log("after rename: {!s}", self.dump_graph)
|
||||
|
||||
def get_relations(
|
||||
self, database: Optional[str], schema: Optional[str]
|
||||
@@ -483,14 +485,14 @@ class RelationsCache:
|
||||
schema = lowercase(schema)
|
||||
with self.lock:
|
||||
results = [
|
||||
r.inner for r in self.relations.values()
|
||||
if (lowercase(r.schema) == schema and
|
||||
lowercase(r.database) == database)
|
||||
r.inner
|
||||
for r in self.relations.values()
|
||||
if (lowercase(r.schema) == schema and lowercase(r.database) == database)
|
||||
]
|
||||
|
||||
if None in results:
|
||||
dbt.exceptions.raise_cache_inconsistent(
|
||||
'in get_relations, a None relation was found in the cache!'
|
||||
"in get_relations, a None relation was found in the cache!"
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
@@ -50,9 +50,7 @@ class AdapterContainer:
|
||||
adapter = self.get_adapter_class_by_name(name)
|
||||
return adapter.Relation
|
||||
|
||||
def get_config_class_by_name(
|
||||
self, name: str
|
||||
) -> Type[AdapterConfig]:
|
||||
def get_config_class_by_name(self, name: str) -> Type[AdapterConfig]:
|
||||
adapter = self.get_adapter_class_by_name(name)
|
||||
return adapter.AdapterSpecificConfigs
|
||||
|
||||
@@ -62,24 +60,24 @@ class AdapterContainer:
|
||||
# singletons
|
||||
try:
|
||||
# mypy doesn't think modules have any attributes.
|
||||
mod: Any = import_module('.' + name, 'dbt.adapters')
|
||||
mod: Any = import_module("." + name, "dbt.adapters")
|
||||
except ModuleNotFoundError as exc:
|
||||
# if we failed to import the target module in particular, inform
|
||||
# the user about it via a runtime error
|
||||
if exc.name == 'dbt.adapters.' + name:
|
||||
raise RuntimeException(f'Could not find adapter type {name}!')
|
||||
logger.info(f'Error importing adapter: {exc}')
|
||||
if exc.name == "dbt.adapters." + name:
|
||||
raise RuntimeException(f"Could not find adapter type {name}!")
|
||||
logger.info(f"Error importing adapter: {exc}")
|
||||
# otherwise, the error had to have come from some underlying
|
||||
# library. Log the stack trace.
|
||||
logger.debug('', exc_info=True)
|
||||
logger.debug("", exc_info=True)
|
||||
raise
|
||||
plugin: AdapterPlugin = mod.Plugin
|
||||
plugin_type = plugin.adapter.type()
|
||||
|
||||
if plugin_type != name:
|
||||
raise RuntimeException(
|
||||
f'Expected to find adapter with type named {name}, got '
|
||||
f'adapter with type {plugin_type}'
|
||||
f"Expected to find adapter with type named {name}, got "
|
||||
f"adapter with type {plugin_type}"
|
||||
)
|
||||
|
||||
with self.lock:
|
||||
@@ -109,8 +107,7 @@ class AdapterContainer:
|
||||
return self.adapters[adapter_name]
|
||||
|
||||
def reset_adapters(self):
|
||||
"""Clear the adapters. This is useful for tests, which change configs.
|
||||
"""
|
||||
"""Clear the adapters. This is useful for tests, which change configs."""
|
||||
with self.lock:
|
||||
for adapter in self.adapters.values():
|
||||
adapter.cleanup_connections()
|
||||
@@ -140,9 +137,7 @@ class AdapterContainer:
|
||||
try:
|
||||
plugin = self.plugins[plugin_name]
|
||||
except KeyError:
|
||||
raise InternalException(
|
||||
f'No plugin found for {plugin_name}'
|
||||
) from None
|
||||
raise InternalException(f"No plugin found for {plugin_name}") from None
|
||||
plugins.append(plugin)
|
||||
seen.add(plugin_name)
|
||||
if plugin.dependencies is None:
|
||||
@@ -166,7 +161,7 @@ class AdapterContainer:
|
||||
path = self.packages[package_name]
|
||||
except KeyError:
|
||||
raise InternalException(
|
||||
f'No internal package listing found for {package_name}'
|
||||
f"No internal package listing found for {package_name}"
|
||||
)
|
||||
paths.append(path)
|
||||
return paths
|
||||
@@ -187,8 +182,7 @@ def get_adapter(config: AdapterRequiredConfig):
|
||||
|
||||
|
||||
def reset_adapters():
|
||||
"""Clear the adapters. This is useful for tests, which change configs.
|
||||
"""
|
||||
"""Clear the adapters. This is useful for tests, which change configs."""
|
||||
FACTORY.reset_adapters()
|
||||
|
||||
|
||||
|
||||
@@ -1,17 +1,27 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Type, Hashable, Optional, ContextManager, List, Generic, TypeVar, ClassVar,
|
||||
Tuple, Union, Dict, Any
|
||||
Type,
|
||||
Hashable,
|
||||
Optional,
|
||||
ContextManager,
|
||||
List,
|
||||
Generic,
|
||||
TypeVar,
|
||||
ClassVar,
|
||||
Tuple,
|
||||
Union,
|
||||
Dict,
|
||||
Any,
|
||||
)
|
||||
from typing_extensions import Protocol
|
||||
|
||||
import agate
|
||||
|
||||
from dbt.contracts.connection import (
|
||||
Connection, AdapterRequiredConfig, AdapterResponse
|
||||
)
|
||||
from dbt.contracts.connection import Connection, AdapterRequiredConfig, AdapterResponse
|
||||
from dbt.contracts.graph.compiled import (
|
||||
CompiledNode, ManifestNode, NonSourceCompiledNode
|
||||
CompiledNode,
|
||||
ManifestNode,
|
||||
NonSourceCompiledNode,
|
||||
)
|
||||
from dbt.contracts.graph.parsed import ParsedNode, ParsedSourceDefinition
|
||||
from dbt.contracts.graph.model_config import BaseConfig
|
||||
@@ -34,7 +44,7 @@ class ColumnProtocol(Protocol):
|
||||
pass
|
||||
|
||||
|
||||
Self = TypeVar('Self', bound='RelationProtocol')
|
||||
Self = TypeVar("Self", bound="RelationProtocol")
|
||||
|
||||
|
||||
class RelationProtocol(Protocol):
|
||||
@@ -64,19 +74,11 @@ class CompilerProtocol(Protocol):
|
||||
...
|
||||
|
||||
|
||||
AdapterConfig_T = TypeVar(
|
||||
'AdapterConfig_T', bound=AdapterConfig
|
||||
)
|
||||
ConnectionManager_T = TypeVar(
|
||||
'ConnectionManager_T', bound=ConnectionManagerProtocol
|
||||
)
|
||||
Relation_T = TypeVar(
|
||||
'Relation_T', bound=RelationProtocol
|
||||
)
|
||||
Column_T = TypeVar(
|
||||
'Column_T', bound=ColumnProtocol
|
||||
)
|
||||
Compiler_T = TypeVar('Compiler_T', bound=CompilerProtocol)
|
||||
AdapterConfig_T = TypeVar("AdapterConfig_T", bound=AdapterConfig)
|
||||
ConnectionManager_T = TypeVar("ConnectionManager_T", bound=ConnectionManagerProtocol)
|
||||
Relation_T = TypeVar("Relation_T", bound=RelationProtocol)
|
||||
Column_T = TypeVar("Column_T", bound=ColumnProtocol)
|
||||
Compiler_T = TypeVar("Compiler_T", bound=CompilerProtocol)
|
||||
|
||||
|
||||
class AdapterProtocol(
|
||||
@@ -87,7 +89,7 @@ class AdapterProtocol(
|
||||
Relation_T,
|
||||
Column_T,
|
||||
Compiler_T,
|
||||
]
|
||||
],
|
||||
):
|
||||
AdapterSpecificConfigs: ClassVar[Type[AdapterConfig_T]]
|
||||
Column: ClassVar[Type[Column_T]]
|
||||
|
||||
@@ -7,9 +7,7 @@ import agate
|
||||
import dbt.clients.agate_helper
|
||||
import dbt.exceptions
|
||||
from dbt.adapters.base import BaseConnectionManager
|
||||
from dbt.contracts.connection import (
|
||||
Connection, ConnectionState, AdapterResponse
|
||||
)
|
||||
from dbt.contracts.connection import Connection, ConnectionState, AdapterResponse
|
||||
from dbt.logger import GLOBAL_LOGGER as logger
|
||||
from dbt import flags
|
||||
|
||||
@@ -23,11 +21,12 @@ class SQLConnectionManager(BaseConnectionManager):
|
||||
- get_response
|
||||
- open
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def cancel(self, connection: Connection):
|
||||
"""Cancel the given connection."""
|
||||
raise dbt.exceptions.NotImplementedException(
|
||||
'`cancel` is not implemented for this adapter!'
|
||||
"`cancel` is not implemented for this adapter!"
|
||||
)
|
||||
|
||||
def cancel_open(self) -> List[str]:
|
||||
@@ -41,8 +40,8 @@ class SQLConnectionManager(BaseConnectionManager):
|
||||
# if the connection failed, the handle will be None so we have
|
||||
# nothing to cancel.
|
||||
if (
|
||||
connection.handle is not None and
|
||||
connection.state == ConnectionState.OPEN
|
||||
connection.handle is not None
|
||||
and connection.state == ConnectionState.OPEN
|
||||
):
|
||||
self.cancel(connection)
|
||||
if connection.name is not None:
|
||||
@@ -54,23 +53,22 @@ class SQLConnectionManager(BaseConnectionManager):
|
||||
sql: str,
|
||||
auto_begin: bool = True,
|
||||
bindings: Optional[Any] = None,
|
||||
abridge_sql_log: bool = False
|
||||
abridge_sql_log: bool = False,
|
||||
) -> Tuple[Connection, Any]:
|
||||
connection = self.get_thread_connection()
|
||||
if auto_begin and connection.transaction_open is False:
|
||||
self.begin()
|
||||
|
||||
logger.debug('Using {} connection "{}".'
|
||||
.format(self.TYPE, connection.name))
|
||||
logger.debug('Using {} connection "{}".'.format(self.TYPE, connection.name))
|
||||
|
||||
with self.exception_handler(sql):
|
||||
if abridge_sql_log:
|
||||
log_sql = '{}...'.format(sql[:512])
|
||||
log_sql = "{}...".format(sql[:512])
|
||||
else:
|
||||
log_sql = sql
|
||||
|
||||
logger.debug(
|
||||
'On {connection_name}: {sql}',
|
||||
"On {connection_name}: {sql}",
|
||||
connection_name=connection.name,
|
||||
sql=log_sql,
|
||||
)
|
||||
@@ -81,7 +79,7 @@ class SQLConnectionManager(BaseConnectionManager):
|
||||
logger.debug(
|
||||
"SQL status: {status} in {elapsed:0.2f} seconds",
|
||||
status=self.get_response(cursor),
|
||||
elapsed=(time.time() - pre)
|
||||
elapsed=(time.time() - pre),
|
||||
)
|
||||
|
||||
return connection, cursor
|
||||
@@ -90,14 +88,12 @@ class SQLConnectionManager(BaseConnectionManager):
|
||||
def get_response(cls, cursor: Any) -> Union[AdapterResponse, str]:
|
||||
"""Get the status of the cursor."""
|
||||
raise dbt.exceptions.NotImplementedException(
|
||||
'`get_response` is not implemented for this adapter!'
|
||||
"`get_response` is not implemented for this adapter!"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def process_results(
|
||||
cls,
|
||||
column_names: Iterable[str],
|
||||
rows: Iterable[Any]
|
||||
cls, column_names: Iterable[str], rows: Iterable[Any]
|
||||
) -> List[Dict[str, Any]]:
|
||||
|
||||
return [dict(zip(column_names, row)) for row in rows]
|
||||
@@ -112,10 +108,7 @@ class SQLConnectionManager(BaseConnectionManager):
|
||||
rows = cursor.fetchall()
|
||||
data = cls.process_results(column_names, rows)
|
||||
|
||||
return dbt.clients.agate_helper.table_from_data_flat(
|
||||
data,
|
||||
column_names
|
||||
)
|
||||
return dbt.clients.agate_helper.table_from_data_flat(data, column_names)
|
||||
|
||||
def execute(
|
||||
self, sql: str, auto_begin: bool = False, fetch: bool = False
|
||||
@@ -130,10 +123,10 @@ class SQLConnectionManager(BaseConnectionManager):
|
||||
return response, table
|
||||
|
||||
def add_begin_query(self):
|
||||
return self.add_query('BEGIN', auto_begin=False)
|
||||
return self.add_query("BEGIN", auto_begin=False)
|
||||
|
||||
def add_commit_query(self):
|
||||
return self.add_query('COMMIT', auto_begin=False)
|
||||
return self.add_query("COMMIT", auto_begin=False)
|
||||
|
||||
def begin(self):
|
||||
connection = self.get_thread_connection()
|
||||
@@ -141,13 +134,14 @@ class SQLConnectionManager(BaseConnectionManager):
|
||||
if flags.STRICT_MODE:
|
||||
if not isinstance(connection, Connection):
|
||||
raise dbt.exceptions.CompilerException(
|
||||
f'In begin, got {connection} - not a Connection!'
|
||||
f"In begin, got {connection} - not a Connection!"
|
||||
)
|
||||
|
||||
if connection.transaction_open is True:
|
||||
raise dbt.exceptions.InternalException(
|
||||
'Tried to begin a new transaction on connection "{}", but '
|
||||
'it already had one open!'.format(connection.name))
|
||||
"it already had one open!".format(connection.name)
|
||||
)
|
||||
|
||||
self.add_begin_query()
|
||||
|
||||
@@ -159,15 +153,16 @@ class SQLConnectionManager(BaseConnectionManager):
|
||||
if flags.STRICT_MODE:
|
||||
if not isinstance(connection, Connection):
|
||||
raise dbt.exceptions.CompilerException(
|
||||
f'In commit, got {connection} - not a Connection!'
|
||||
f"In commit, got {connection} - not a Connection!"
|
||||
)
|
||||
|
||||
if connection.transaction_open is False:
|
||||
raise dbt.exceptions.InternalException(
|
||||
'Tried to commit transaction on connection "{}", but '
|
||||
'it does not have one open!'.format(connection.name))
|
||||
"it does not have one open!".format(connection.name)
|
||||
)
|
||||
|
||||
logger.debug('On {}: COMMIT'.format(connection.name))
|
||||
logger.debug("On {}: COMMIT".format(connection.name))
|
||||
self.add_commit_query()
|
||||
|
||||
connection.transaction_open = False
|
||||
|
||||
@@ -10,16 +10,16 @@ from dbt.logger import GLOBAL_LOGGER as logger
|
||||
|
||||
from dbt.adapters.base.relation import BaseRelation
|
||||
|
||||
LIST_RELATIONS_MACRO_NAME = 'list_relations_without_caching'
|
||||
GET_COLUMNS_IN_RELATION_MACRO_NAME = 'get_columns_in_relation'
|
||||
LIST_SCHEMAS_MACRO_NAME = 'list_schemas'
|
||||
CHECK_SCHEMA_EXISTS_MACRO_NAME = 'check_schema_exists'
|
||||
CREATE_SCHEMA_MACRO_NAME = 'create_schema'
|
||||
DROP_SCHEMA_MACRO_NAME = 'drop_schema'
|
||||
RENAME_RELATION_MACRO_NAME = 'rename_relation'
|
||||
TRUNCATE_RELATION_MACRO_NAME = 'truncate_relation'
|
||||
DROP_RELATION_MACRO_NAME = 'drop_relation'
|
||||
ALTER_COLUMN_TYPE_MACRO_NAME = 'alter_column_type'
|
||||
LIST_RELATIONS_MACRO_NAME = "list_relations_without_caching"
|
||||
GET_COLUMNS_IN_RELATION_MACRO_NAME = "get_columns_in_relation"
|
||||
LIST_SCHEMAS_MACRO_NAME = "list_schemas"
|
||||
CHECK_SCHEMA_EXISTS_MACRO_NAME = "check_schema_exists"
|
||||
CREATE_SCHEMA_MACRO_NAME = "create_schema"
|
||||
DROP_SCHEMA_MACRO_NAME = "drop_schema"
|
||||
RENAME_RELATION_MACRO_NAME = "rename_relation"
|
||||
TRUNCATE_RELATION_MACRO_NAME = "truncate_relation"
|
||||
DROP_RELATION_MACRO_NAME = "drop_relation"
|
||||
ALTER_COLUMN_TYPE_MACRO_NAME = "alter_column_type"
|
||||
|
||||
|
||||
class SQLAdapter(BaseAdapter):
|
||||
@@ -60,30 +60,23 @@ class SQLAdapter(BaseAdapter):
|
||||
:param abridge_sql_log: If set, limit the raw sql logged to 512
|
||||
characters
|
||||
"""
|
||||
return self.connections.add_query(sql, auto_begin, bindings,
|
||||
abridge_sql_log)
|
||||
return self.connections.add_query(sql, auto_begin, bindings, abridge_sql_log)
|
||||
|
||||
@classmethod
|
||||
def convert_text_type(cls, agate_table: agate.Table, col_idx: int) -> str:
|
||||
return "text"
|
||||
|
||||
@classmethod
|
||||
def convert_number_type(
|
||||
cls, agate_table: agate.Table, col_idx: int
|
||||
) -> str:
|
||||
decimals = agate_table.aggregate(agate.MaxPrecision(col_idx))
|
||||
def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str:
|
||||
decimals = agate_table.aggregate(agate.MaxPrecision(col_idx)) # type: ignore
|
||||
return "float8" if decimals else "integer"
|
||||
|
||||
@classmethod
|
||||
def convert_boolean_type(
|
||||
cls, agate_table: agate.Table, col_idx: int
|
||||
) -> str:
|
||||
def convert_boolean_type(cls, agate_table: agate.Table, col_idx: int) -> str:
|
||||
return "boolean"
|
||||
|
||||
@classmethod
|
||||
def convert_datetime_type(
|
||||
cls, agate_table: agate.Table, col_idx: int
|
||||
) -> str:
|
||||
def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str:
|
||||
return "timestamp without time zone"
|
||||
|
||||
@classmethod
|
||||
@@ -99,31 +92,28 @@ class SQLAdapter(BaseAdapter):
|
||||
return True
|
||||
|
||||
def expand_column_types(self, goal, current):
|
||||
reference_columns = {
|
||||
c.name: c for c in
|
||||
self.get_columns_in_relation(goal)
|
||||
}
|
||||
reference_columns = {c.name: c for c in self.get_columns_in_relation(goal)}
|
||||
|
||||
target_columns = {
|
||||
c.name: c for c
|
||||
in self.get_columns_in_relation(current)
|
||||
}
|
||||
target_columns = {c.name: c for c in self.get_columns_in_relation(current)}
|
||||
|
||||
for column_name, reference_column in reference_columns.items():
|
||||
target_column = target_columns.get(column_name)
|
||||
|
||||
if target_column is not None and \
|
||||
target_column.can_expand_to(reference_column):
|
||||
if target_column is not None and target_column.can_expand_to(
|
||||
reference_column
|
||||
):
|
||||
col_string_size = reference_column.string_size()
|
||||
new_type = self.Column.string_type(col_string_size)
|
||||
logger.debug("Changing col type from {} to {} in table {}",
|
||||
target_column.data_type, new_type, current)
|
||||
logger.debug(
|
||||
"Changing col type from {} to {} in table {}",
|
||||
target_column.data_type,
|
||||
new_type,
|
||||
current,
|
||||
)
|
||||
|
||||
self.alter_column_type(current, column_name, new_type)
|
||||
|
||||
def alter_column_type(
|
||||
self, relation, column_name, new_column_type
|
||||
) -> None:
|
||||
def alter_column_type(self, relation, column_name, new_column_type) -> None:
|
||||
"""
|
||||
1. Create a new column (w/ temp name and correct type)
|
||||
2. Copy data over to it
|
||||
@@ -131,53 +121,40 @@ class SQLAdapter(BaseAdapter):
|
||||
4. Rename the new column to existing column
|
||||
"""
|
||||
kwargs = {
|
||||
'relation': relation,
|
||||
'column_name': column_name,
|
||||
'new_column_type': new_column_type,
|
||||
"relation": relation,
|
||||
"column_name": column_name,
|
||||
"new_column_type": new_column_type,
|
||||
}
|
||||
self.execute_macro(
|
||||
ALTER_COLUMN_TYPE_MACRO_NAME,
|
||||
kwargs=kwargs
|
||||
)
|
||||
self.execute_macro(ALTER_COLUMN_TYPE_MACRO_NAME, kwargs=kwargs)
|
||||
|
||||
def drop_relation(self, relation):
|
||||
if relation.type is None:
|
||||
dbt.exceptions.raise_compiler_error(
|
||||
'Tried to drop relation {}, but its type is null.'
|
||||
.format(relation))
|
||||
"Tried to drop relation {}, but its type is null.".format(relation)
|
||||
)
|
||||
|
||||
self.cache_dropped(relation)
|
||||
self.execute_macro(
|
||||
DROP_RELATION_MACRO_NAME,
|
||||
kwargs={'relation': relation}
|
||||
)
|
||||
self.execute_macro(DROP_RELATION_MACRO_NAME, kwargs={"relation": relation})
|
||||
|
||||
def truncate_relation(self, relation):
|
||||
self.execute_macro(
|
||||
TRUNCATE_RELATION_MACRO_NAME,
|
||||
kwargs={'relation': relation}
|
||||
)
|
||||
self.execute_macro(TRUNCATE_RELATION_MACRO_NAME, kwargs={"relation": relation})
|
||||
|
||||
def rename_relation(self, from_relation, to_relation):
|
||||
self.cache_renamed(from_relation, to_relation)
|
||||
|
||||
kwargs = {'from_relation': from_relation, 'to_relation': to_relation}
|
||||
self.execute_macro(
|
||||
RENAME_RELATION_MACRO_NAME,
|
||||
kwargs=kwargs
|
||||
)
|
||||
kwargs = {"from_relation": from_relation, "to_relation": to_relation}
|
||||
self.execute_macro(RENAME_RELATION_MACRO_NAME, kwargs=kwargs)
|
||||
|
||||
def get_columns_in_relation(self, relation):
|
||||
return self.execute_macro(
|
||||
GET_COLUMNS_IN_RELATION_MACRO_NAME,
|
||||
kwargs={'relation': relation}
|
||||
GET_COLUMNS_IN_RELATION_MACRO_NAME, kwargs={"relation": relation}
|
||||
)
|
||||
|
||||
def create_schema(self, relation: BaseRelation) -> None:
|
||||
relation = relation.without_identifier()
|
||||
logger.debug('Creating schema "{}"', relation)
|
||||
kwargs = {
|
||||
'relation': relation,
|
||||
"relation": relation,
|
||||
}
|
||||
self.execute_macro(CREATE_SCHEMA_MACRO_NAME, kwargs=kwargs)
|
||||
self.commit_if_has_connection()
|
||||
@@ -188,39 +165,35 @@ class SQLAdapter(BaseAdapter):
|
||||
relation = relation.without_identifier()
|
||||
logger.debug('Dropping schema "{}".', relation)
|
||||
kwargs = {
|
||||
'relation': relation,
|
||||
"relation": relation,
|
||||
}
|
||||
self.execute_macro(DROP_SCHEMA_MACRO_NAME, kwargs=kwargs)
|
||||
# we can update the cache here
|
||||
self.cache.drop_schema(relation.database, relation.schema)
|
||||
|
||||
def list_relations_without_caching(
|
||||
self, schema_relation: BaseRelation,
|
||||
self,
|
||||
schema_relation: BaseRelation,
|
||||
) -> List[BaseRelation]:
|
||||
kwargs = {'schema_relation': schema_relation}
|
||||
results = self.execute_macro(
|
||||
LIST_RELATIONS_MACRO_NAME,
|
||||
kwargs=kwargs
|
||||
)
|
||||
kwargs = {"schema_relation": schema_relation}
|
||||
results = self.execute_macro(LIST_RELATIONS_MACRO_NAME, kwargs=kwargs)
|
||||
|
||||
relations = []
|
||||
quote_policy = {
|
||||
'database': True,
|
||||
'schema': True,
|
||||
'identifier': True
|
||||
}
|
||||
quote_policy = {"database": True, "schema": True, "identifier": True}
|
||||
for _database, name, _schema, _type in results:
|
||||
try:
|
||||
_type = self.Relation.get_relation_type(_type)
|
||||
except ValueError:
|
||||
_type = self.Relation.External
|
||||
relations.append(self.Relation.create(
|
||||
relations.append(
|
||||
self.Relation.create(
|
||||
database=_database,
|
||||
schema=_schema,
|
||||
identifier=name,
|
||||
quote_policy=quote_policy,
|
||||
type=_type
|
||||
))
|
||||
type=_type,
|
||||
)
|
||||
)
|
||||
return relations
|
||||
|
||||
def quote(self, identifier):
|
||||
@@ -228,8 +201,7 @@ class SQLAdapter(BaseAdapter):
|
||||
|
||||
def list_schemas(self, database: str) -> List[str]:
|
||||
results = self.execute_macro(
|
||||
LIST_SCHEMAS_MACRO_NAME,
|
||||
kwargs={'database': database}
|
||||
LIST_SCHEMAS_MACRO_NAME, kwargs={"database": database}
|
||||
)
|
||||
|
||||
return [row[0] for row in results]
|
||||
@@ -238,13 +210,10 @@ class SQLAdapter(BaseAdapter):
|
||||
information_schema = self.Relation.create(
|
||||
database=database,
|
||||
schema=schema,
|
||||
identifier='INFORMATION_SCHEMA',
|
||||
quote_policy=self.config.quoting
|
||||
identifier="INFORMATION_SCHEMA",
|
||||
quote_policy=self.config.quoting,
|
||||
).information_schema()
|
||||
|
||||
kwargs = {'information_schema': information_schema, 'schema': schema}
|
||||
results = self.execute_macro(
|
||||
CHECK_SCHEMA_EXISTS_MACRO_NAME,
|
||||
kwargs=kwargs
|
||||
)
|
||||
kwargs = {"information_schema": information_schema, "schema": schema}
|
||||
results = self.execute_macro(CHECK_SCHEMA_EXISTS_MACRO_NAME, kwargs=kwargs)
|
||||
return results[0][0] > 0
|
||||
|
||||
@@ -10,79 +10,89 @@ def regex(pat):
|
||||
|
||||
class BlockData:
|
||||
"""raw plaintext data from the top level of the file."""
|
||||
|
||||
def __init__(self, contents):
|
||||
self.block_type_name = '__dbt__data'
|
||||
self.block_type_name = "__dbt__data"
|
||||
self.contents = contents
|
||||
self.full_block = contents
|
||||
|
||||
|
||||
class BlockTag:
|
||||
def __init__(self, block_type_name, block_name, contents=None,
|
||||
full_block=None, **kw):
|
||||
def __init__(
|
||||
self, block_type_name, block_name, contents=None, full_block=None, **kw
|
||||
):
|
||||
self.block_type_name = block_type_name
|
||||
self.block_name = block_name
|
||||
self.contents = contents
|
||||
self.full_block = full_block
|
||||
|
||||
def __str__(self):
|
||||
return 'BlockTag({!r}, {!r})'.format(self.block_type_name,
|
||||
self.block_name)
|
||||
return "BlockTag({!r}, {!r})".format(self.block_type_name, self.block_name)
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
@property
|
||||
def end_block_type_name(self):
|
||||
return 'end{}'.format(self.block_type_name)
|
||||
return "end{}".format(self.block_type_name)
|
||||
|
||||
def end_pat(self):
|
||||
# we don't want to use string formatting here because jinja uses most
|
||||
# of the string formatting operators in its syntax...
|
||||
pattern = ''.join((
|
||||
r'(?P<endblock>((?:\s*\{\%\-|\{\%)\s*',
|
||||
pattern = "".join(
|
||||
(
|
||||
r"(?P<endblock>((?:\s*\{\%\-|\{\%)\s*",
|
||||
self.end_block_type_name,
|
||||
r'\s*(?:\-\%\}\s*|\%\})))',
|
||||
))
|
||||
r"\s*(?:\-\%\}\s*|\%\})))",
|
||||
)
|
||||
)
|
||||
return regex(pattern)
|
||||
|
||||
|
||||
Tag = namedtuple('Tag', 'block_type_name block_name start end')
|
||||
Tag = namedtuple("Tag", "block_type_name block_name start end")
|
||||
|
||||
|
||||
_NAME_PATTERN = r'[A-Za-z_][A-Za-z_0-9]*'
|
||||
_NAME_PATTERN = r"[A-Za-z_][A-Za-z_0-9]*"
|
||||
|
||||
COMMENT_START_PATTERN = regex(r'(?:(?P<comment_start>(\s*\{\#)))')
|
||||
COMMENT_END_PATTERN = regex(r'(.*?)(\s*\#\})')
|
||||
COMMENT_START_PATTERN = regex(r"(?:(?P<comment_start>(\s*\{\#)))")
|
||||
COMMENT_END_PATTERN = regex(r"(.*?)(\s*\#\})")
|
||||
RAW_START_PATTERN = regex(
|
||||
r'(?:\s*\{\%\-|\{\%)\s*(?P<raw_start>(raw))\s*(?:\-\%\}\s*|\%\})'
|
||||
r"(?:\s*\{\%\-|\{\%)\s*(?P<raw_start>(raw))\s*(?:\-\%\}\s*|\%\})"
|
||||
)
|
||||
EXPR_START_PATTERN = regex(r'(?P<expr_start>(\{\{\s*))')
|
||||
EXPR_END_PATTERN = regex(r'(?P<expr_end>(\s*\}\}))')
|
||||
EXPR_START_PATTERN = regex(r"(?P<expr_start>(\{\{\s*))")
|
||||
EXPR_END_PATTERN = regex(r"(?P<expr_end>(\s*\}\}))")
|
||||
|
||||
BLOCK_START_PATTERN = regex(''.join((
|
||||
r'(?:\s*\{\%\-|\{\%)\s*',
|
||||
r'(?P<block_type_name>({}))'.format(_NAME_PATTERN),
|
||||
BLOCK_START_PATTERN = regex(
|
||||
"".join(
|
||||
(
|
||||
r"(?:\s*\{\%\-|\{\%)\s*",
|
||||
r"(?P<block_type_name>({}))".format(_NAME_PATTERN),
|
||||
# some blocks have a 'block name'.
|
||||
r'(?:\s+(?P<block_name>({})))?'.format(_NAME_PATTERN),
|
||||
)))
|
||||
r"(?:\s+(?P<block_name>({})))?".format(_NAME_PATTERN),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
RAW_BLOCK_PATTERN = regex(''.join((
|
||||
r'(?:\s*\{\%\-|\{\%)\s*raw\s*(?:\-\%\}\s*|\%\})',
|
||||
r'(?:.*?)',
|
||||
r'(?:\s*\{\%\-|\{\%)\s*endraw\s*(?:\-\%\}\s*|\%\})',
|
||||
)))
|
||||
RAW_BLOCK_PATTERN = regex(
|
||||
"".join(
|
||||
(
|
||||
r"(?:\s*\{\%\-|\{\%)\s*raw\s*(?:\-\%\}\s*|\%\})",
|
||||
r"(?:.*?)",
|
||||
r"(?:\s*\{\%\-|\{\%)\s*endraw\s*(?:\-\%\}\s*|\%\})",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
TAG_CLOSE_PATTERN = regex(r'(?:(?P<tag_close>(\-\%\}\s*|\%\})))')
|
||||
TAG_CLOSE_PATTERN = regex(r"(?:(?P<tag_close>(\-\%\}\s*|\%\})))")
|
||||
|
||||
# stolen from jinja's lexer. Note that we've consumed all prefix whitespace by
|
||||
# the time we want to use this.
|
||||
STRING_PATTERN = regex(
|
||||
r"(?P<string>('([^'\\]*(?:\\.[^'\\]*)*)'|"
|
||||
r'"([^"\\]*(?:\\.[^"\\]*)*)"))'
|
||||
r"(?P<string>('([^'\\]*(?:\\.[^'\\]*)*)'|" r'"([^"\\]*(?:\\.[^"\\]*)*)"))'
|
||||
)
|
||||
|
||||
QUOTE_START_PATTERN = regex(r'''(?P<quote>(['"]))''')
|
||||
QUOTE_START_PATTERN = regex(r"""(?P<quote>(['"]))""")
|
||||
|
||||
|
||||
class TagIterator:
|
||||
@@ -99,10 +109,10 @@ class TagIterator:
|
||||
end_val: int = self.pos if end is None else end
|
||||
data = self.data[:end_val]
|
||||
# if not found, rfind returns -1, and -1+1=0, which is perfect!
|
||||
last_line_start = data.rfind('\n') + 1
|
||||
last_line_start = data.rfind("\n") + 1
|
||||
# it's easy to forget this, but line numbers are 1-indexed
|
||||
line_number = data.count('\n') + 1
|
||||
return f'{line_number}:{end_val - last_line_start}'
|
||||
line_number = data.count("\n") + 1
|
||||
return f"{line_number}:{end_val - last_line_start}"
|
||||
|
||||
def advance(self, new_position):
|
||||
self.pos = new_position
|
||||
@@ -120,7 +130,7 @@ class TagIterator:
|
||||
matches = []
|
||||
for pattern in patterns:
|
||||
# default to 'search', but sometimes we want to 'match'.
|
||||
if kwargs.get('method', 'search') == 'search':
|
||||
if kwargs.get("method", "search") == "search":
|
||||
match = self._search(pattern)
|
||||
else:
|
||||
match = self._match(pattern)
|
||||
@@ -156,22 +166,20 @@ class TagIterator:
|
||||
"""
|
||||
self.advance(match.end())
|
||||
while True:
|
||||
match = self._expect_match('}}',
|
||||
EXPR_END_PATTERN,
|
||||
QUOTE_START_PATTERN)
|
||||
if match.groupdict().get('expr_end') is not None:
|
||||
match = self._expect_match("}}", EXPR_END_PATTERN, QUOTE_START_PATTERN)
|
||||
if match.groupdict().get("expr_end") is not None:
|
||||
break
|
||||
else:
|
||||
# it's a quote. we haven't advanced for this match yet, so
|
||||
# just slurp up the whole string, no need to rewind.
|
||||
match = self._expect_match('string', STRING_PATTERN)
|
||||
match = self._expect_match("string", STRING_PATTERN)
|
||||
self.advance(match.end())
|
||||
|
||||
self.advance(match.end())
|
||||
|
||||
def handle_comment(self, match):
|
||||
self.advance(match.end())
|
||||
match = self._expect_match('#}', COMMENT_END_PATTERN)
|
||||
match = self._expect_match("#}", COMMENT_END_PATTERN)
|
||||
self.advance(match.end())
|
||||
|
||||
def _expect_block_close(self):
|
||||
@@ -188,22 +196,19 @@ class TagIterator:
|
||||
"""
|
||||
while True:
|
||||
end_match = self._expect_match(
|
||||
'tag close ("%}")',
|
||||
QUOTE_START_PATTERN,
|
||||
TAG_CLOSE_PATTERN
|
||||
'tag close ("%}")', QUOTE_START_PATTERN, TAG_CLOSE_PATTERN
|
||||
)
|
||||
self.advance(end_match.end())
|
||||
if end_match.groupdict().get('tag_close') is not None:
|
||||
if end_match.groupdict().get("tag_close") is not None:
|
||||
return
|
||||
# must be a string. Rewind to its start and advance past it.
|
||||
self.rewind()
|
||||
string_match = self._expect_match('string', STRING_PATTERN)
|
||||
string_match = self._expect_match("string", STRING_PATTERN)
|
||||
self.advance(string_match.end())
|
||||
|
||||
def handle_raw(self):
|
||||
# raw blocks are super special, they are a single complete regex
|
||||
match = self._expect_match('{% raw %}...{% endraw %}',
|
||||
RAW_BLOCK_PATTERN)
|
||||
match = self._expect_match("{% raw %}...{% endraw %}", RAW_BLOCK_PATTERN)
|
||||
self.advance(match.end())
|
||||
return match.end()
|
||||
|
||||
@@ -220,13 +225,12 @@ class TagIterator:
|
||||
"""
|
||||
groups = match.groupdict()
|
||||
# always a value
|
||||
block_type_name = groups['block_type_name']
|
||||
block_type_name = groups["block_type_name"]
|
||||
# might be None
|
||||
block_name = groups.get('block_name')
|
||||
block_name = groups.get("block_name")
|
||||
start_pos = self.pos
|
||||
if block_type_name == 'raw':
|
||||
match = self._expect_match('{% raw %}...{% endraw %}',
|
||||
RAW_BLOCK_PATTERN)
|
||||
if block_type_name == "raw":
|
||||
match = self._expect_match("{% raw %}...{% endraw %}", RAW_BLOCK_PATTERN)
|
||||
self.advance(match.end())
|
||||
else:
|
||||
self.advance(match.end())
|
||||
@@ -235,15 +239,13 @@ class TagIterator:
|
||||
block_type_name=block_type_name,
|
||||
block_name=block_name,
|
||||
start=start_pos,
|
||||
end=self.pos
|
||||
end=self.pos,
|
||||
)
|
||||
|
||||
def find_tags(self):
|
||||
while True:
|
||||
match = self._first_match(
|
||||
BLOCK_START_PATTERN,
|
||||
COMMENT_START_PATTERN,
|
||||
EXPR_START_PATTERN
|
||||
BLOCK_START_PATTERN, COMMENT_START_PATTERN, EXPR_START_PATTERN
|
||||
)
|
||||
if match is None:
|
||||
break
|
||||
@@ -252,9 +254,9 @@ class TagIterator:
|
||||
# start = self.pos
|
||||
|
||||
groups = match.groupdict()
|
||||
comment_start = groups.get('comment_start')
|
||||
expr_start = groups.get('expr_start')
|
||||
block_type_name = groups.get('block_type_name')
|
||||
comment_start = groups.get("comment_start")
|
||||
expr_start = groups.get("expr_start")
|
||||
block_type_name = groups.get("block_type_name")
|
||||
|
||||
if comment_start is not None:
|
||||
self.handle_comment(match)
|
||||
@@ -264,8 +266,8 @@ class TagIterator:
|
||||
yield self.handle_tag(match)
|
||||
else:
|
||||
raise dbt.exceptions.InternalException(
|
||||
'Invalid regex match in next_block, expected block start, '
|
||||
'expr start, or comment start'
|
||||
"Invalid regex match in next_block, expected block start, "
|
||||
"expr start, or comment start"
|
||||
)
|
||||
|
||||
def __iter__(self):
|
||||
@@ -273,21 +275,18 @@ class TagIterator:
|
||||
|
||||
|
||||
duplicate_tags = (
|
||||
'Got nested tags: {outer.block_type_name} (started at {outer.start}) did '
|
||||
'not have a matching {{% end{outer.block_type_name} %}} before a '
|
||||
'subsequent {inner.block_type_name} was found (started at {inner.start})'
|
||||
"Got nested tags: {outer.block_type_name} (started at {outer.start}) did "
|
||||
"not have a matching {{% end{outer.block_type_name} %}} before a "
|
||||
"subsequent {inner.block_type_name} was found (started at {inner.start})"
|
||||
)
|
||||
|
||||
|
||||
_CONTROL_FLOW_TAGS = {
|
||||
'if': 'endif',
|
||||
'for': 'endfor',
|
||||
"if": "endif",
|
||||
"for": "endfor",
|
||||
}
|
||||
|
||||
_CONTROL_FLOW_END_TAGS = {
|
||||
v: k
|
||||
for k, v in _CONTROL_FLOW_TAGS.items()
|
||||
}
|
||||
_CONTROL_FLOW_END_TAGS = {v: k for k, v in _CONTROL_FLOW_TAGS.items()}
|
||||
|
||||
|
||||
class BlockIterator:
|
||||
@@ -310,15 +309,15 @@ class BlockIterator:
|
||||
|
||||
def is_current_end(self, tag):
|
||||
return (
|
||||
tag.block_type_name.startswith('end') and
|
||||
self.current is not None and
|
||||
tag.block_type_name[3:] == self.current.block_type_name
|
||||
tag.block_type_name.startswith("end")
|
||||
and self.current is not None
|
||||
and tag.block_type_name[3:] == self.current.block_type_name
|
||||
)
|
||||
|
||||
def find_blocks(self, allowed_blocks=None, collect_raw_data=True):
|
||||
"""Find all top-level blocks in the data."""
|
||||
if allowed_blocks is None:
|
||||
allowed_blocks = {'snapshot', 'macro', 'materialization', 'docs'}
|
||||
allowed_blocks = {"snapshot", "macro", "materialization", "docs"}
|
||||
|
||||
for tag in self.tag_parser.find_tags():
|
||||
if tag.block_type_name in _CONTROL_FLOW_TAGS:
|
||||
@@ -329,31 +328,37 @@ class BlockIterator:
|
||||
found = self.stack.pop()
|
||||
else:
|
||||
expected = _CONTROL_FLOW_END_TAGS[tag.block_type_name]
|
||||
dbt.exceptions.raise_compiler_error((
|
||||
'Got an unexpected control flow end tag, got {} but '
|
||||
'never saw a preceeding {} (@ {})'
|
||||
dbt.exceptions.raise_compiler_error(
|
||||
(
|
||||
"Got an unexpected control flow end tag, got {} but "
|
||||
"never saw a preceeding {} (@ {})"
|
||||
).format(
|
||||
tag.block_type_name,
|
||||
expected,
|
||||
self.tag_parser.linepos(tag.start)
|
||||
))
|
||||
self.tag_parser.linepos(tag.start),
|
||||
)
|
||||
)
|
||||
expected = _CONTROL_FLOW_TAGS[found]
|
||||
if expected != tag.block_type_name:
|
||||
dbt.exceptions.raise_compiler_error((
|
||||
'Got an unexpected control flow end tag, got {} but '
|
||||
'expected {} next (@ {})'
|
||||
dbt.exceptions.raise_compiler_error(
|
||||
(
|
||||
"Got an unexpected control flow end tag, got {} but "
|
||||
"expected {} next (@ {})"
|
||||
).format(
|
||||
tag.block_type_name,
|
||||
expected,
|
||||
self.tag_parser.linepos(tag.start)
|
||||
))
|
||||
self.tag_parser.linepos(tag.start),
|
||||
)
|
||||
)
|
||||
|
||||
if tag.block_type_name in allowed_blocks:
|
||||
if self.stack:
|
||||
dbt.exceptions.raise_compiler_error((
|
||||
'Got a block definition inside control flow at {}. '
|
||||
'All dbt block definitions must be at the top level'
|
||||
).format(self.tag_parser.linepos(tag.start)))
|
||||
dbt.exceptions.raise_compiler_error(
|
||||
(
|
||||
"Got a block definition inside control flow at {}. "
|
||||
"All dbt block definitions must be at the top level"
|
||||
).format(self.tag_parser.linepos(tag.start))
|
||||
)
|
||||
if self.current is not None:
|
||||
dbt.exceptions.raise_compiler_error(
|
||||
duplicate_tags.format(outer=self.current, inner=tag)
|
||||
@@ -372,16 +377,18 @@ class BlockIterator:
|
||||
block_type_name=self.current.block_type_name,
|
||||
block_name=self.current.block_name,
|
||||
contents=self.data[self.current.end : tag.start],
|
||||
full_block=self.data[self.current.start:tag.end]
|
||||
full_block=self.data[self.current.start : tag.end],
|
||||
)
|
||||
self.current = None
|
||||
|
||||
if self.current:
|
||||
linecount = self.data[:self.current.end].count('\n') + 1
|
||||
dbt.exceptions.raise_compiler_error((
|
||||
'Reached EOF without finding a close tag for '
|
||||
'{} (searched from line {})'
|
||||
).format(self.current.block_type_name, linecount))
|
||||
linecount = self.data[: self.current.end].count("\n") + 1
|
||||
dbt.exceptions.raise_compiler_error(
|
||||
(
|
||||
"Reached EOF without finding a close tag for "
|
||||
"{} (searched from line {})"
|
||||
).format(self.current.block_type_name, linecount)
|
||||
)
|
||||
|
||||
if collect_raw_data:
|
||||
raw_data = self.data[self.last_position :]
|
||||
@@ -389,5 +396,8 @@ class BlockIterator:
|
||||
yield BlockData(raw_data)
|
||||
|
||||
def lex_for_blocks(self, allowed_blocks=None, collect_raw_data=True):
|
||||
return list(self.find_blocks(allowed_blocks=allowed_blocks,
|
||||
collect_raw_data=collect_raw_data))
|
||||
return list(
|
||||
self.find_blocks(
|
||||
allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data
|
||||
)
|
||||
)
|
||||
|
||||
@@ -10,7 +10,7 @@ from typing import Iterable, List, Dict, Union, Optional, Any
|
||||
from dbt.exceptions import RuntimeException
|
||||
|
||||
|
||||
BOM = BOM_UTF8.decode('utf-8') # '\ufeff'
|
||||
BOM = BOM_UTF8.decode("utf-8") # '\ufeff'
|
||||
|
||||
|
||||
class ISODateTime(agate.data_types.DateTime):
|
||||
@@ -30,28 +30,23 @@ class ISODateTime(agate.data_types.DateTime):
|
||||
except: # noqa
|
||||
pass
|
||||
|
||||
raise agate.exceptions.CastError(
|
||||
'Can not parse value "%s" as datetime.' % d
|
||||
)
|
||||
raise agate.exceptions.CastError('Can not parse value "%s" as datetime.' % d)
|
||||
|
||||
|
||||
def build_type_tester(text_columns: Iterable[str]) -> agate.TypeTester:
|
||||
types = [
|
||||
agate.data_types.Number(null_values=('null', '')),
|
||||
agate.data_types.Date(null_values=('null', ''),
|
||||
date_format='%Y-%m-%d'),
|
||||
agate.data_types.DateTime(null_values=('null', ''),
|
||||
datetime_format='%Y-%m-%d %H:%M:%S'),
|
||||
ISODateTime(null_values=('null', '')),
|
||||
agate.data_types.Boolean(true_values=('true',),
|
||||
false_values=('false',),
|
||||
null_values=('null', '')),
|
||||
agate.data_types.Text(null_values=('null', ''))
|
||||
agate.data_types.Number(null_values=("null", "")),
|
||||
agate.data_types.Date(null_values=("null", ""), date_format="%Y-%m-%d"),
|
||||
agate.data_types.DateTime(
|
||||
null_values=("null", ""), datetime_format="%Y-%m-%d %H:%M:%S"
|
||||
),
|
||||
ISODateTime(null_values=("null", "")),
|
||||
agate.data_types.Boolean(
|
||||
true_values=("true",), false_values=("false",), null_values=("null", "")
|
||||
),
|
||||
agate.data_types.Text(null_values=("null", "")),
|
||||
]
|
||||
force = {
|
||||
k: agate.data_types.Text(null_values=('null', ''))
|
||||
for k in text_columns
|
||||
}
|
||||
force = {k: agate.data_types.Text(null_values=("null", "")) for k in text_columns}
|
||||
return agate.TypeTester(force=force, types=types)
|
||||
|
||||
|
||||
@@ -115,7 +110,7 @@ def as_matrix(table):
|
||||
|
||||
def from_csv(abspath, text_columns):
|
||||
type_tester = build_type_tester(text_columns=text_columns)
|
||||
with open(abspath, encoding='utf-8') as fp:
|
||||
with open(abspath, encoding="utf-8") as fp:
|
||||
if fp.read(1) != BOM:
|
||||
fp.seek(0)
|
||||
return agate.Table.from_csv(fp, column_types=type_tester)
|
||||
@@ -147,8 +142,8 @@ class ColumnTypeBuilder(Dict[str, NullableAgateType]):
|
||||
elif not isinstance(value, type(existing_type)):
|
||||
# actual type mismatch!
|
||||
raise RuntimeException(
|
||||
f'Tables contain columns with the same names ({key}), '
|
||||
f'but different types ({value} vs {existing_type})'
|
||||
f"Tables contain columns with the same names ({key}), "
|
||||
f"but different types ({value} vs {existing_type})"
|
||||
)
|
||||
|
||||
def finalize(self) -> Dict[str, agate.data_types.DataType]:
|
||||
@@ -163,7 +158,7 @@ class ColumnTypeBuilder(Dict[str, NullableAgateType]):
|
||||
|
||||
|
||||
def _merged_column_types(
|
||||
tables: List[agate.Table]
|
||||
tables: List[agate.Table],
|
||||
) -> Dict[str, agate.data_types.DataType]:
|
||||
# this is a lot like agate.Table.merge, but with handling for all-null
|
||||
# rows being "any type".
|
||||
@@ -190,10 +185,7 @@ def merge_tables(tables: List[agate.Table]) -> agate.Table:
|
||||
|
||||
rows: List[agate.Row] = []
|
||||
for table in tables:
|
||||
if (
|
||||
table.column_names == column_names and
|
||||
table.column_types == column_types
|
||||
):
|
||||
if table.column_names == column_names and table.column_types == column_types:
|
||||
rows.extend(table.rows)
|
||||
else:
|
||||
for row in table.rows:
|
||||
|
||||
@@ -12,7 +12,7 @@ https://cloud.google.com/sdk/
|
||||
|
||||
def gcloud_installed():
|
||||
try:
|
||||
run_cmd('.', ['gcloud', '--version'])
|
||||
run_cmd(".", ["gcloud", "--version"])
|
||||
return True
|
||||
except OSError as e:
|
||||
logger.debug(e)
|
||||
@@ -21,6 +21,6 @@ def gcloud_installed():
|
||||
|
||||
def setup_default_credentials():
|
||||
if gcloud_installed():
|
||||
run_cmd('.', ["gcloud", "auth", "application-default", "login"])
|
||||
run_cmd(".", ["gcloud", "auth", "application-default", "login"])
|
||||
else:
|
||||
raise dbt.exceptions.RuntimeException(NOT_INSTALLED_MSG)
|
||||
|
||||
@@ -7,77 +7,74 @@ import dbt.exceptions
|
||||
|
||||
|
||||
def clone(repo, cwd, dirname=None, remove_git_dir=False, branch=None):
|
||||
clone_cmd = ['git', 'clone', '--depth', '1']
|
||||
clone_cmd = ["git", "clone", "--depth", "1"]
|
||||
|
||||
if branch is not None:
|
||||
clone_cmd.extend(['--branch', branch])
|
||||
clone_cmd.extend(["--branch", branch])
|
||||
|
||||
clone_cmd.append(repo)
|
||||
|
||||
if dirname is not None:
|
||||
clone_cmd.append(dirname)
|
||||
|
||||
result = run_cmd(cwd, clone_cmd, env={'LC_ALL': 'C'})
|
||||
result = run_cmd(cwd, clone_cmd, env={"LC_ALL": "C"})
|
||||
|
||||
if remove_git_dir:
|
||||
rmdir(os.path.join(dirname, '.git'))
|
||||
rmdir(os.path.join(dirname, ".git"))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def list_tags(cwd):
|
||||
out, err = run_cmd(cwd, ['git', 'tag', '--list'], env={'LC_ALL': 'C'})
|
||||
tags = out.decode('utf-8').strip().split("\n")
|
||||
out, err = run_cmd(cwd, ["git", "tag", "--list"], env={"LC_ALL": "C"})
|
||||
tags = out.decode("utf-8").strip().split("\n")
|
||||
return tags
|
||||
|
||||
|
||||
def _checkout(cwd, repo, branch):
|
||||
logger.debug(' Checking out branch {}.'.format(branch))
|
||||
logger.debug(" Checking out branch {}.".format(branch))
|
||||
|
||||
run_cmd(cwd, ['git', 'remote', 'set-branches', 'origin', branch])
|
||||
run_cmd(cwd, ['git', 'fetch', '--tags', '--depth', '1', 'origin', branch])
|
||||
run_cmd(cwd, ["git", "remote", "set-branches", "origin", branch])
|
||||
run_cmd(cwd, ["git", "fetch", "--tags", "--depth", "1", "origin", branch])
|
||||
|
||||
tags = list_tags(cwd)
|
||||
|
||||
# Prefer tags to branches if one exists
|
||||
if branch in tags:
|
||||
spec = 'tags/{}'.format(branch)
|
||||
spec = "tags/{}".format(branch)
|
||||
else:
|
||||
spec = 'origin/{}'.format(branch)
|
||||
spec = "origin/{}".format(branch)
|
||||
|
||||
out, err = run_cmd(cwd, ['git', 'reset', '--hard', spec],
|
||||
env={'LC_ALL': 'C'})
|
||||
out, err = run_cmd(cwd, ["git", "reset", "--hard", spec], env={"LC_ALL": "C"})
|
||||
return out, err
|
||||
|
||||
|
||||
def checkout(cwd, repo, branch=None):
|
||||
if branch is None:
|
||||
branch = 'HEAD'
|
||||
branch = "HEAD"
|
||||
try:
|
||||
return _checkout(cwd, repo, branch)
|
||||
except dbt.exceptions.CommandResultError as exc:
|
||||
stderr = exc.stderr.decode('utf-8').strip()
|
||||
stderr = exc.stderr.decode("utf-8").strip()
|
||||
dbt.exceptions.bad_package_spec(repo, branch, stderr)
|
||||
|
||||
|
||||
def get_current_sha(cwd):
|
||||
out, err = run_cmd(cwd, ['git', 'rev-parse', 'HEAD'], env={'LC_ALL': 'C'})
|
||||
out, err = run_cmd(cwd, ["git", "rev-parse", "HEAD"], env={"LC_ALL": "C"})
|
||||
|
||||
return out.decode('utf-8')
|
||||
return out.decode("utf-8")
|
||||
|
||||
|
||||
def remove_remote(cwd):
|
||||
return run_cmd(cwd, ['git', 'remote', 'rm', 'origin'], env={'LC_ALL': 'C'})
|
||||
return run_cmd(cwd, ["git", "remote", "rm", "origin"], env={"LC_ALL": "C"})
|
||||
|
||||
|
||||
def clone_and_checkout(repo, cwd, dirname=None, remove_git_dir=False,
|
||||
branch=None):
|
||||
def clone_and_checkout(repo, cwd, dirname=None, remove_git_dir=False, branch=None):
|
||||
exists = None
|
||||
try:
|
||||
_, err = clone(repo, cwd, dirname=dirname,
|
||||
remove_git_dir=remove_git_dir)
|
||||
_, err = clone(repo, cwd, dirname=dirname, remove_git_dir=remove_git_dir)
|
||||
except dbt.exceptions.CommandResultError as exc:
|
||||
err = exc.stderr.decode('utf-8')
|
||||
err = exc.stderr.decode("utf-8")
|
||||
exists = re.match("fatal: destination path '(.+)' already exists", err)
|
||||
if not exists: # something else is wrong, raise it
|
||||
raise
|
||||
@@ -86,25 +83,26 @@ def clone_and_checkout(repo, cwd, dirname=None, remove_git_dir=False,
|
||||
start_sha = None
|
||||
if exists:
|
||||
directory = exists.group(1)
|
||||
logger.debug('Updating existing dependency {}.', directory)
|
||||
logger.debug("Updating existing dependency {}.", directory)
|
||||
else:
|
||||
matches = re.match("Cloning into '(.+)'", err.decode('utf-8'))
|
||||
matches = re.match("Cloning into '(.+)'", err.decode("utf-8"))
|
||||
if matches is None:
|
||||
raise dbt.exceptions.RuntimeException(
|
||||
f'Error cloning {repo} - never saw "Cloning into ..." from git'
|
||||
)
|
||||
directory = matches.group(1)
|
||||
logger.debug('Pulling new dependency {}.', directory)
|
||||
logger.debug("Pulling new dependency {}.", directory)
|
||||
full_path = os.path.join(cwd, directory)
|
||||
start_sha = get_current_sha(full_path)
|
||||
checkout(full_path, repo, branch)
|
||||
end_sha = get_current_sha(full_path)
|
||||
if exists:
|
||||
if start_sha == end_sha:
|
||||
logger.debug(' Already at {}, nothing to do.', start_sha[:7])
|
||||
logger.debug(" Already at {}, nothing to do.", start_sha[:7])
|
||||
else:
|
||||
logger.debug(' Updated checkout from {} to {}.',
|
||||
start_sha[:7], end_sha[:7])
|
||||
logger.debug(
|
||||
" Updated checkout from {} to {}.", start_sha[:7], end_sha[:7]
|
||||
)
|
||||
else:
|
||||
logger.debug(' Checked out at {}.', end_sha[:7])
|
||||
logger.debug(" Checked out at {}.", end_sha[:7])
|
||||
return directory
|
||||
|
||||
@@ -8,8 +8,17 @@ from ast import literal_eval
|
||||
from contextlib import contextmanager
|
||||
from itertools import chain, islice
|
||||
from typing import (
|
||||
List, Union, Set, Optional, Dict, Any, Iterator, Type, NoReturn, Tuple,
|
||||
Callable
|
||||
List,
|
||||
Union,
|
||||
Set,
|
||||
Optional,
|
||||
Dict,
|
||||
Any,
|
||||
Iterator,
|
||||
Type,
|
||||
NoReturn,
|
||||
Tuple,
|
||||
Callable,
|
||||
)
|
||||
|
||||
import jinja2
|
||||
@@ -20,16 +29,22 @@ import jinja2.parser
|
||||
import jinja2.sandbox
|
||||
|
||||
from dbt.utils import (
|
||||
get_dbt_macro_name, get_docs_macro_name, get_materialization_macro_name,
|
||||
deep_map
|
||||
get_dbt_macro_name,
|
||||
get_docs_macro_name,
|
||||
get_materialization_macro_name,
|
||||
deep_map,
|
||||
)
|
||||
|
||||
from dbt.clients._jinja_blocks import BlockIterator, BlockData, BlockTag
|
||||
from dbt.contracts.graph.compiled import CompiledSchemaTestNode
|
||||
from dbt.contracts.graph.parsed import ParsedSchemaTestNode
|
||||
from dbt.exceptions import (
|
||||
InternalException, raise_compiler_error, CompilationException,
|
||||
invalid_materialization_argument, MacroReturn, JinjaRenderingException
|
||||
InternalException,
|
||||
raise_compiler_error,
|
||||
CompilationException,
|
||||
invalid_materialization_argument,
|
||||
MacroReturn,
|
||||
JinjaRenderingException,
|
||||
)
|
||||
from dbt import flags
|
||||
from dbt.logger import GLOBAL_LOGGER as logger # noqa
|
||||
@@ -40,26 +55,26 @@ def _linecache_inject(source, write):
|
||||
# this is the only reliable way to accomplish this. Obviously, it's
|
||||
# really darn noisy and will fill your temporary directory
|
||||
tmp_file = tempfile.NamedTemporaryFile(
|
||||
prefix='dbt-macro-compiled-',
|
||||
suffix='.py',
|
||||
prefix="dbt-macro-compiled-",
|
||||
suffix=".py",
|
||||
delete=False,
|
||||
mode='w+',
|
||||
encoding='utf-8',
|
||||
mode="w+",
|
||||
encoding="utf-8",
|
||||
)
|
||||
tmp_file.write(source)
|
||||
filename = tmp_file.name
|
||||
else:
|
||||
# `codecs.encode` actually takes a `bytes` as the first argument if
|
||||
# the second argument is 'hex' - mypy does not know this.
|
||||
rnd = codecs.encode(os.urandom(12), 'hex') # type: ignore
|
||||
filename = rnd.decode('ascii')
|
||||
rnd = codecs.encode(os.urandom(12), "hex") # type: ignore
|
||||
filename = rnd.decode("ascii")
|
||||
|
||||
# put ourselves in the cache
|
||||
cache_entry = (
|
||||
len(source),
|
||||
None,
|
||||
[line + '\n' for line in source.splitlines()],
|
||||
filename
|
||||
[line + "\n" for line in source.splitlines()],
|
||||
filename,
|
||||
)
|
||||
# linecache does in fact have an attribute `cache`, thanks
|
||||
linecache.cache[filename] = cache_entry # type: ignore
|
||||
@@ -73,12 +88,10 @@ class MacroFuzzParser(jinja2.parser.Parser):
|
||||
# modified to fuzz macros defined in the same file. this way
|
||||
# dbt can understand the stack of macros being called.
|
||||
# - @cmcarthur
|
||||
node.name = get_dbt_macro_name(
|
||||
self.parse_assign_target(name_only=True).name)
|
||||
node.name = get_dbt_macro_name(self.parse_assign_target(name_only=True).name)
|
||||
|
||||
self.parse_signature(node)
|
||||
node.body = self.parse_statements(('name:endmacro',),
|
||||
drop_needle=True)
|
||||
node.body = self.parse_statements(("name:endmacro",), drop_needle=True)
|
||||
return node
|
||||
|
||||
|
||||
@@ -94,8 +107,8 @@ class MacroFuzzEnvironment(jinja2.sandbox.SandboxedEnvironment):
|
||||
If the value is 'write', also write the files to disk.
|
||||
WARNING: This can write a ton of data if you aren't careful.
|
||||
"""
|
||||
if filename == '<template>' and flags.MACRO_DEBUGGING:
|
||||
write = flags.MACRO_DEBUGGING == 'write'
|
||||
if filename == "<template>" and flags.MACRO_DEBUGGING:
|
||||
write = flags.MACRO_DEBUGGING == "write"
|
||||
filename = _linecache_inject(source, write)
|
||||
|
||||
return super()._compile(source, filename) # type: ignore
|
||||
@@ -138,7 +151,7 @@ def quoted_native_concat(nodes):
|
||||
head = list(islice(nodes, 2))
|
||||
|
||||
if not head:
|
||||
return ''
|
||||
return ""
|
||||
|
||||
if len(head) == 1:
|
||||
raw = head[0]
|
||||
@@ -180,9 +193,7 @@ class NativeSandboxTemplate(jinja2.nativetypes.NativeTemplate): # mypy: ignore
|
||||
vars = dict(*args, **kwargs)
|
||||
|
||||
try:
|
||||
return quoted_native_concat(
|
||||
self.root_render_func(self.new_context(vars))
|
||||
)
|
||||
return quoted_native_concat(self.root_render_func(self.new_context(vars)))
|
||||
except Exception:
|
||||
return self.environment.handle_exception()
|
||||
|
||||
@@ -221,10 +232,10 @@ class BaseMacroGenerator:
|
||||
self.context: Optional[Dict[str, Any]] = context
|
||||
|
||||
def get_template(self):
|
||||
raise NotImplementedError('get_template not implemented!')
|
||||
raise NotImplementedError("get_template not implemented!")
|
||||
|
||||
def get_name(self) -> str:
|
||||
raise NotImplementedError('get_name not implemented!')
|
||||
raise NotImplementedError("get_name not implemented!")
|
||||
|
||||
def get_macro(self):
|
||||
name = self.get_name()
|
||||
@@ -247,9 +258,7 @@ class BaseMacroGenerator:
|
||||
def call_macro(self, *args, **kwargs):
|
||||
# called from __call__ methods
|
||||
if self.context is None:
|
||||
raise InternalException(
|
||||
'Context is still None in call_macro!'
|
||||
)
|
||||
raise InternalException("Context is still None in call_macro!")
|
||||
assert self.context is not None
|
||||
|
||||
macro = self.get_macro()
|
||||
@@ -276,7 +285,7 @@ class MacroStack(threading.local):
|
||||
def pop(self, name):
|
||||
got = self.call_stack.pop()
|
||||
if got != name:
|
||||
raise InternalException(f'popped {got}, expected {name}')
|
||||
raise InternalException(f"popped {got}, expected {name}")
|
||||
|
||||
|
||||
class MacroGenerator(BaseMacroGenerator):
|
||||
@@ -285,7 +294,7 @@ class MacroGenerator(BaseMacroGenerator):
|
||||
macro,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
node: Optional[Any] = None,
|
||||
stack: Optional[MacroStack] = None
|
||||
stack: Optional[MacroStack] = None,
|
||||
) -> None:
|
||||
super().__init__(context)
|
||||
self.macro = macro
|
||||
@@ -333,9 +342,7 @@ class MacroGenerator(BaseMacroGenerator):
|
||||
|
||||
|
||||
class QueryStringGenerator(BaseMacroGenerator):
|
||||
def __init__(
|
||||
self, template_str: str, context: Dict[str, Any]
|
||||
) -> None:
|
||||
def __init__(self, template_str: str, context: Dict[str, Any]) -> None:
|
||||
super().__init__(context)
|
||||
self.template_str: str = template_str
|
||||
env = get_environment()
|
||||
@@ -345,7 +352,7 @@ class QueryStringGenerator(BaseMacroGenerator):
|
||||
)
|
||||
|
||||
def get_name(self) -> str:
|
||||
return 'query_comment_macro'
|
||||
return "query_comment_macro"
|
||||
|
||||
def get_template(self):
|
||||
"""Don't use the template cache, we don't have a node"""
|
||||
@@ -356,45 +363,41 @@ class QueryStringGenerator(BaseMacroGenerator):
|
||||
|
||||
|
||||
class MaterializationExtension(jinja2.ext.Extension):
|
||||
tags = ['materialization']
|
||||
tags = ["materialization"]
|
||||
|
||||
def parse(self, parser):
|
||||
node = jinja2.nodes.Macro(lineno=next(parser.stream).lineno)
|
||||
materialization_name = \
|
||||
parser.parse_assign_target(name_only=True).name
|
||||
materialization_name = parser.parse_assign_target(name_only=True).name
|
||||
|
||||
adapter_name = 'default'
|
||||
adapter_name = "default"
|
||||
node.args = []
|
||||
node.defaults = []
|
||||
|
||||
while parser.stream.skip_if('comma'):
|
||||
while parser.stream.skip_if("comma"):
|
||||
target = parser.parse_assign_target(name_only=True)
|
||||
|
||||
if target.name == 'default':
|
||||
if target.name == "default":
|
||||
pass
|
||||
|
||||
elif target.name == 'adapter':
|
||||
parser.stream.expect('assign')
|
||||
elif target.name == "adapter":
|
||||
parser.stream.expect("assign")
|
||||
value = parser.parse_expression()
|
||||
adapter_name = value.value
|
||||
|
||||
else:
|
||||
invalid_materialization_argument(
|
||||
materialization_name, target.name
|
||||
)
|
||||
invalid_materialization_argument(materialization_name, target.name)
|
||||
|
||||
node.name = get_materialization_macro_name(
|
||||
materialization_name, adapter_name
|
||||
)
|
||||
node.name = get_materialization_macro_name(materialization_name, adapter_name)
|
||||
|
||||
node.body = parser.parse_statements(('name:endmaterialization',),
|
||||
drop_needle=True)
|
||||
node.body = parser.parse_statements(
|
||||
("name:endmaterialization",), drop_needle=True
|
||||
)
|
||||
|
||||
return node
|
||||
|
||||
|
||||
class DocumentationExtension(jinja2.ext.Extension):
|
||||
tags = ['docs']
|
||||
tags = ["docs"]
|
||||
|
||||
def parse(self, parser):
|
||||
node = jinja2.nodes.Macro(lineno=next(parser.stream).lineno)
|
||||
@@ -403,13 +406,12 @@ class DocumentationExtension(jinja2.ext.Extension):
|
||||
node.args = []
|
||||
node.defaults = []
|
||||
node.name = get_docs_macro_name(docs_name)
|
||||
node.body = parser.parse_statements(('name:enddocs',),
|
||||
drop_needle=True)
|
||||
node.body = parser.parse_statements(("name:enddocs",), drop_needle=True)
|
||||
return node
|
||||
|
||||
|
||||
def _is_dunder_name(name):
|
||||
return name.startswith('__') and name.endswith('__')
|
||||
return name.startswith("__") and name.endswith("__")
|
||||
|
||||
|
||||
def create_undefined(node=None):
|
||||
@@ -430,10 +432,11 @@ def create_undefined(node=None):
|
||||
return self
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name == 'name' or _is_dunder_name(name):
|
||||
if name == "name" or _is_dunder_name(name):
|
||||
raise AttributeError(
|
||||
"'{}' object has no attribute '{}'"
|
||||
.format(type(self).__name__, name)
|
||||
"'{}' object has no attribute '{}'".format(
|
||||
type(self).__name__, name
|
||||
)
|
||||
)
|
||||
|
||||
self.name = name
|
||||
@@ -444,24 +447,24 @@ def create_undefined(node=None):
|
||||
return self
|
||||
|
||||
def __reduce__(self):
|
||||
raise_compiler_error(f'{self.name} is undefined', node=node)
|
||||
raise_compiler_error(f"{self.name} is undefined", node=node)
|
||||
|
||||
return Undefined
|
||||
|
||||
|
||||
NATIVE_FILTERS: Dict[str, Callable[[Any], Any]] = {
|
||||
'as_text': TextMarker,
|
||||
'as_bool': BoolMarker,
|
||||
'as_native': NativeMarker,
|
||||
'as_number': NumberMarker,
|
||||
"as_text": TextMarker,
|
||||
"as_bool": BoolMarker,
|
||||
"as_native": NativeMarker,
|
||||
"as_number": NumberMarker,
|
||||
}
|
||||
|
||||
|
||||
TEXT_FILTERS: Dict[str, Callable[[Any], Any]] = {
|
||||
'as_text': lambda x: x,
|
||||
'as_bool': lambda x: x,
|
||||
'as_native': lambda x: x,
|
||||
'as_number': lambda x: x,
|
||||
"as_text": lambda x: x,
|
||||
"as_bool": lambda x: x,
|
||||
"as_native": lambda x: x,
|
||||
"as_number": lambda x: x,
|
||||
}
|
||||
|
||||
|
||||
@@ -471,14 +474,14 @@ def get_environment(
|
||||
native: bool = False,
|
||||
) -> jinja2.Environment:
|
||||
args: Dict[str, List[Union[str, Type[jinja2.ext.Extension]]]] = {
|
||||
'extensions': ['jinja2.ext.do']
|
||||
"extensions": ["jinja2.ext.do"]
|
||||
}
|
||||
|
||||
if capture_macros:
|
||||
args['undefined'] = create_undefined(node)
|
||||
args["undefined"] = create_undefined(node)
|
||||
|
||||
args['extensions'].append(MaterializationExtension)
|
||||
args['extensions'].append(DocumentationExtension)
|
||||
args["extensions"].append(MaterializationExtension)
|
||||
args["extensions"].append(DocumentationExtension)
|
||||
|
||||
env_cls: Type[jinja2.Environment]
|
||||
text_filter: Type
|
||||
@@ -541,8 +544,8 @@ def _requote_result(raw_value: str, rendered: str) -> str:
|
||||
elif single_quoted:
|
||||
quote_char = "'"
|
||||
else:
|
||||
quote_char = ''
|
||||
return f'{quote_char}{rendered}{quote_char}'
|
||||
quote_char = ""
|
||||
return f"{quote_char}{rendered}{quote_char}"
|
||||
|
||||
|
||||
# performance note: Local benmcharking (so take it with a big grain of salt!)
|
||||
@@ -550,7 +553,7 @@ def _requote_result(raw_value: str, rendered: str) -> str:
|
||||
# checking two separate patterns, but the standard deviation is smaller with
|
||||
# one pattern. The time difference between the two was ~2 std deviations, which
|
||||
# is small enough that I've just chosen the more readable option.
|
||||
_HAS_RENDER_CHARS_PAT = re.compile(r'({[{%#]|[#}%]})')
|
||||
_HAS_RENDER_CHARS_PAT = re.compile(r"({[{%#]|[#}%]})")
|
||||
|
||||
|
||||
def get_rendered(
|
||||
@@ -567,9 +570,9 @@ def get_rendered(
|
||||
# native=True case by passing the input string to ast.literal_eval, like
|
||||
# the native renderer does.
|
||||
if (
|
||||
not native and
|
||||
isinstance(string, str) and
|
||||
_HAS_RENDER_CHARS_PAT.search(string) is None
|
||||
not native
|
||||
and isinstance(string, str)
|
||||
and _HAS_RENDER_CHARS_PAT.search(string) is None
|
||||
):
|
||||
return string
|
||||
template = get_template(
|
||||
@@ -606,12 +609,11 @@ def extract_toplevel_blocks(
|
||||
`collect_raw_data` is `True`) `BlockData` objects.
|
||||
"""
|
||||
return BlockIterator(data).lex_for_blocks(
|
||||
allowed_blocks=allowed_blocks,
|
||||
collect_raw_data=collect_raw_data
|
||||
allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data
|
||||
)
|
||||
|
||||
|
||||
SCHEMA_TEST_KWARGS_NAME = '_dbt_schema_test_kwargs'
|
||||
SCHEMA_TEST_KWARGS_NAME = "_dbt_schema_test_kwargs"
|
||||
|
||||
|
||||
def add_rendered_test_kwargs(
|
||||
@@ -623,24 +625,21 @@ def add_rendered_test_kwargs(
|
||||
renderer, then insert that value into the given context as the special test
|
||||
keyword arguments member.
|
||||
"""
|
||||
looks_like_func = r'^\s*(env_var|ref|var|source|doc)\s*\(.+\)\s*$'
|
||||
looks_like_func = r"^\s*(env_var|ref|var|source|doc)\s*\(.+\)\s*$"
|
||||
|
||||
def _convert_function(
|
||||
value: Any, keypath: Tuple[Union[str, int], ...]
|
||||
) -> Any:
|
||||
def _convert_function(value: Any, keypath: Tuple[Union[str, int], ...]) -> Any:
|
||||
if isinstance(value, str):
|
||||
if keypath == ('column_name',):
|
||||
if keypath == ("column_name",):
|
||||
# special case: Don't render column names as native, make them
|
||||
# be strings
|
||||
return value
|
||||
|
||||
if re.match(looks_like_func, value) is not None:
|
||||
# curly braces to make rendering happy
|
||||
value = f'{{{{ {value} }}}}'
|
||||
value = f"{{{{ {value} }}}}"
|
||||
|
||||
value = get_rendered(
|
||||
value, context, node, capture_macros=capture_macros,
|
||||
native=True
|
||||
value, context, node, capture_macros=capture_macros, native=True
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
@@ -6,17 +6,17 @@ from dbt.logger import GLOBAL_LOGGER as logger
|
||||
import os
|
||||
import time
|
||||
|
||||
if os.getenv('DBT_PACKAGE_HUB_URL'):
|
||||
DEFAULT_REGISTRY_BASE_URL = os.getenv('DBT_PACKAGE_HUB_URL')
|
||||
if os.getenv("DBT_PACKAGE_HUB_URL"):
|
||||
DEFAULT_REGISTRY_BASE_URL = os.getenv("DBT_PACKAGE_HUB_URL")
|
||||
else:
|
||||
DEFAULT_REGISTRY_BASE_URL = 'https://hub.getdbt.com/'
|
||||
DEFAULT_REGISTRY_BASE_URL = "https://hub.getdbt.com/"
|
||||
|
||||
|
||||
def _get_url(url, registry_base_url=None):
|
||||
if registry_base_url is None:
|
||||
registry_base_url = DEFAULT_REGISTRY_BASE_URL
|
||||
|
||||
return '{}{}'.format(registry_base_url, url)
|
||||
return "{}{}".format(registry_base_url, url)
|
||||
|
||||
|
||||
def _wrap_exceptions(fn):
|
||||
@@ -33,42 +33,40 @@ def _wrap_exceptions(fn):
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
raise RegistryException(
|
||||
'Unable to connect to registry hub'
|
||||
) from exc
|
||||
raise RegistryException("Unable to connect to registry hub") from exc
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@_wrap_exceptions
|
||||
def _get(path, registry_base_url=None):
|
||||
url = _get_url(path, registry_base_url)
|
||||
logger.debug('Making package registry request: GET {}'.format(url))
|
||||
logger.debug("Making package registry request: GET {}".format(url))
|
||||
resp = requests.get(url)
|
||||
logger.debug('Response from registry: GET {} {}'.format(url,
|
||||
resp.status_code))
|
||||
logger.debug("Response from registry: GET {} {}".format(url, resp.status_code))
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
def index(registry_base_url=None):
|
||||
return _get('api/v1/index.json', registry_base_url)
|
||||
return _get("api/v1/index.json", registry_base_url)
|
||||
|
||||
|
||||
index_cached = memoized(index)
|
||||
|
||||
|
||||
def packages(registry_base_url=None):
|
||||
return _get('api/v1/packages.json', registry_base_url)
|
||||
return _get("api/v1/packages.json", registry_base_url)
|
||||
|
||||
|
||||
def package(name, registry_base_url=None):
|
||||
return _get('api/v1/{}.json'.format(name), registry_base_url)
|
||||
return _get("api/v1/{}.json".format(name), registry_base_url)
|
||||
|
||||
|
||||
def package_version(name, version, registry_base_url=None):
|
||||
return _get('api/v1/{}/{}.json'.format(name, version), registry_base_url)
|
||||
return _get("api/v1/{}/{}.json".format(name, version), registry_base_url)
|
||||
|
||||
|
||||
def get_available_versions(name):
|
||||
response = package(name)
|
||||
return list(response['versions'])
|
||||
return list(response["versions"])
|
||||
|
||||
@@ -10,16 +10,14 @@ import sys
|
||||
import tarfile
|
||||
import requests
|
||||
import stat
|
||||
from typing import (
|
||||
Type, NoReturn, List, Optional, Dict, Any, Tuple, Callable, Union
|
||||
)
|
||||
from typing import Type, NoReturn, List, Optional, Dict, Any, Tuple, Callable, Union
|
||||
|
||||
import dbt.exceptions
|
||||
import dbt.utils
|
||||
|
||||
from dbt.logger import GLOBAL_LOGGER as logger
|
||||
|
||||
if sys.platform == 'win32':
|
||||
if sys.platform == "win32":
|
||||
from ctypes import WinDLL, c_bool
|
||||
else:
|
||||
WinDLL = None
|
||||
@@ -51,30 +49,29 @@ def find_matching(
|
||||
reobj = re.compile(regex, re.IGNORECASE)
|
||||
|
||||
for relative_path_to_search in relative_paths_to_search:
|
||||
absolute_path_to_search = os.path.join(
|
||||
root_path, relative_path_to_search)
|
||||
absolute_path_to_search = os.path.join(root_path, relative_path_to_search)
|
||||
walk_results = os.walk(absolute_path_to_search)
|
||||
|
||||
for current_path, subdirectories, local_files in walk_results:
|
||||
for local_file in local_files:
|
||||
absolute_path = os.path.join(current_path, local_file)
|
||||
relative_path = os.path.relpath(
|
||||
absolute_path, absolute_path_to_search
|
||||
)
|
||||
relative_path = os.path.relpath(absolute_path, absolute_path_to_search)
|
||||
if reobj.match(local_file):
|
||||
matching.append({
|
||||
'searched_path': relative_path_to_search,
|
||||
'absolute_path': absolute_path,
|
||||
'relative_path': relative_path,
|
||||
})
|
||||
matching.append(
|
||||
{
|
||||
"searched_path": relative_path_to_search,
|
||||
"absolute_path": absolute_path,
|
||||
"relative_path": relative_path,
|
||||
}
|
||||
)
|
||||
|
||||
return matching
|
||||
|
||||
|
||||
def load_file_contents(path: str, strip: bool = True) -> str:
|
||||
path = convert_path(path)
|
||||
with open(path, 'rb') as handle:
|
||||
to_return = handle.read().decode('utf-8')
|
||||
with open(path, "rb") as handle:
|
||||
to_return = handle.read().decode("utf-8")
|
||||
|
||||
if strip:
|
||||
to_return = to_return.strip()
|
||||
@@ -101,14 +98,14 @@ def make_directory(path: str) -> None:
|
||||
raise e
|
||||
|
||||
|
||||
def make_file(path: str, contents: str = '', overwrite: bool = False) -> bool:
|
||||
def make_file(path: str, contents: str = "", overwrite: bool = False) -> bool:
|
||||
"""
|
||||
Make a file at `path` assuming that the directory it resides in already
|
||||
exists. The file is saved with contents `contents`
|
||||
"""
|
||||
if overwrite or not os.path.exists(path):
|
||||
path = convert_path(path)
|
||||
with open(path, 'w') as fh:
|
||||
with open(path, "w") as fh:
|
||||
fh.write(contents)
|
||||
return True
|
||||
|
||||
@@ -120,7 +117,7 @@ def make_symlink(source: str, link_path: str) -> None:
|
||||
Create a symlink at `link_path` referring to `source`.
|
||||
"""
|
||||
if not supports_symlinks():
|
||||
dbt.exceptions.system_error('create a symbolic link')
|
||||
dbt.exceptions.system_error("create a symbolic link")
|
||||
|
||||
os.symlink(source, link_path)
|
||||
|
||||
@@ -129,11 +126,11 @@ def supports_symlinks() -> bool:
|
||||
return getattr(os, "symlink", None) is not None
|
||||
|
||||
|
||||
def write_file(path: str, contents: str = '') -> bool:
|
||||
def write_file(path: str, contents: str = "") -> bool:
|
||||
path = convert_path(path)
|
||||
try:
|
||||
make_directory(os.path.dirname(path))
|
||||
with open(path, 'w', encoding='utf-8') as f:
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
f.write(str(contents))
|
||||
except Exception as exc:
|
||||
# note that you can't just catch FileNotFound, because sometimes
|
||||
@@ -142,20 +139,20 @@ def write_file(path: str, contents: str = '') -> bool:
|
||||
# sometimes windows fails to write paths that are less than the length
|
||||
# limit. So on windows, suppress all errors that happen from writing
|
||||
# to disk.
|
||||
if os.name == 'nt':
|
||||
if os.name == "nt":
|
||||
# sometimes we get a winerror of 3 which means the path was
|
||||
# definitely too long, but other times we don't and it means the
|
||||
# path was just probably too long. This is probably based on the
|
||||
# windows/python version.
|
||||
if getattr(exc, 'winerror', 0) == 3:
|
||||
reason = 'Path was too long'
|
||||
if getattr(exc, "winerror", 0) == 3:
|
||||
reason = "Path was too long"
|
||||
else:
|
||||
reason = 'Path was possibly too long'
|
||||
reason = "Path was possibly too long"
|
||||
# all our hard work and the path was still too long. Log and
|
||||
# continue.
|
||||
logger.debug(
|
||||
f'Could not write to path {path}({len(path)} characters): '
|
||||
f'{reason}\nexception: {exc}'
|
||||
f"Could not write to path {path}({len(path)} characters): "
|
||||
f"{reason}\nexception: {exc}"
|
||||
)
|
||||
else:
|
||||
raise
|
||||
@@ -189,10 +186,7 @@ def resolve_path_from_base(path_to_resolve: str, base_path: str) -> str:
|
||||
If path_to_resolve is an absolute path or a user path (~), just
|
||||
resolve it to an absolute path and return.
|
||||
"""
|
||||
return os.path.abspath(
|
||||
os.path.join(
|
||||
base_path,
|
||||
os.path.expanduser(path_to_resolve)))
|
||||
return os.path.abspath(os.path.join(base_path, os.path.expanduser(path_to_resolve)))
|
||||
|
||||
|
||||
def rmdir(path: str) -> None:
|
||||
@@ -202,7 +196,7 @@ def rmdir(path: str) -> None:
|
||||
cloned via git) can cause rmtree to throw a PermissionError exception
|
||||
"""
|
||||
path = convert_path(path)
|
||||
if sys.platform == 'win32':
|
||||
if sys.platform == "win32":
|
||||
onerror = _windows_rmdir_readonly
|
||||
else:
|
||||
onerror = None
|
||||
@@ -221,7 +215,7 @@ def _win_prepare_path(path: str) -> str:
|
||||
# letter back in.
|
||||
# Unless it starts with '\\'. In that case, the path is a UNC mount point
|
||||
# and splitdrive will be fine.
|
||||
if not path.startswith('\\\\') and path.startswith('\\'):
|
||||
if not path.startswith("\\\\") and path.startswith("\\"):
|
||||
curdrive = os.path.splitdrive(os.getcwd())[0]
|
||||
path = curdrive + path
|
||||
|
||||
@@ -236,7 +230,7 @@ def _win_prepare_path(path: str) -> str:
|
||||
|
||||
|
||||
def _supports_long_paths() -> bool:
|
||||
if sys.platform != 'win32':
|
||||
if sys.platform != "win32":
|
||||
return True
|
||||
# Eryk Sun says to use `WinDLL('ntdll')` instead of `windll.ntdll` because
|
||||
# of pointer caching in a comment here:
|
||||
@@ -244,11 +238,11 @@ def _supports_long_paths() -> bool:
|
||||
# I don't know exaclty what he means, but I am inclined to believe him as
|
||||
# he's pretty active on Python windows bugs!
|
||||
try:
|
||||
dll = WinDLL('ntdll')
|
||||
dll = WinDLL("ntdll")
|
||||
except OSError: # I don't think this happens? you need ntdll to run python
|
||||
return False
|
||||
# not all windows versions have it at all
|
||||
if not hasattr(dll, 'RtlAreLongPathsEnabled'):
|
||||
if not hasattr(dll, "RtlAreLongPathsEnabled"):
|
||||
return False
|
||||
# tell windows we want to get back a single unsigned byte (a bool).
|
||||
dll.RtlAreLongPathsEnabled.restype = c_bool
|
||||
@@ -268,7 +262,7 @@ def convert_path(path: str) -> str:
|
||||
if _supports_long_paths():
|
||||
return path
|
||||
|
||||
prefix = '\\\\?\\'
|
||||
prefix = "\\\\?\\"
|
||||
# Nothing to do
|
||||
if path.startswith(prefix):
|
||||
return path
|
||||
@@ -299,39 +293,35 @@ def path_is_symlink(path: str) -> bool:
|
||||
|
||||
def open_dir_cmd() -> str:
|
||||
# https://docs.python.org/2/library/sys.html#sys.platform
|
||||
if sys.platform == 'win32':
|
||||
return 'start'
|
||||
if sys.platform == "win32":
|
||||
return "start"
|
||||
|
||||
elif sys.platform == 'darwin':
|
||||
return 'open'
|
||||
elif sys.platform == "darwin":
|
||||
return "open"
|
||||
|
||||
else:
|
||||
return 'xdg-open'
|
||||
return "xdg-open"
|
||||
|
||||
|
||||
def _handle_posix_cwd_error(
|
||||
exc: OSError, cwd: str, cmd: List[str]
|
||||
) -> NoReturn:
|
||||
def _handle_posix_cwd_error(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn:
|
||||
if exc.errno == errno.ENOENT:
|
||||
message = 'Directory does not exist'
|
||||
message = "Directory does not exist"
|
||||
elif exc.errno == errno.EACCES:
|
||||
message = 'Current user cannot access directory, check permissions'
|
||||
message = "Current user cannot access directory, check permissions"
|
||||
elif exc.errno == errno.ENOTDIR:
|
||||
message = 'Not a directory'
|
||||
message = "Not a directory"
|
||||
else:
|
||||
message = 'Unknown OSError: {} - cwd'.format(str(exc))
|
||||
message = "Unknown OSError: {} - cwd".format(str(exc))
|
||||
raise dbt.exceptions.WorkingDirectoryError(cwd, cmd, message)
|
||||
|
||||
|
||||
def _handle_posix_cmd_error(
|
||||
exc: OSError, cwd: str, cmd: List[str]
|
||||
) -> NoReturn:
|
||||
def _handle_posix_cmd_error(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn:
|
||||
if exc.errno == errno.ENOENT:
|
||||
message = "Could not find command, ensure it is in the user's PATH"
|
||||
elif exc.errno == errno.EACCES:
|
||||
message = 'User does not have permissions for this command'
|
||||
message = "User does not have permissions for this command"
|
||||
else:
|
||||
message = 'Unknown OSError: {} - cmd'.format(str(exc))
|
||||
message = "Unknown OSError: {} - cmd".format(str(exc))
|
||||
raise dbt.exceptions.ExecutableError(cwd, cmd, message)
|
||||
|
||||
|
||||
@@ -356,7 +346,7 @@ def _handle_posix_error(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn:
|
||||
- exc.errno == EACCES
|
||||
- exc.filename == None(?)
|
||||
"""
|
||||
if getattr(exc, 'filename', None) == cwd:
|
||||
if getattr(exc, "filename", None) == cwd:
|
||||
_handle_posix_cwd_error(exc, cwd, cmd)
|
||||
else:
|
||||
_handle_posix_cmd_error(exc, cwd, cmd)
|
||||
@@ -365,46 +355,48 @@ def _handle_posix_error(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn:
|
||||
def _handle_windows_error(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn:
|
||||
cls: Type[dbt.exceptions.Exception] = dbt.exceptions.CommandError
|
||||
if exc.errno == errno.ENOENT:
|
||||
message = ("Could not find command, ensure it is in the user's PATH "
|
||||
"and that the user has permissions to run it")
|
||||
message = (
|
||||
"Could not find command, ensure it is in the user's PATH "
|
||||
"and that the user has permissions to run it"
|
||||
)
|
||||
cls = dbt.exceptions.ExecutableError
|
||||
elif exc.errno == errno.ENOEXEC:
|
||||
message = ('Command was not executable, ensure it is valid')
|
||||
message = "Command was not executable, ensure it is valid"
|
||||
cls = dbt.exceptions.ExecutableError
|
||||
elif exc.errno == errno.ENOTDIR:
|
||||
message = ('Unable to cd: path does not exist, user does not have'
|
||||
' permissions, or not a directory')
|
||||
message = (
|
||||
"Unable to cd: path does not exist, user does not have"
|
||||
" permissions, or not a directory"
|
||||
)
|
||||
cls = dbt.exceptions.WorkingDirectoryError
|
||||
else:
|
||||
message = 'Unknown error: {} (errno={}: "{}")'.format(
|
||||
str(exc), exc.errno, errno.errorcode.get(exc.errno, '<Unknown!>')
|
||||
str(exc), exc.errno, errno.errorcode.get(exc.errno, "<Unknown!>")
|
||||
)
|
||||
raise cls(cwd, cmd, message)
|
||||
|
||||
|
||||
def _interpret_oserror(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn:
|
||||
"""Interpret an OSError exc and raise the appropriate dbt exception.
|
||||
|
||||
"""
|
||||
"""Interpret an OSError exc and raise the appropriate dbt exception."""
|
||||
if len(cmd) == 0:
|
||||
raise dbt.exceptions.CommandError(cwd, cmd)
|
||||
|
||||
# all of these functions raise unconditionally
|
||||
if os.name == 'nt':
|
||||
if os.name == "nt":
|
||||
_handle_windows_error(exc, cwd, cmd)
|
||||
else:
|
||||
_handle_posix_error(exc, cwd, cmd)
|
||||
|
||||
# this should not be reachable, raise _something_ at least!
|
||||
raise dbt.exceptions.InternalException(
|
||||
'Unhandled exception in _interpret_oserror: {}'.format(exc)
|
||||
"Unhandled exception in _interpret_oserror: {}".format(exc)
|
||||
)
|
||||
|
||||
|
||||
def run_cmd(
|
||||
cwd: str, cmd: List[str], env: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[bytes, bytes]:
|
||||
logger.debug('Executing "{}"'.format(' '.join(cmd)))
|
||||
logger.debug('Executing "{}"'.format(" ".join(cmd)))
|
||||
if len(cmd) == 0:
|
||||
raise dbt.exceptions.CommandError(cwd, cmd)
|
||||
|
||||
@@ -417,11 +409,8 @@ def run_cmd(
|
||||
|
||||
try:
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
cwd=cwd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
env=full_env)
|
||||
cmd, cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=full_env
|
||||
)
|
||||
|
||||
out, err = proc.communicate()
|
||||
except OSError as exc:
|
||||
@@ -431,9 +420,8 @@ def run_cmd(
|
||||
logger.debug('STDERR: "{!s}"'.format(err))
|
||||
|
||||
if proc.returncode != 0:
|
||||
logger.debug('command return code={}'.format(proc.returncode))
|
||||
raise dbt.exceptions.CommandResultError(cwd, cmd, proc.returncode,
|
||||
out, err)
|
||||
logger.debug("command return code={}".format(proc.returncode))
|
||||
raise dbt.exceptions.CommandResultError(cwd, cmd, proc.returncode, out, err)
|
||||
|
||||
return out, err
|
||||
|
||||
@@ -442,9 +430,9 @@ def download(
|
||||
url: str, path: str, timeout: Optional[Union[float, tuple]] = None
|
||||
) -> None:
|
||||
path = convert_path(path)
|
||||
connection_timeout = timeout or float(os.getenv('DBT_HTTP_TIMEOUT', 10))
|
||||
connection_timeout = timeout or float(os.getenv("DBT_HTTP_TIMEOUT", 10))
|
||||
response = requests.get(url, timeout=connection_timeout)
|
||||
with open(path, 'wb') as handle:
|
||||
with open(path, "wb") as handle:
|
||||
for block in response.iter_content(1024 * 64):
|
||||
handle.write(block)
|
||||
|
||||
@@ -468,7 +456,7 @@ def untar_package(
|
||||
) -> None:
|
||||
tar_path = convert_path(tar_path)
|
||||
tar_dir_name = None
|
||||
with tarfile.open(tar_path, 'r') as tarball:
|
||||
with tarfile.open(tar_path, "r") as tarball:
|
||||
tarball.extractall(dest_dir)
|
||||
tar_dir_name = os.path.commonprefix(tarball.getnames())
|
||||
if rename_to:
|
||||
@@ -484,7 +472,7 @@ def chmod_and_retry(func, path, exc_info):
|
||||
We want to retry most operations here, but listdir is one that we know will
|
||||
be useless.
|
||||
"""
|
||||
if func is os.listdir or os.name != 'nt':
|
||||
if func is os.listdir or os.name != "nt":
|
||||
raise
|
||||
os.chmod(path, stat.S_IREAD | stat.S_IWRITE)
|
||||
# on error,this will raise.
|
||||
@@ -505,7 +493,7 @@ def move(src, dst):
|
||||
"""
|
||||
src = convert_path(src)
|
||||
dst = convert_path(dst)
|
||||
if os.name != 'nt':
|
||||
if os.name != "nt":
|
||||
return shutil.move(src, dst)
|
||||
|
||||
if os.path.isdir(dst):
|
||||
@@ -513,7 +501,7 @@ def move(src, dst):
|
||||
os.rename(src, dst)
|
||||
return
|
||||
|
||||
dst = os.path.join(dst, os.path.basename(src.rstrip('/\\')))
|
||||
dst = os.path.join(dst, os.path.basename(src.rstrip("/\\")))
|
||||
if os.path.exists(dst):
|
||||
raise EnvironmentError("Path '{}' already exists".format(dst))
|
||||
|
||||
@@ -522,11 +510,10 @@ def move(src, dst):
|
||||
except OSError:
|
||||
# probably different drives
|
||||
if os.path.isdir(src):
|
||||
if _absnorm(dst + '\\').startswith(_absnorm(src + '\\')):
|
||||
if _absnorm(dst + "\\").startswith(_absnorm(src + "\\")):
|
||||
# dst is inside src
|
||||
raise EnvironmentError(
|
||||
"Cannot move a directory '{}' into itself '{}'"
|
||||
.format(src, dst)
|
||||
"Cannot move a directory '{}' into itself '{}'".format(src, dst)
|
||||
)
|
||||
shutil.copytree(src, dst, symlinks=True)
|
||||
rmtree(src)
|
||||
|
||||
@@ -5,15 +5,9 @@ import yaml.scanner
|
||||
|
||||
# the C version is faster, but it doesn't always exist
|
||||
try:
|
||||
from yaml import (
|
||||
CLoader as Loader,
|
||||
CSafeLoader as SafeLoader,
|
||||
CDumper as Dumper
|
||||
)
|
||||
from yaml import CLoader as Loader, CSafeLoader as SafeLoader, CDumper as Dumper
|
||||
except ImportError:
|
||||
from yaml import ( # type: ignore # noqa: F401
|
||||
Loader, SafeLoader, Dumper
|
||||
)
|
||||
from yaml import Loader, SafeLoader, Dumper # type: ignore # noqa: F401
|
||||
|
||||
|
||||
YAML_ERROR_MESSAGE = """
|
||||
@@ -33,14 +27,14 @@ def line_no(i, line, width=3):
|
||||
|
||||
|
||||
def prefix_with_line_numbers(string, no_start, no_end):
|
||||
line_list = string.split('\n')
|
||||
line_list = string.split("\n")
|
||||
|
||||
numbers = range(no_start, no_end)
|
||||
relevant_lines = line_list[no_start:no_end]
|
||||
|
||||
return "\n".join([
|
||||
line_no(i + 1, line) for (i, line) in zip(numbers, relevant_lines)
|
||||
])
|
||||
return "\n".join(
|
||||
[line_no(i + 1, line) for (i, line) in zip(numbers, relevant_lines)]
|
||||
)
|
||||
|
||||
|
||||
def contextualized_yaml_error(raw_contents, error):
|
||||
@@ -51,9 +45,9 @@ def contextualized_yaml_error(raw_contents, error):
|
||||
|
||||
nice_error = prefix_with_line_numbers(raw_contents, min_line, max_line)
|
||||
|
||||
return YAML_ERROR_MESSAGE.format(line_number=mark.line + 1,
|
||||
nice_error=nice_error,
|
||||
raw_error=error)
|
||||
return YAML_ERROR_MESSAGE.format(
|
||||
line_number=mark.line + 1, nice_error=nice_error, raw_error=error
|
||||
)
|
||||
|
||||
|
||||
def safe_load(contents):
|
||||
@@ -64,7 +58,7 @@ def load_yaml_text(contents):
|
||||
try:
|
||||
return safe_load(contents)
|
||||
except (yaml.scanner.ScannerError, yaml.YAMLError) as e:
|
||||
if hasattr(e, 'problem_mark'):
|
||||
if hasattr(e, "problem_mark"):
|
||||
error = contextualized_yaml_error(contents, e)
|
||||
else:
|
||||
error = str(e)
|
||||
|
||||
@@ -32,28 +32,28 @@ from dbt.node_types import NodeType
|
||||
from dbt.utils import pluralize
|
||||
import dbt.tracking
|
||||
|
||||
graph_file_name = 'graph.gpickle'
|
||||
graph_file_name = "graph.gpickle"
|
||||
|
||||
|
||||
def _compiled_type_for(model: ParsedNode):
|
||||
if type(model) not in COMPILED_TYPES:
|
||||
raise InternalException(
|
||||
f'Asked to compile {type(model)} node, but it has no compiled form'
|
||||
f"Asked to compile {type(model)} node, but it has no compiled form"
|
||||
)
|
||||
return COMPILED_TYPES[type(model)]
|
||||
|
||||
|
||||
def print_compile_stats(stats):
|
||||
names = {
|
||||
NodeType.Model: 'model',
|
||||
NodeType.Test: 'test',
|
||||
NodeType.Snapshot: 'snapshot',
|
||||
NodeType.Analysis: 'analysis',
|
||||
NodeType.Macro: 'macro',
|
||||
NodeType.Operation: 'operation',
|
||||
NodeType.Seed: 'seed file',
|
||||
NodeType.Source: 'source',
|
||||
NodeType.Exposure: 'exposure',
|
||||
NodeType.Model: "model",
|
||||
NodeType.Test: "test",
|
||||
NodeType.Snapshot: "snapshot",
|
||||
NodeType.Analysis: "analysis",
|
||||
NodeType.Macro: "macro",
|
||||
NodeType.Operation: "operation",
|
||||
NodeType.Seed: "seed file",
|
||||
NodeType.Source: "source",
|
||||
NodeType.Exposure: "exposure",
|
||||
}
|
||||
|
||||
results = {k: 0 for k in names.keys()}
|
||||
@@ -64,10 +64,9 @@ def print_compile_stats(stats):
|
||||
resource_counts = {k.pluralize(): v for k, v in results.items()}
|
||||
dbt.tracking.track_resource_counts(resource_counts)
|
||||
|
||||
stat_line = ", ".join([
|
||||
pluralize(ct, names.get(t)) for t, ct in results.items()
|
||||
if t in names
|
||||
])
|
||||
stat_line = ", ".join(
|
||||
[pluralize(ct, names.get(t)) for t, ct in results.items() if t in names]
|
||||
)
|
||||
|
||||
logger.info("Found {}".format(stat_line))
|
||||
|
||||
@@ -166,9 +165,7 @@ class Compiler:
|
||||
extra_context: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
|
||||
context = generate_runtime_model(
|
||||
node, self.config, manifest
|
||||
)
|
||||
context = generate_runtime_model(node, self.config, manifest)
|
||||
context.update(extra_context)
|
||||
if isinstance(node, CompiledSchemaTestNode):
|
||||
# for test nodes, add a special keyword args value to the context
|
||||
@@ -183,8 +180,7 @@ class Compiler:
|
||||
|
||||
def _get_relation_name(self, node: ParsedNode):
|
||||
relation_name = None
|
||||
if (node.resource_type in NodeType.refable() and
|
||||
not node.is_ephemeral_model):
|
||||
if node.resource_type in NodeType.refable() and not node.is_ephemeral_model:
|
||||
adapter = get_adapter(self.config)
|
||||
relation_cls = adapter.Relation
|
||||
relation_name = str(relation_cls.create_from(self.config, node))
|
||||
@@ -227,32 +223,29 @@ class Compiler:
|
||||
|
||||
with_stmt = None
|
||||
for token in parsed.tokens:
|
||||
if token.is_keyword and token.normalized == 'WITH':
|
||||
if token.is_keyword and token.normalized == "WITH":
|
||||
with_stmt = token
|
||||
break
|
||||
|
||||
if with_stmt is None:
|
||||
# no with stmt, add one, and inject CTEs right at the beginning
|
||||
first_token = parsed.token_first()
|
||||
with_stmt = sqlparse.sql.Token(sqlparse.tokens.Keyword, 'with')
|
||||
with_stmt = sqlparse.sql.Token(sqlparse.tokens.Keyword, "with")
|
||||
parsed.insert_before(first_token, with_stmt)
|
||||
else:
|
||||
# stmt exists, add a comma (which will come after injected CTEs)
|
||||
trailing_comma = sqlparse.sql.Token(
|
||||
sqlparse.tokens.Punctuation, ','
|
||||
)
|
||||
trailing_comma = sqlparse.sql.Token(sqlparse.tokens.Punctuation, ",")
|
||||
parsed.insert_after(with_stmt, trailing_comma)
|
||||
|
||||
token = sqlparse.sql.Token(
|
||||
sqlparse.tokens.Keyword,
|
||||
", ".join(c.sql for c in ctes)
|
||||
sqlparse.tokens.Keyword, ", ".join(c.sql for c in ctes)
|
||||
)
|
||||
parsed.insert_after(with_stmt, token)
|
||||
|
||||
return str(parsed)
|
||||
|
||||
def _get_dbt_test_name(self) -> str:
|
||||
return 'dbt__cte__internal_test'
|
||||
return "dbt__cte__internal_test"
|
||||
|
||||
# This method is called by the 'compile_node' method. Starting
|
||||
# from the node that it is passed in, it will recursively call
|
||||
@@ -268,9 +261,7 @@ class Compiler:
|
||||
) -> Tuple[NonSourceCompiledNode, List[InjectedCTE]]:
|
||||
|
||||
if model.compiled_sql is None:
|
||||
raise RuntimeException(
|
||||
'Cannot inject ctes into an unparsed node', model
|
||||
)
|
||||
raise RuntimeException("Cannot inject ctes into an unparsed node", model)
|
||||
if model.extra_ctes_injected:
|
||||
return (model, model.extra_ctes)
|
||||
|
||||
@@ -296,19 +287,18 @@ class Compiler:
|
||||
else:
|
||||
if cte.id not in manifest.nodes:
|
||||
raise InternalException(
|
||||
f'During compilation, found a cte reference that '
|
||||
f'could not be resolved: {cte.id}'
|
||||
f"During compilation, found a cte reference that "
|
||||
f"could not be resolved: {cte.id}"
|
||||
)
|
||||
cte_model = manifest.nodes[cte.id]
|
||||
|
||||
if not cte_model.is_ephemeral_model:
|
||||
raise InternalException(f'{cte.id} is not ephemeral')
|
||||
raise InternalException(f"{cte.id} is not ephemeral")
|
||||
|
||||
# This model has already been compiled, so it's been
|
||||
# through here before
|
||||
if getattr(cte_model, 'compiled', False):
|
||||
assert isinstance(cte_model,
|
||||
tuple(COMPILED_TYPES.values()))
|
||||
if getattr(cte_model, "compiled", False):
|
||||
assert isinstance(cte_model, tuple(COMPILED_TYPES.values()))
|
||||
cte_model = cast(NonSourceCompiledNode, cte_model)
|
||||
new_prepended_ctes = cte_model.extra_ctes
|
||||
|
||||
@@ -316,11 +306,9 @@ class Compiler:
|
||||
else:
|
||||
# This is an ephemeral parsed model that we can compile.
|
||||
# Compile and update the node
|
||||
cte_model = self._compile_node(
|
||||
cte_model, manifest, extra_context)
|
||||
cte_model = self._compile_node(cte_model, manifest, extra_context)
|
||||
# recursively call this method
|
||||
cte_model, new_prepended_ctes = \
|
||||
self._recursively_prepend_ctes(
|
||||
cte_model, new_prepended_ctes = self._recursively_prepend_ctes(
|
||||
cte_model, manifest, extra_context
|
||||
)
|
||||
# Save compiled SQL file and sync manifest
|
||||
@@ -330,7 +318,7 @@ class Compiler:
|
||||
_extend_prepended_ctes(prepended_ctes, new_prepended_ctes)
|
||||
|
||||
new_cte_name = self.add_ephemeral_prefix(cte_model.name)
|
||||
sql = f' {new_cte_name} as (\n{cte_model.compiled_sql}\n)'
|
||||
sql = f" {new_cte_name} as (\n{cte_model.compiled_sql}\n)"
|
||||
|
||||
_add_prepended_cte(prepended_ctes, InjectedCTE(id=cte.id, sql=sql))
|
||||
|
||||
@@ -371,11 +359,10 @@ class Compiler:
|
||||
# compiled_sql, and do the regular prepend logic from CTEs.
|
||||
name = self._get_dbt_test_name()
|
||||
cte = InjectedCTE(
|
||||
id=name,
|
||||
sql=f' {name} as (\n{compiled_node.compiled_sql}\n)'
|
||||
id=name, sql=f" {name} as (\n{compiled_node.compiled_sql}\n)"
|
||||
)
|
||||
compiled_node.extra_ctes.append(cte)
|
||||
compiled_node.compiled_sql = f'\nselect count(*) from {name}'
|
||||
compiled_node.compiled_sql = f"\nselect count(*) from {name}"
|
||||
|
||||
return compiled_node
|
||||
|
||||
@@ -395,17 +382,17 @@ class Compiler:
|
||||
logger.debug("Compiling {}".format(node.unique_id))
|
||||
|
||||
data = node.to_dict(omit_none=True)
|
||||
data.update({
|
||||
'compiled': False,
|
||||
'compiled_sql': None,
|
||||
'extra_ctes_injected': False,
|
||||
'extra_ctes': [],
|
||||
})
|
||||
data.update(
|
||||
{
|
||||
"compiled": False,
|
||||
"compiled_sql": None,
|
||||
"extra_ctes_injected": False,
|
||||
"extra_ctes": [],
|
||||
}
|
||||
)
|
||||
compiled_node = _compiled_type_for(node).from_dict(data)
|
||||
|
||||
context = self._create_node_context(
|
||||
compiled_node, manifest, extra_context
|
||||
)
|
||||
context = self._create_node_context(compiled_node, manifest, extra_context)
|
||||
|
||||
compiled_node.compiled_sql = jinja.get_rendered(
|
||||
node.raw_sql,
|
||||
@@ -419,9 +406,7 @@ class Compiler:
|
||||
|
||||
# add ctes for specific test nodes, and also for
|
||||
# possible future use in adapters
|
||||
compiled_node = self._add_ctes(
|
||||
compiled_node, manifest, extra_context
|
||||
)
|
||||
compiled_node = self._add_ctes(compiled_node, manifest, extra_context)
|
||||
|
||||
return compiled_node
|
||||
|
||||
@@ -431,21 +416,17 @@ class Compiler:
|
||||
if flags.WRITE_JSON:
|
||||
linker.write_graph(graph_path, manifest)
|
||||
|
||||
def link_node(
|
||||
self, linker: Linker, node: GraphMemberNode, manifest: Manifest
|
||||
):
|
||||
def link_node(self, linker: Linker, node: GraphMemberNode, manifest: Manifest):
|
||||
linker.add_node(node.unique_id)
|
||||
|
||||
for dependency in node.depends_on_nodes:
|
||||
if dependency in manifest.nodes:
|
||||
linker.dependency(
|
||||
node.unique_id,
|
||||
(manifest.nodes[dependency].unique_id)
|
||||
node.unique_id, (manifest.nodes[dependency].unique_id)
|
||||
)
|
||||
elif dependency in manifest.sources:
|
||||
linker.dependency(
|
||||
node.unique_id,
|
||||
(manifest.sources[dependency].unique_id)
|
||||
node.unique_id, (manifest.sources[dependency].unique_id)
|
||||
)
|
||||
else:
|
||||
dependency_not_found(node, dependency)
|
||||
@@ -480,16 +461,13 @@ class Compiler:
|
||||
|
||||
# writes the "compiled_sql" into the target/compiled directory
|
||||
def _write_node(self, node: NonSourceCompiledNode) -> ManifestNode:
|
||||
if (not node.extra_ctes_injected or
|
||||
node.resource_type == NodeType.Snapshot):
|
||||
if not node.extra_ctes_injected or node.resource_type == NodeType.Snapshot:
|
||||
return node
|
||||
logger.debug(f'Writing injected SQL for node "{node.unique_id}"')
|
||||
|
||||
if node.compiled_sql:
|
||||
node.build_path = node.write_node(
|
||||
self.config.target_path,
|
||||
'compiled',
|
||||
node.compiled_sql
|
||||
self.config.target_path, "compiled", node.compiled_sql
|
||||
)
|
||||
return node
|
||||
|
||||
@@ -507,9 +485,7 @@ class Compiler:
|
||||
) -> NonSourceCompiledNode:
|
||||
node = self._compile_node(node, manifest, extra_context)
|
||||
|
||||
node, _ = self._recursively_prepend_ctes(
|
||||
node, manifest, extra_context
|
||||
)
|
||||
node, _ = self._recursively_prepend_ctes(node, manifest, extra_context)
|
||||
if write:
|
||||
self._write_node(node)
|
||||
return node
|
||||
|
||||
@@ -20,10 +20,8 @@ from dbt.utils import coerce_dict_str
|
||||
from .renderer import ProfileRenderer
|
||||
|
||||
DEFAULT_THREADS = 1
|
||||
DEFAULT_PROFILES_DIR = os.path.join(os.path.expanduser('~'), '.dbt')
|
||||
PROFILES_DIR = os.path.expanduser(
|
||||
os.getenv('DBT_PROFILES_DIR', DEFAULT_PROFILES_DIR)
|
||||
)
|
||||
DEFAULT_PROFILES_DIR = os.path.join(os.path.expanduser("~"), ".dbt")
|
||||
PROFILES_DIR = os.path.expanduser(os.getenv("DBT_PROFILES_DIR", DEFAULT_PROFILES_DIR))
|
||||
|
||||
INVALID_PROFILE_MESSAGE = """
|
||||
dbt encountered an error while trying to read your profiles.yml file.
|
||||
@@ -43,11 +41,13 @@ Here, [profile name] should be replaced with a profile name
|
||||
defined in your profiles.yml file. You can find profiles.yml here:
|
||||
|
||||
{profiles_file}/profiles.yml
|
||||
""".format(profiles_file=PROFILES_DIR)
|
||||
""".format(
|
||||
profiles_file=PROFILES_DIR
|
||||
)
|
||||
|
||||
|
||||
def read_profile(profiles_dir: str) -> Dict[str, Any]:
|
||||
path = os.path.join(profiles_dir, 'profiles.yml')
|
||||
path = os.path.join(profiles_dir, "profiles.yml")
|
||||
|
||||
contents = None
|
||||
if os.path.isfile(path):
|
||||
@@ -55,12 +55,8 @@ def read_profile(profiles_dir: str) -> Dict[str, Any]:
|
||||
contents = load_file_contents(path, strip=False)
|
||||
yaml_content = load_yaml_text(contents)
|
||||
if not yaml_content:
|
||||
msg = f'The profiles.yml file at {path} is empty'
|
||||
raise DbtProfileError(
|
||||
INVALID_PROFILE_MESSAGE.format(
|
||||
error_string=msg
|
||||
)
|
||||
)
|
||||
msg = f"The profiles.yml file at {path} is empty"
|
||||
raise DbtProfileError(INVALID_PROFILE_MESSAGE.format(error_string=msg))
|
||||
return yaml_content
|
||||
except ValidationException as e:
|
||||
msg = INVALID_PROFILE_MESSAGE.format(error_string=e)
|
||||
@@ -73,7 +69,7 @@ def read_user_config(directory: str) -> UserConfig:
|
||||
try:
|
||||
profile = read_profile(directory)
|
||||
if profile:
|
||||
user_cfg = coerce_dict_str(profile.get('config', {}))
|
||||
user_cfg = coerce_dict_str(profile.get("config", {}))
|
||||
if user_cfg is not None:
|
||||
UserConfig.validate(user_cfg)
|
||||
return UserConfig.from_dict(user_cfg)
|
||||
@@ -92,9 +88,7 @@ class Profile(HasCredentials):
|
||||
threads: int
|
||||
credentials: Credentials
|
||||
|
||||
def to_profile_info(
|
||||
self, serialize_credentials: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
def to_profile_info(self, serialize_credentials: bool = False) -> Dict[str, Any]:
|
||||
"""Unlike to_project_config, this dict is not a mirror of any existing
|
||||
on-disk data structure. It's used when creating a new profile from an
|
||||
existing one.
|
||||
@@ -104,34 +98,35 @@ class Profile(HasCredentials):
|
||||
:returns dict: The serialized profile.
|
||||
"""
|
||||
result = {
|
||||
'profile_name': self.profile_name,
|
||||
'target_name': self.target_name,
|
||||
'config': self.config,
|
||||
'threads': self.threads,
|
||||
'credentials': self.credentials,
|
||||
"profile_name": self.profile_name,
|
||||
"target_name": self.target_name,
|
||||
"config": self.config,
|
||||
"threads": self.threads,
|
||||
"credentials": self.credentials,
|
||||
}
|
||||
if serialize_credentials:
|
||||
result['config'] = self.config.to_dict(omit_none=True)
|
||||
result['credentials'] = self.credentials.to_dict(omit_none=True)
|
||||
result["config"] = self.config.to_dict(omit_none=True)
|
||||
result["credentials"] = self.credentials.to_dict(omit_none=True)
|
||||
return result
|
||||
|
||||
def to_target_dict(self) -> Dict[str, Any]:
|
||||
target = dict(
|
||||
self.credentials.connection_info(with_aliases=True)
|
||||
target = dict(self.credentials.connection_info(with_aliases=True))
|
||||
target.update(
|
||||
{
|
||||
"type": self.credentials.type,
|
||||
"threads": self.threads,
|
||||
"name": self.target_name,
|
||||
"target_name": self.target_name,
|
||||
"profile_name": self.profile_name,
|
||||
"config": self.config.to_dict(omit_none=True),
|
||||
}
|
||||
)
|
||||
target.update({
|
||||
'type': self.credentials.type,
|
||||
'threads': self.threads,
|
||||
'name': self.target_name,
|
||||
'target_name': self.target_name,
|
||||
'profile_name': self.profile_name,
|
||||
'config': self.config.to_dict(omit_none=True),
|
||||
})
|
||||
return target
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not (isinstance(other, self.__class__) and
|
||||
isinstance(self, other.__class__)):
|
||||
if not (
|
||||
isinstance(other, self.__class__) and isinstance(self, other.__class__)
|
||||
):
|
||||
return NotImplemented
|
||||
return self.to_profile_info() == other.to_profile_info()
|
||||
|
||||
@@ -151,14 +146,17 @@ class Profile(HasCredentials):
|
||||
) -> Credentials:
|
||||
# avoid an import cycle
|
||||
from dbt.adapters.factory import load_plugin
|
||||
|
||||
# credentials carry their 'type' in their actual type, not their
|
||||
# attributes. We do want this in order to pick our Credentials class.
|
||||
if 'type' not in profile:
|
||||
if "type" not in profile:
|
||||
raise DbtProfileError(
|
||||
'required field "type" not found in profile {} and target {}'
|
||||
.format(profile_name, target_name))
|
||||
'required field "type" not found in profile {} and target {}'.format(
|
||||
profile_name, target_name
|
||||
)
|
||||
)
|
||||
|
||||
typename = profile.pop('type')
|
||||
typename = profile.pop("type")
|
||||
try:
|
||||
cls = load_plugin(typename)
|
||||
data = cls.translate_aliases(profile)
|
||||
@@ -167,8 +165,9 @@ class Profile(HasCredentials):
|
||||
except (RuntimeException, ValidationError) as e:
|
||||
msg = str(e) if isinstance(e, RuntimeException) else e.message
|
||||
raise DbtProfileError(
|
||||
'Credentials in profile "{}", target "{}" invalid: {}'
|
||||
.format(profile_name, target_name, msg)
|
||||
'Credentials in profile "{}", target "{}" invalid: {}'.format(
|
||||
profile_name, target_name, msg
|
||||
)
|
||||
) from e
|
||||
|
||||
return credentials
|
||||
@@ -189,19 +188,21 @@ class Profile(HasCredentials):
|
||||
def _get_profile_data(
|
||||
profile: Dict[str, Any], profile_name: str, target_name: str
|
||||
) -> Dict[str, Any]:
|
||||
if 'outputs' not in profile:
|
||||
if "outputs" not in profile:
|
||||
raise DbtProfileError(
|
||||
"outputs not specified in profile '{}'".format(profile_name)
|
||||
)
|
||||
outputs = profile['outputs']
|
||||
outputs = profile["outputs"]
|
||||
|
||||
if target_name not in outputs:
|
||||
outputs = '\n'.join(' - {}'.format(output)
|
||||
for output in outputs)
|
||||
msg = ("The profile '{}' does not have a target named '{}'. The "
|
||||
"valid target names for this profile are:\n{}"
|
||||
.format(profile_name, target_name, outputs))
|
||||
raise DbtProfileError(msg, result_type='invalid_target')
|
||||
outputs = "\n".join(" - {}".format(output) for output in outputs)
|
||||
msg = (
|
||||
"The profile '{}' does not have a target named '{}'. The "
|
||||
"valid target names for this profile are:\n{}".format(
|
||||
profile_name, target_name, outputs
|
||||
)
|
||||
)
|
||||
raise DbtProfileError(msg, result_type="invalid_target")
|
||||
profile_data = outputs[target_name]
|
||||
|
||||
if not isinstance(profile_data, dict):
|
||||
@@ -209,7 +210,7 @@ class Profile(HasCredentials):
|
||||
f"output '{target_name}' of profile '{profile_name}' is "
|
||||
f"misconfigured in profiles.yml"
|
||||
)
|
||||
raise DbtProfileError(msg, result_type='invalid_target')
|
||||
raise DbtProfileError(msg, result_type="invalid_target")
|
||||
|
||||
return profile_data
|
||||
|
||||
@@ -220,8 +221,8 @@ class Profile(HasCredentials):
|
||||
threads: int,
|
||||
profile_name: str,
|
||||
target_name: str,
|
||||
user_cfg: Optional[Dict[str, Any]] = None
|
||||
) -> 'Profile':
|
||||
user_cfg: Optional[Dict[str, Any]] = None,
|
||||
) -> "Profile":
|
||||
"""Create a profile from an existing set of Credentials and the
|
||||
remaining information.
|
||||
|
||||
@@ -244,7 +245,7 @@ class Profile(HasCredentials):
|
||||
target_name=target_name,
|
||||
config=config,
|
||||
threads=threads,
|
||||
credentials=credentials
|
||||
credentials=credentials,
|
||||
)
|
||||
profile.validate()
|
||||
return profile
|
||||
@@ -269,19 +270,18 @@ class Profile(HasCredentials):
|
||||
# name to extract a profile that we can render.
|
||||
if target_override is not None:
|
||||
target_name = target_override
|
||||
elif 'target' in raw_profile:
|
||||
elif "target" in raw_profile:
|
||||
# render the target if it was parsed from yaml
|
||||
target_name = renderer.render_value(raw_profile['target'])
|
||||
target_name = renderer.render_value(raw_profile["target"])
|
||||
else:
|
||||
target_name = 'default'
|
||||
target_name = "default"
|
||||
logger.debug(
|
||||
"target not specified in profile '{}', using '{}'"
|
||||
.format(profile_name, target_name)
|
||||
"target not specified in profile '{}', using '{}'".format(
|
||||
profile_name, target_name
|
||||
)
|
||||
)
|
||||
|
||||
raw_profile_data = cls._get_profile_data(
|
||||
raw_profile, profile_name, target_name
|
||||
)
|
||||
raw_profile_data = cls._get_profile_data(raw_profile, profile_name, target_name)
|
||||
|
||||
try:
|
||||
profile_data = renderer.render_data(raw_profile_data)
|
||||
@@ -298,7 +298,7 @@ class Profile(HasCredentials):
|
||||
user_cfg: Optional[Dict[str, Any]] = None,
|
||||
target_override: Optional[str] = None,
|
||||
threads_override: Optional[int] = None,
|
||||
) -> 'Profile':
|
||||
) -> "Profile":
|
||||
"""Create a profile from its raw profile information.
|
||||
|
||||
(this is an intermediate step, mostly useful for unit testing)
|
||||
@@ -319,7 +319,7 @@ class Profile(HasCredentials):
|
||||
"""
|
||||
# user_cfg is not rendered.
|
||||
if user_cfg is None:
|
||||
user_cfg = raw_profile.get('config')
|
||||
user_cfg = raw_profile.get("config")
|
||||
# TODO: should it be, and the values coerced to bool?
|
||||
target_name, profile_data = cls.render_profile(
|
||||
raw_profile, profile_name, target_override, renderer
|
||||
@@ -327,7 +327,7 @@ class Profile(HasCredentials):
|
||||
|
||||
# valid connections never include the number of threads, but it's
|
||||
# stored on a per-connection level in the raw configs
|
||||
threads = profile_data.pop('threads', DEFAULT_THREADS)
|
||||
threads = profile_data.pop("threads", DEFAULT_THREADS)
|
||||
if threads_override is not None:
|
||||
threads = threads_override
|
||||
|
||||
@@ -340,7 +340,7 @@ class Profile(HasCredentials):
|
||||
profile_name=profile_name,
|
||||
target_name=target_name,
|
||||
threads=threads,
|
||||
user_cfg=user_cfg
|
||||
user_cfg=user_cfg,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -351,7 +351,7 @@ class Profile(HasCredentials):
|
||||
renderer: ProfileRenderer,
|
||||
target_override: Optional[str] = None,
|
||||
threads_override: Optional[int] = None,
|
||||
) -> 'Profile':
|
||||
) -> "Profile":
|
||||
"""
|
||||
:param raw_profiles: The profile data, from disk as yaml.
|
||||
:param profile_name: The profile name to use.
|
||||
@@ -375,15 +375,9 @@ class Profile(HasCredentials):
|
||||
# don't render keys, so we can pluck that out
|
||||
raw_profile = raw_profiles[profile_name]
|
||||
if not raw_profile:
|
||||
msg = (
|
||||
f'Profile {profile_name} in profiles.yml is empty'
|
||||
)
|
||||
raise DbtProfileError(
|
||||
INVALID_PROFILE_MESSAGE.format(
|
||||
error_string=msg
|
||||
)
|
||||
)
|
||||
user_cfg = raw_profiles.get('config')
|
||||
msg = f"Profile {profile_name} in profiles.yml is empty"
|
||||
raise DbtProfileError(INVALID_PROFILE_MESSAGE.format(error_string=msg))
|
||||
user_cfg = raw_profiles.get("config")
|
||||
|
||||
return cls.from_raw_profile_info(
|
||||
raw_profile=raw_profile,
|
||||
@@ -400,7 +394,7 @@ class Profile(HasCredentials):
|
||||
args: Any,
|
||||
renderer: ProfileRenderer,
|
||||
project_profile_name: Optional[str],
|
||||
) -> 'Profile':
|
||||
) -> "Profile":
|
||||
"""Given the raw profiles as read from disk and the name of the desired
|
||||
profile if specified, return the profile component of the runtime
|
||||
config.
|
||||
@@ -415,15 +409,16 @@ class Profile(HasCredentials):
|
||||
target could not be found.
|
||||
:returns Profile: The new Profile object.
|
||||
"""
|
||||
threads_override = getattr(args, 'threads', None)
|
||||
target_override = getattr(args, 'target', None)
|
||||
threads_override = getattr(args, "threads", None)
|
||||
target_override = getattr(args, "target", None)
|
||||
raw_profiles = read_profile(args.profiles_dir)
|
||||
profile_name = cls.pick_profile_name(getattr(args, 'profile', None),
|
||||
project_profile_name)
|
||||
profile_name = cls.pick_profile_name(
|
||||
getattr(args, "profile", None), project_profile_name
|
||||
)
|
||||
return cls.from_raw_profiles(
|
||||
raw_profiles=raw_profiles,
|
||||
profile_name=profile_name,
|
||||
renderer=renderer,
|
||||
target_override=target_override,
|
||||
threads_override=threads_override
|
||||
threads_override=threads_override,
|
||||
)
|
||||
|
||||
@@ -2,7 +2,13 @@ from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from itertools import chain
|
||||
from typing import (
|
||||
List, Dict, Any, Optional, TypeVar, Union, Mapping,
|
||||
List,
|
||||
Dict,
|
||||
Any,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
Mapping,
|
||||
)
|
||||
from typing_extensions import Protocol, runtime_checkable
|
||||
|
||||
@@ -82,9 +88,7 @@ def _load_yaml(path):
|
||||
|
||||
|
||||
def package_data_from_root(project_root):
|
||||
package_filepath = resolve_path_from_base(
|
||||
'packages.yml', project_root
|
||||
)
|
||||
package_filepath = resolve_path_from_base("packages.yml", project_root)
|
||||
|
||||
if path_exists(package_filepath):
|
||||
packages_dict = _load_yaml(package_filepath)
|
||||
@@ -95,7 +99,7 @@ def package_data_from_root(project_root):
|
||||
|
||||
def package_config_from_data(packages_data: Dict[str, Any]):
|
||||
if not packages_data:
|
||||
packages_data = {'packages': []}
|
||||
packages_data = {"packages": []}
|
||||
|
||||
try:
|
||||
PackageConfig.validate(packages_data)
|
||||
@@ -118,7 +122,7 @@ def _parse_versions(versions: Union[List[str], str]) -> List[VersionSpecifier]:
|
||||
Regardless, this will return a list of VersionSpecifiers
|
||||
"""
|
||||
if isinstance(versions, str):
|
||||
versions = versions.split(',')
|
||||
versions = versions.split(",")
|
||||
return [VersionSpecifier.from_version_string(v) for v in versions]
|
||||
|
||||
|
||||
@@ -129,11 +133,12 @@ def _all_source_paths(
|
||||
analysis_paths: List[str],
|
||||
macro_paths: List[str],
|
||||
) -> List[str]:
|
||||
return list(chain(source_paths, data_paths, snapshot_paths, analysis_paths,
|
||||
macro_paths))
|
||||
return list(
|
||||
chain(source_paths, data_paths, snapshot_paths, analysis_paths, macro_paths)
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def value_or(value: Optional[T], default: T) -> T:
|
||||
@@ -146,21 +151,18 @@ def value_or(value: Optional[T], default: T) -> T:
|
||||
def _raw_project_from(project_root: str) -> Dict[str, Any]:
|
||||
|
||||
project_root = os.path.normpath(project_root)
|
||||
project_yaml_filepath = os.path.join(project_root, 'dbt_project.yml')
|
||||
project_yaml_filepath = os.path.join(project_root, "dbt_project.yml")
|
||||
|
||||
# get the project.yml contents
|
||||
if not path_exists(project_yaml_filepath):
|
||||
raise DbtProjectError(
|
||||
'no dbt_project.yml found at expected path {}'
|
||||
.format(project_yaml_filepath)
|
||||
"no dbt_project.yml found at expected path {}".format(project_yaml_filepath)
|
||||
)
|
||||
|
||||
project_dict = _load_yaml(project_yaml_filepath)
|
||||
|
||||
if not isinstance(project_dict, dict):
|
||||
raise DbtProjectError(
|
||||
'dbt_project.yml does not parse to a dictionary'
|
||||
)
|
||||
raise DbtProjectError("dbt_project.yml does not parse to a dictionary")
|
||||
|
||||
return project_dict
|
||||
|
||||
@@ -169,7 +171,7 @@ def _query_comment_from_cfg(
|
||||
cfg_query_comment: Union[QueryComment, NoValue, str, None]
|
||||
) -> QueryComment:
|
||||
if not cfg_query_comment:
|
||||
return QueryComment(comment='')
|
||||
return QueryComment(comment="")
|
||||
|
||||
if isinstance(cfg_query_comment, str):
|
||||
return QueryComment(comment=cfg_query_comment)
|
||||
@@ -186,9 +188,7 @@ def validate_version(dbt_version: List[VersionSpecifier], project_name: str):
|
||||
if not versions_compatible(*dbt_version):
|
||||
msg = IMPOSSIBLE_VERSION_ERROR.format(
|
||||
package=project_name,
|
||||
version_spec=[
|
||||
x.to_version_string() for x in dbt_version
|
||||
]
|
||||
version_spec=[x.to_version_string() for x in dbt_version],
|
||||
)
|
||||
raise DbtProjectError(msg)
|
||||
|
||||
@@ -196,9 +196,7 @@ def validate_version(dbt_version: List[VersionSpecifier], project_name: str):
|
||||
msg = INVALID_VERSION_ERROR.format(
|
||||
package=project_name,
|
||||
installed=installed.to_version_string(),
|
||||
version_spec=[
|
||||
x.to_version_string() for x in dbt_version
|
||||
]
|
||||
version_spec=[x.to_version_string() for x in dbt_version],
|
||||
)
|
||||
raise DbtProjectError(msg)
|
||||
|
||||
@@ -207,8 +205,8 @@ def _get_required_version(
|
||||
project_dict: Dict[str, Any],
|
||||
verify_version: bool,
|
||||
) -> List[VersionSpecifier]:
|
||||
dbt_raw_version: Union[List[str], str] = '>=0.0.0'
|
||||
required = project_dict.get('require-dbt-version')
|
||||
dbt_raw_version: Union[List[str], str] = ">=0.0.0"
|
||||
required = project_dict.get("require-dbt-version")
|
||||
if required is not None:
|
||||
dbt_raw_version = required
|
||||
|
||||
@@ -219,11 +217,11 @@ def _get_required_version(
|
||||
|
||||
if verify_version:
|
||||
# no name is also an error that we want to raise
|
||||
if 'name' not in project_dict:
|
||||
if "name" not in project_dict:
|
||||
raise DbtProjectError(
|
||||
'Required "name" field not present in project',
|
||||
)
|
||||
validate_version(dbt_version, project_dict['name'])
|
||||
validate_version(dbt_version, project_dict["name"])
|
||||
|
||||
return dbt_version
|
||||
|
||||
@@ -231,34 +229,36 @@ def _get_required_version(
|
||||
@dataclass
|
||||
class RenderComponents:
|
||||
project_dict: Dict[str, Any] = field(
|
||||
metadata=dict(description='The project dictionary')
|
||||
metadata=dict(description="The project dictionary")
|
||||
)
|
||||
packages_dict: Dict[str, Any] = field(
|
||||
metadata=dict(description='The packages dictionary')
|
||||
metadata=dict(description="The packages dictionary")
|
||||
)
|
||||
selectors_dict: Dict[str, Any] = field(
|
||||
metadata=dict(description='The selectors dictionary')
|
||||
metadata=dict(description="The selectors dictionary")
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PartialProject(RenderComponents):
|
||||
profile_name: Optional[str] = field(metadata=dict(
|
||||
description='The unrendered profile name in the project, if set'
|
||||
))
|
||||
project_name: Optional[str] = field(metadata=dict(
|
||||
description=(
|
||||
'The name of the project. This should always be set and will not '
|
||||
'be rendered'
|
||||
profile_name: Optional[str] = field(
|
||||
metadata=dict(description="The unrendered profile name in the project, if set")
|
||||
)
|
||||
project_name: Optional[str] = field(
|
||||
metadata=dict(
|
||||
description=(
|
||||
"The name of the project. This should always be set and will not "
|
||||
"be rendered"
|
||||
)
|
||||
)
|
||||
)
|
||||
))
|
||||
project_root: str = field(
|
||||
metadata=dict(description='The root directory of the project'),
|
||||
metadata=dict(description="The root directory of the project"),
|
||||
)
|
||||
verify_version: bool = field(
|
||||
metadata=dict(description=(
|
||||
'If True, verify the dbt version matches the required version'
|
||||
))
|
||||
metadata=dict(
|
||||
description=("If True, verify the dbt version matches the required version")
|
||||
)
|
||||
)
|
||||
|
||||
def render_profile_name(self, renderer) -> Optional[str]:
|
||||
@@ -271,9 +271,7 @@ class PartialProject(RenderComponents):
|
||||
renderer: DbtProjectYamlRenderer,
|
||||
) -> RenderComponents:
|
||||
|
||||
rendered_project = renderer.render_project(
|
||||
self.project_dict, self.project_root
|
||||
)
|
||||
rendered_project = renderer.render_project(self.project_dict, self.project_root)
|
||||
rendered_packages = renderer.render_packages(self.packages_dict)
|
||||
rendered_selectors = renderer.render_selectors(self.selectors_dict)
|
||||
|
||||
@@ -283,16 +281,16 @@ class PartialProject(RenderComponents):
|
||||
selectors_dict=rendered_selectors,
|
||||
)
|
||||
|
||||
def render(self, renderer: DbtProjectYamlRenderer) -> 'Project':
|
||||
def render(self, renderer: DbtProjectYamlRenderer) -> "Project":
|
||||
try:
|
||||
rendered = self.get_rendered(renderer)
|
||||
return self.create_project(rendered)
|
||||
except DbtProjectError as exc:
|
||||
if exc.path is None:
|
||||
exc.path = os.path.join(self.project_root, 'dbt_project.yml')
|
||||
exc.path = os.path.join(self.project_root, "dbt_project.yml")
|
||||
raise
|
||||
|
||||
def create_project(self, rendered: RenderComponents) -> 'Project':
|
||||
def create_project(self, rendered: RenderComponents) -> "Project":
|
||||
unrendered = RenderComponents(
|
||||
project_dict=self.project_dict,
|
||||
packages_dict=self.packages_dict,
|
||||
@@ -305,9 +303,7 @@ class PartialProject(RenderComponents):
|
||||
|
||||
try:
|
||||
ProjectContract.validate(rendered.project_dict)
|
||||
cfg = ProjectContract.from_dict(
|
||||
rendered.project_dict
|
||||
)
|
||||
cfg = ProjectContract.from_dict(rendered.project_dict)
|
||||
except ValidationError as e:
|
||||
raise DbtProjectError(validator_error_message(e)) from e
|
||||
# name/version are required in the Project definition, so we can assume
|
||||
@@ -317,31 +313,30 @@ class PartialProject(RenderComponents):
|
||||
# this is added at project_dict parse time and should always be here
|
||||
# once we see it.
|
||||
if cfg.project_root is None:
|
||||
raise DbtProjectError('cfg must have a project root!')
|
||||
raise DbtProjectError("cfg must have a project root!")
|
||||
else:
|
||||
project_root = cfg.project_root
|
||||
# this is only optional in the sense that if it's not present, it needs
|
||||
# to have been a cli argument.
|
||||
profile_name = cfg.profile
|
||||
# these are all the defaults
|
||||
source_paths: List[str] = value_or(cfg.source_paths, ['models'])
|
||||
macro_paths: List[str] = value_or(cfg.macro_paths, ['macros'])
|
||||
data_paths: List[str] = value_or(cfg.data_paths, ['data'])
|
||||
test_paths: List[str] = value_or(cfg.test_paths, ['test'])
|
||||
source_paths: List[str] = value_or(cfg.source_paths, ["models"])
|
||||
macro_paths: List[str] = value_or(cfg.macro_paths, ["macros"])
|
||||
data_paths: List[str] = value_or(cfg.data_paths, ["data"])
|
||||
test_paths: List[str] = value_or(cfg.test_paths, ["test"])
|
||||
analysis_paths: List[str] = value_or(cfg.analysis_paths, [])
|
||||
snapshot_paths: List[str] = value_or(cfg.snapshot_paths, ['snapshots'])
|
||||
snapshot_paths: List[str] = value_or(cfg.snapshot_paths, ["snapshots"])
|
||||
|
||||
all_source_paths: List[str] = _all_source_paths(
|
||||
source_paths, data_paths, snapshot_paths, analysis_paths,
|
||||
macro_paths
|
||||
source_paths, data_paths, snapshot_paths, analysis_paths, macro_paths
|
||||
)
|
||||
|
||||
docs_paths: List[str] = value_or(cfg.docs_paths, all_source_paths)
|
||||
asset_paths: List[str] = value_or(cfg.asset_paths, [])
|
||||
target_path: str = value_or(cfg.target_path, 'target')
|
||||
target_path: str = value_or(cfg.target_path, "target")
|
||||
clean_targets: List[str] = value_or(cfg.clean_targets, [target_path])
|
||||
log_path: str = value_or(cfg.log_path, 'logs')
|
||||
modules_path: str = value_or(cfg.modules_path, 'dbt_modules')
|
||||
log_path: str = value_or(cfg.log_path, "logs")
|
||||
modules_path: str = value_or(cfg.modules_path, "dbt_modules")
|
||||
# in the default case we'll populate this once we know the adapter type
|
||||
# It would be nice to just pass along a Quoting here, but that would
|
||||
# break many things
|
||||
@@ -373,11 +368,12 @@ class PartialProject(RenderComponents):
|
||||
packages = package_config_from_data(rendered.packages_dict)
|
||||
selectors = selector_config_from_data(rendered.selectors_dict)
|
||||
manifest_selectors: Dict[str, Any] = {}
|
||||
if rendered.selectors_dict and rendered.selectors_dict['selectors']:
|
||||
if rendered.selectors_dict and rendered.selectors_dict["selectors"]:
|
||||
# this is a dict with a single key 'selectors' pointing to a list
|
||||
# of dicts.
|
||||
manifest_selectors = SelectorDict.parse_from_selectors_list(
|
||||
rendered.selectors_dict['selectors'])
|
||||
rendered.selectors_dict["selectors"]
|
||||
)
|
||||
|
||||
project = Project(
|
||||
project_name=name,
|
||||
@@ -426,10 +422,9 @@ class PartialProject(RenderComponents):
|
||||
*,
|
||||
verify_version: bool = False,
|
||||
):
|
||||
"""Construct a partial project from its constituent dicts.
|
||||
"""
|
||||
project_name = project_dict.get('name')
|
||||
profile_name = project_dict.get('profile')
|
||||
"""Construct a partial project from its constituent dicts."""
|
||||
project_name = project_dict.get("name")
|
||||
profile_name = project_dict.get("profile")
|
||||
|
||||
return cls(
|
||||
profile_name=profile_name,
|
||||
@@ -444,14 +439,14 @@ class PartialProject(RenderComponents):
|
||||
@classmethod
|
||||
def from_project_root(
|
||||
cls, project_root: str, *, verify_version: bool = False
|
||||
) -> 'PartialProject':
|
||||
) -> "PartialProject":
|
||||
project_root = os.path.normpath(project_root)
|
||||
project_dict = _raw_project_from(project_root)
|
||||
config_version = project_dict.get('config-version', 1)
|
||||
config_version = project_dict.get("config-version", 1)
|
||||
if config_version != 2:
|
||||
raise DbtProjectError(
|
||||
f'Invalid config version: {config_version}, expected 2',
|
||||
path=os.path.join(project_root, 'dbt_project.yml')
|
||||
f"Invalid config version: {config_version}, expected 2",
|
||||
path=os.path.join(project_root, "dbt_project.yml"),
|
||||
)
|
||||
|
||||
packages_dict = package_data_from_root(project_root)
|
||||
@@ -468,15 +463,10 @@ class PartialProject(RenderComponents):
|
||||
class VarProvider:
|
||||
"""Var providers are tied to a particular Project."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vars: Dict[str, Dict[str, Any]]
|
||||
) -> None:
|
||||
def __init__(self, vars: Dict[str, Dict[str, Any]]) -> None:
|
||||
self.vars = vars
|
||||
|
||||
def vars_for(
|
||||
self, node: IsFQNResource, adapter_type: str
|
||||
) -> Mapping[str, Any]:
|
||||
def vars_for(self, node: IsFQNResource, adapter_type: str) -> Mapping[str, Any]:
|
||||
# in v2, vars are only either project or globally scoped
|
||||
merged = MultiDict([self.vars])
|
||||
merged.add(self.vars.get(node.package_name, {}))
|
||||
@@ -525,8 +515,11 @@ class Project:
|
||||
@property
|
||||
def all_source_paths(self) -> List[str]:
|
||||
return _all_source_paths(
|
||||
self.source_paths, self.data_paths, self.snapshot_paths,
|
||||
self.analysis_paths, self.macro_paths
|
||||
self.source_paths,
|
||||
self.data_paths,
|
||||
self.snapshot_paths,
|
||||
self.analysis_paths,
|
||||
self.macro_paths,
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
@@ -534,11 +527,13 @@ class Project:
|
||||
return str(cfg)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not (isinstance(other, self.__class__) and
|
||||
isinstance(self, other.__class__)):
|
||||
if not (
|
||||
isinstance(other, self.__class__) and isinstance(self, other.__class__)
|
||||
):
|
||||
return False
|
||||
return self.to_project_config(with_packages=True) == \
|
||||
other.to_project_config(with_packages=True)
|
||||
return self.to_project_config(with_packages=True) == other.to_project_config(
|
||||
with_packages=True
|
||||
)
|
||||
|
||||
def to_project_config(self, with_packages=False):
|
||||
"""Return a dict representation of the config that could be written to
|
||||
@@ -548,38 +543,39 @@ class Project:
|
||||
file in the root.
|
||||
:returns dict: The serialized profile.
|
||||
"""
|
||||
result = deepcopy({
|
||||
'name': self.project_name,
|
||||
'version': self.version,
|
||||
'project-root': self.project_root,
|
||||
'profile': self.profile_name,
|
||||
'source-paths': self.source_paths,
|
||||
'macro-paths': self.macro_paths,
|
||||
'data-paths': self.data_paths,
|
||||
'test-paths': self.test_paths,
|
||||
'analysis-paths': self.analysis_paths,
|
||||
'docs-paths': self.docs_paths,
|
||||
'asset-paths': self.asset_paths,
|
||||
'target-path': self.target_path,
|
||||
'snapshot-paths': self.snapshot_paths,
|
||||
'clean-targets': self.clean_targets,
|
||||
'log-path': self.log_path,
|
||||
'quoting': self.quoting,
|
||||
'models': self.models,
|
||||
'on-run-start': self.on_run_start,
|
||||
'on-run-end': self.on_run_end,
|
||||
'seeds': self.seeds,
|
||||
'snapshots': self.snapshots,
|
||||
'sources': self.sources,
|
||||
'vars': self.vars.to_dict(),
|
||||
'require-dbt-version': [
|
||||
result = deepcopy(
|
||||
{
|
||||
"name": self.project_name,
|
||||
"version": self.version,
|
||||
"project-root": self.project_root,
|
||||
"profile": self.profile_name,
|
||||
"source-paths": self.source_paths,
|
||||
"macro-paths": self.macro_paths,
|
||||
"data-paths": self.data_paths,
|
||||
"test-paths": self.test_paths,
|
||||
"analysis-paths": self.analysis_paths,
|
||||
"docs-paths": self.docs_paths,
|
||||
"asset-paths": self.asset_paths,
|
||||
"target-path": self.target_path,
|
||||
"snapshot-paths": self.snapshot_paths,
|
||||
"clean-targets": self.clean_targets,
|
||||
"log-path": self.log_path,
|
||||
"quoting": self.quoting,
|
||||
"models": self.models,
|
||||
"on-run-start": self.on_run_start,
|
||||
"on-run-end": self.on_run_end,
|
||||
"seeds": self.seeds,
|
||||
"snapshots": self.snapshots,
|
||||
"sources": self.sources,
|
||||
"vars": self.vars.to_dict(),
|
||||
"require-dbt-version": [
|
||||
v.to_version_string() for v in self.dbt_version
|
||||
],
|
||||
'config-version': self.config_version,
|
||||
})
|
||||
"config-version": self.config_version,
|
||||
}
|
||||
)
|
||||
if self.query_comment:
|
||||
result['query-comment'] = \
|
||||
self.query_comment.to_dict(omit_none=True)
|
||||
result["query-comment"] = self.query_comment.to_dict(omit_none=True)
|
||||
|
||||
if with_packages:
|
||||
result.update(self.packages.to_dict(omit_none=True))
|
||||
@@ -610,8 +606,8 @@ class Project:
|
||||
selectors_dict: Dict[str, Any],
|
||||
renderer: DbtProjectYamlRenderer,
|
||||
*,
|
||||
verify_version: bool = False
|
||||
) -> 'Project':
|
||||
verify_version: bool = False,
|
||||
) -> "Project":
|
||||
partial = PartialProject.from_dicts(
|
||||
project_root=project_root,
|
||||
project_dict=project_dict,
|
||||
@@ -628,17 +624,17 @@ class Project:
|
||||
renderer: DbtProjectYamlRenderer,
|
||||
*,
|
||||
verify_version: bool = False,
|
||||
) -> 'Project':
|
||||
) -> "Project":
|
||||
partial = cls.partial_load(project_root, verify_version=verify_version)
|
||||
return partial.render(renderer)
|
||||
|
||||
def hashed_name(self):
|
||||
return hashlib.md5(self.project_name.encode('utf-8')).hexdigest()
|
||||
return hashlib.md5(self.project_name.encode("utf-8")).hexdigest()
|
||||
|
||||
def get_selector(self, name: str) -> SelectionSpec:
|
||||
if name not in self.selectors:
|
||||
raise RuntimeException(
|
||||
f'Could not find selector named {name}, expected one of '
|
||||
f'{list(self.selectors)}'
|
||||
f"Could not find selector named {name}, expected one of "
|
||||
f"{list(self.selectors)}"
|
||||
)
|
||||
return self.selectors[name]
|
||||
|
||||
@@ -2,9 +2,7 @@ from typing import Dict, Any, Tuple, Optional, Union, Callable
|
||||
|
||||
from dbt.clients.jinja import get_rendered, catch_jinja
|
||||
|
||||
from dbt.exceptions import (
|
||||
DbtProjectError, CompilationException, RecursionException
|
||||
)
|
||||
from dbt.exceptions import DbtProjectError, CompilationException, RecursionException
|
||||
from dbt.node_types import NodeType
|
||||
from dbt.utils import deep_map
|
||||
|
||||
@@ -18,7 +16,7 @@ class BaseRenderer:
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return 'Rendering'
|
||||
return "Rendering"
|
||||
|
||||
def should_render_keypath(self, keypath: Keypath) -> bool:
|
||||
return True
|
||||
@@ -29,9 +27,7 @@ class BaseRenderer:
|
||||
|
||||
return self.render_value(value, keypath)
|
||||
|
||||
def render_value(
|
||||
self, value: Any, keypath: Optional[Keypath] = None
|
||||
) -> Any:
|
||||
def render_value(self, value: Any, keypath: Optional[Keypath] = None) -> Any:
|
||||
# keypath is ignored.
|
||||
# if it wasn't read as a string, ignore it
|
||||
if not isinstance(value, str):
|
||||
@@ -40,18 +36,16 @@ class BaseRenderer:
|
||||
with catch_jinja():
|
||||
return get_rendered(value, self.context, native=True)
|
||||
except CompilationException as exc:
|
||||
msg = f'Could not render {value}: {exc.msg}'
|
||||
msg = f"Could not render {value}: {exc.msg}"
|
||||
raise CompilationException(msg) from exc
|
||||
|
||||
def render_data(
|
||||
self, data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
def render_data(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
try:
|
||||
return deep_map(self.render_entry, data)
|
||||
except RecursionException:
|
||||
raise DbtProjectError(
|
||||
f'Cycle detected: {self.name} input has a reference to itself',
|
||||
project=data
|
||||
f"Cycle detected: {self.name} input has a reference to itself",
|
||||
project=data,
|
||||
)
|
||||
|
||||
|
||||
@@ -78,15 +72,15 @@ class ProjectPostprocessor(Dict[Keypath, Callable[[Any], Any]]):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self[('on-run-start',)] = _list_if_none_or_string
|
||||
self[('on-run-end',)] = _list_if_none_or_string
|
||||
self[("on-run-start",)] = _list_if_none_or_string
|
||||
self[("on-run-end",)] = _list_if_none_or_string
|
||||
|
||||
for k in ('models', 'seeds', 'snapshots'):
|
||||
for k in ("models", "seeds", "snapshots"):
|
||||
self[(k,)] = _dict_if_none
|
||||
self[(k, 'vars')] = _dict_if_none
|
||||
self[(k, 'pre-hook')] = _list_if_none_or_string
|
||||
self[(k, 'post-hook')] = _list_if_none_or_string
|
||||
self[('seeds', 'column_types')] = _dict_if_none
|
||||
self[(k, "vars")] = _dict_if_none
|
||||
self[(k, "pre-hook")] = _list_if_none_or_string
|
||||
self[(k, "post-hook")] = _list_if_none_or_string
|
||||
self[("seeds", "column_types")] = _dict_if_none
|
||||
|
||||
def postprocess(self, value: Any, key: Keypath) -> Any:
|
||||
if key in self:
|
||||
@@ -101,7 +95,7 @@ class DbtProjectYamlRenderer(BaseRenderer):
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
'Project config'
|
||||
"Project config"
|
||||
|
||||
def get_package_renderer(self) -> BaseRenderer:
|
||||
return PackageRenderer(self.context)
|
||||
@@ -116,7 +110,7 @@ class DbtProjectYamlRenderer(BaseRenderer):
|
||||
) -> Dict[str, Any]:
|
||||
"""Render the project and insert the project root after rendering."""
|
||||
rendered_project = self.render_data(project)
|
||||
rendered_project['project-root'] = project_root
|
||||
rendered_project["project-root"] = project_root
|
||||
return rendered_project
|
||||
|
||||
def render_packages(self, packages: Dict[str, Any]):
|
||||
@@ -138,20 +132,19 @@ class DbtProjectYamlRenderer(BaseRenderer):
|
||||
|
||||
first = keypath[0]
|
||||
# run hooks are not rendered
|
||||
if first in {'on-run-start', 'on-run-end', 'query-comment'}:
|
||||
if first in {"on-run-start", "on-run-end", "query-comment"}:
|
||||
return False
|
||||
|
||||
# don't render vars blocks until runtime
|
||||
if first == 'vars':
|
||||
if first == "vars":
|
||||
return False
|
||||
|
||||
if first in {'seeds', 'models', 'snapshots', 'seeds'}:
|
||||
if first in {"seeds", "models", "snapshots", "seeds"}:
|
||||
keypath_parts = {
|
||||
(k.lstrip('+') if isinstance(k, str) else k)
|
||||
for k in keypath
|
||||
(k.lstrip("+") if isinstance(k, str) else k) for k in keypath
|
||||
}
|
||||
# model-level hooks
|
||||
if 'pre-hook' in keypath_parts or 'post-hook' in keypath_parts:
|
||||
if "pre-hook" in keypath_parts or "post-hook" in keypath_parts:
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -160,17 +153,15 @@ class DbtProjectYamlRenderer(BaseRenderer):
|
||||
class ProfileRenderer(BaseRenderer):
|
||||
@property
|
||||
def name(self):
|
||||
'Profile'
|
||||
"Profile"
|
||||
|
||||
|
||||
class SchemaYamlRenderer(BaseRenderer):
|
||||
DOCUMENTABLE_NODES = frozenset(
|
||||
n.pluralize() for n in NodeType.documentable()
|
||||
)
|
||||
DOCUMENTABLE_NODES = frozenset(n.pluralize() for n in NodeType.documentable())
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return 'Rendering yaml'
|
||||
return "Rendering yaml"
|
||||
|
||||
def _is_norender_key(self, keypath: Keypath) -> bool:
|
||||
"""
|
||||
@@ -185,13 +176,13 @@ class SchemaYamlRenderer(BaseRenderer):
|
||||
|
||||
Return True if it's tests or description - those aren't rendered
|
||||
"""
|
||||
if len(keypath) >= 2 and keypath[1] in ('tests', 'description'):
|
||||
if len(keypath) >= 2 and keypath[1] in ("tests", "description"):
|
||||
return True
|
||||
|
||||
if (
|
||||
len(keypath) >= 4 and
|
||||
keypath[1] == 'columns' and
|
||||
keypath[3] in ('tests', 'description')
|
||||
len(keypath) >= 4
|
||||
and keypath[1] == "columns"
|
||||
and keypath[3] in ("tests", "description")
|
||||
):
|
||||
return True
|
||||
|
||||
@@ -209,13 +200,13 @@ class SchemaYamlRenderer(BaseRenderer):
|
||||
return True
|
||||
|
||||
if keypath[0] == NodeType.Source.pluralize():
|
||||
if keypath[2] == 'description':
|
||||
if keypath[2] == "description":
|
||||
return False
|
||||
if keypath[2] == 'tables':
|
||||
if keypath[2] == "tables":
|
||||
if self._is_norender_key(keypath[3:]):
|
||||
return False
|
||||
elif keypath[0] == NodeType.Macro.pluralize():
|
||||
if keypath[2] == 'arguments':
|
||||
if keypath[2] == "arguments":
|
||||
if self._is_norender_key(keypath[3:]):
|
||||
return False
|
||||
elif self._is_norender_key(keypath[1:]):
|
||||
@@ -229,10 +220,10 @@ class SchemaYamlRenderer(BaseRenderer):
|
||||
class PackageRenderer(BaseRenderer):
|
||||
@property
|
||||
def name(self):
|
||||
return 'Packages config'
|
||||
return "Packages config"
|
||||
|
||||
|
||||
class SelectorRenderer(BaseRenderer):
|
||||
@property
|
||||
def name(self):
|
||||
return 'Selector config'
|
||||
return "Selector config"
|
||||
|
||||
@@ -4,8 +4,16 @@ from copy import deepcopy
|
||||
from dataclasses import dataclass, fields
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Dict, Any, Optional, Mapping, Iterator, Iterable, Tuple, List, MutableSet,
|
||||
Type
|
||||
Dict,
|
||||
Any,
|
||||
Optional,
|
||||
Mapping,
|
||||
Iterator,
|
||||
Iterable,
|
||||
Tuple,
|
||||
List,
|
||||
MutableSet,
|
||||
Type,
|
||||
)
|
||||
|
||||
from .profile import Profile
|
||||
@@ -15,7 +23,7 @@ from .utils import parse_cli_vars
|
||||
from dbt import tracking
|
||||
from dbt.adapters.factory import get_relation_class_by_name, get_include_paths
|
||||
from dbt.helper_types import FQNPath, PathSet
|
||||
from dbt.context.base import generate_base_context
|
||||
from dbt.context import generate_base_context
|
||||
from dbt.context.target import generate_target_context
|
||||
from dbt.contracts.connection import AdapterRequiredConfig, Credentials
|
||||
from dbt.contracts.graph.manifest import ManifestMetadata
|
||||
@@ -30,15 +38,13 @@ from dbt.exceptions import (
|
||||
DbtProjectError,
|
||||
validator_error_message,
|
||||
warn_or_error,
|
||||
raise_compiler_error
|
||||
raise_compiler_error,
|
||||
)
|
||||
|
||||
from dbt.dataclass_schema import ValidationError
|
||||
|
||||
|
||||
def _project_quoting_dict(
|
||||
proj: Project, profile: Profile
|
||||
) -> Dict[ComponentName, bool]:
|
||||
def _project_quoting_dict(proj: Project, profile: Profile) -> Dict[ComponentName, bool]:
|
||||
src: Dict[str, Any] = profile.credentials.translate_aliases(proj.quoting)
|
||||
result: Dict[ComponentName, bool] = {}
|
||||
for key in ComponentName:
|
||||
@@ -54,7 +60,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
||||
args: Any
|
||||
profile_name: str
|
||||
cli_vars: Dict[str, Any]
|
||||
dependencies: Optional[Mapping[str, 'RuntimeConfig']] = None
|
||||
dependencies: Optional[Mapping[str, "RuntimeConfig"]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.validate()
|
||||
@@ -65,8 +71,8 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
||||
project: Project,
|
||||
profile: Profile,
|
||||
args: Any,
|
||||
dependencies: Optional[Mapping[str, 'RuntimeConfig']] = None,
|
||||
) -> 'RuntimeConfig':
|
||||
dependencies: Optional[Mapping[str, "RuntimeConfig"]] = None,
|
||||
) -> "RuntimeConfig":
|
||||
"""Instantiate a RuntimeConfig from its components.
|
||||
|
||||
:param profile: A parsed dbt Profile.
|
||||
@@ -80,7 +86,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
||||
.replace_dict(_project_quoting_dict(project, profile))
|
||||
).to_dict(omit_none=True)
|
||||
|
||||
cli_vars: Dict[str, Any] = parse_cli_vars(getattr(args, 'vars', '{}'))
|
||||
cli_vars: Dict[str, Any] = parse_cli_vars(getattr(args, "vars", "{}"))
|
||||
|
||||
return cls(
|
||||
project_name=project.project_name,
|
||||
@@ -123,7 +129,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
||||
dependencies=dependencies,
|
||||
)
|
||||
|
||||
def new_project(self, project_root: str) -> 'RuntimeConfig':
|
||||
def new_project(self, project_root: str) -> "RuntimeConfig":
|
||||
"""Given a new project root, read in its project dictionary, supply the
|
||||
existing project's profile info, and create a new project file.
|
||||
|
||||
@@ -142,7 +148,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
||||
project = Project.from_project_root(
|
||||
project_root,
|
||||
renderer,
|
||||
verify_version=getattr(self.args, 'version_check', False),
|
||||
verify_version=getattr(self.args, "version_check", False),
|
||||
)
|
||||
|
||||
cfg = self.from_parts(
|
||||
@@ -165,7 +171,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
||||
"""
|
||||
result = self.to_project_config(with_packages=True)
|
||||
result.update(self.to_profile_info(serialize_credentials=True))
|
||||
result['cli_vars'] = deepcopy(self.cli_vars)
|
||||
result["cli_vars"] = deepcopy(self.cli_vars)
|
||||
return result
|
||||
|
||||
def validate(self):
|
||||
@@ -185,30 +191,21 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
||||
profile_renderer: ProfileRenderer,
|
||||
profile_name: Optional[str],
|
||||
) -> Profile:
|
||||
return Profile.render_from_args(
|
||||
args, profile_renderer, profile_name
|
||||
)
|
||||
return Profile.render_from_args(args, profile_renderer, profile_name)
|
||||
|
||||
@classmethod
|
||||
def collect_parts(
|
||||
cls: Type['RuntimeConfig'], args: Any
|
||||
) -> Tuple[Project, Profile]:
|
||||
def collect_parts(cls: Type["RuntimeConfig"], args: Any) -> Tuple[Project, Profile]:
|
||||
# profile_name from the project
|
||||
project_root = args.project_dir if args.project_dir else os.getcwd()
|
||||
version_check = getattr(args, 'version_check', False)
|
||||
partial = Project.partial_load(
|
||||
project_root,
|
||||
verify_version=version_check
|
||||
)
|
||||
version_check = getattr(args, "version_check", False)
|
||||
partial = Project.partial_load(project_root, verify_version=version_check)
|
||||
|
||||
# build the profile using the base renderer and the one fact we know
|
||||
cli_vars: Dict[str, Any] = parse_cli_vars(getattr(args, 'vars', '{}'))
|
||||
cli_vars: Dict[str, Any] = parse_cli_vars(getattr(args, "vars", "{}"))
|
||||
profile_renderer = ProfileRenderer(generate_base_context(cli_vars))
|
||||
profile_name = partial.render_profile_name(profile_renderer)
|
||||
|
||||
profile = cls._get_rendered_profile(
|
||||
args, profile_renderer, profile_name
|
||||
)
|
||||
profile = cls._get_rendered_profile(args, profile_renderer, profile_name)
|
||||
|
||||
# get a new renderer using our target information and render the
|
||||
# project
|
||||
@@ -218,7 +215,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
||||
return (project, profile)
|
||||
|
||||
@classmethod
|
||||
def from_args(cls, args: Any) -> 'RuntimeConfig':
|
||||
def from_args(cls, args: Any) -> "RuntimeConfig":
|
||||
"""Given arguments, read in dbt_project.yml from the current directory,
|
||||
read in packages.yml if it exists, and use them to find the profile to
|
||||
load.
|
||||
@@ -238,8 +235,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
||||
|
||||
def get_metadata(self) -> ManifestMetadata:
|
||||
return ManifestMetadata(
|
||||
project_id=self.hashed_name(),
|
||||
adapter_type=self.credentials.type
|
||||
project_id=self.hashed_name(), adapter_type=self.credentials.type
|
||||
)
|
||||
|
||||
def _get_v2_config_paths(
|
||||
@@ -249,7 +245,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
||||
paths: MutableSet[FQNPath],
|
||||
) -> PathSet:
|
||||
for key, value in config.items():
|
||||
if isinstance(value, dict) and not key.startswith('+'):
|
||||
if isinstance(value, dict) and not key.startswith("+"):
|
||||
self._get_v2_config_paths(value, path + (key,), paths)
|
||||
else:
|
||||
paths.add(path)
|
||||
@@ -265,7 +261,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
||||
paths = set()
|
||||
|
||||
for key, value in config.items():
|
||||
if isinstance(value, dict) and not key.startswith('+'):
|
||||
if isinstance(value, dict) and not key.startswith("+"):
|
||||
self._get_v2_config_paths(value, path + (key,), paths)
|
||||
else:
|
||||
paths.add(path)
|
||||
@@ -277,10 +273,10 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
||||
a configured path in the resource.
|
||||
"""
|
||||
return {
|
||||
'models': self._get_config_paths(self.models),
|
||||
'seeds': self._get_config_paths(self.seeds),
|
||||
'snapshots': self._get_config_paths(self.snapshots),
|
||||
'sources': self._get_config_paths(self.sources),
|
||||
"models": self._get_config_paths(self.models),
|
||||
"seeds": self._get_config_paths(self.seeds),
|
||||
"snapshots": self._get_config_paths(self.snapshots),
|
||||
"sources": self._get_config_paths(self.sources),
|
||||
}
|
||||
|
||||
def get_unused_resource_config_paths(
|
||||
@@ -301,9 +297,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
||||
|
||||
for config_path in config_paths:
|
||||
if not _is_config_used(config_path, fqns):
|
||||
unused_resource_config_paths.append(
|
||||
(resource_type,) + config_path
|
||||
)
|
||||
unused_resource_config_paths.append((resource_type,) + config_path)
|
||||
return unused_resource_config_paths
|
||||
|
||||
def warn_for_unused_resource_config_paths(
|
||||
@@ -316,27 +310,25 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
||||
return
|
||||
|
||||
msg = UNUSED_RESOURCE_CONFIGURATION_PATH_MESSAGE.format(
|
||||
len(unused),
|
||||
'\n'.join('- {}'.format('.'.join(u)) for u in unused)
|
||||
len(unused), "\n".join("- {}".format(".".join(u)) for u in unused)
|
||||
)
|
||||
|
||||
warn_or_error(msg, log_fmt=warning_tag('{}'))
|
||||
warn_or_error(msg, log_fmt=warning_tag("{}"))
|
||||
|
||||
def load_dependencies(self) -> Mapping[str, 'RuntimeConfig']:
|
||||
def load_dependencies(self) -> Mapping[str, "RuntimeConfig"]:
|
||||
if self.dependencies is None:
|
||||
all_projects = {self.project_name: self}
|
||||
internal_packages = get_include_paths(self.credentials.type)
|
||||
project_paths = itertools.chain(
|
||||
internal_packages,
|
||||
self._get_project_directories()
|
||||
internal_packages, self._get_project_directories()
|
||||
)
|
||||
for project_name, project in self.load_projects(project_paths):
|
||||
if project_name in all_projects:
|
||||
raise_compiler_error(
|
||||
f'dbt found more than one package with the name '
|
||||
f"dbt found more than one package with the name "
|
||||
f'"{project_name}" included in this project. Package '
|
||||
f'names must be unique in a project. Please rename '
|
||||
f'one of these packages.'
|
||||
f"names must be unique in a project. Please rename "
|
||||
f"one of these packages."
|
||||
)
|
||||
all_projects[project_name] = project
|
||||
self.dependencies = all_projects
|
||||
@@ -347,14 +339,14 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
||||
|
||||
def load_projects(
|
||||
self, paths: Iterable[Path]
|
||||
) -> Iterator[Tuple[str, 'RuntimeConfig']]:
|
||||
) -> Iterator[Tuple[str, "RuntimeConfig"]]:
|
||||
for path in paths:
|
||||
try:
|
||||
project = self.new_project(str(path))
|
||||
except DbtProjectError as e:
|
||||
raise DbtProjectError(
|
||||
f'Failed to read package: {e}',
|
||||
result_type='invalid_project',
|
||||
f"Failed to read package: {e}",
|
||||
result_type="invalid_project",
|
||||
path=path,
|
||||
) from e
|
||||
else:
|
||||
@@ -365,13 +357,13 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
||||
|
||||
if root.exists():
|
||||
for path in root.iterdir():
|
||||
if path.is_dir() and not path.name.startswith('__'):
|
||||
if path.is_dir() and not path.name.startswith("__"):
|
||||
yield path
|
||||
|
||||
|
||||
class UnsetCredentials(Credentials):
|
||||
def __init__(self):
|
||||
super().__init__('', '')
|
||||
super().__init__("", "")
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
@@ -387,9 +379,7 @@ class UnsetCredentials(Credentials):
|
||||
class UnsetConfig(UserConfig):
|
||||
def __getattribute__(self, name):
|
||||
if name in {f.name for f in fields(UserConfig)}:
|
||||
raise AttributeError(
|
||||
f"'UnsetConfig' object has no attribute {name}"
|
||||
)
|
||||
raise AttributeError(f"'UnsetConfig' object has no attribute {name}")
|
||||
|
||||
def __post_serialize__(self, dct):
|
||||
return {}
|
||||
@@ -399,15 +389,15 @@ class UnsetProfile(Profile):
|
||||
def __init__(self):
|
||||
self.credentials = UnsetCredentials()
|
||||
self.config = UnsetConfig()
|
||||
self.profile_name = ''
|
||||
self.target_name = ''
|
||||
self.profile_name = ""
|
||||
self.target_name = ""
|
||||
self.threads = -1
|
||||
|
||||
def to_target_dict(self):
|
||||
return {}
|
||||
|
||||
def __getattribute__(self, name):
|
||||
if name in {'profile_name', 'target_name', 'threads'}:
|
||||
if name in {"profile_name", "target_name", "threads"}:
|
||||
raise RuntimeException(
|
||||
f'Error: disallowed attribute "{name}" - no profile!'
|
||||
)
|
||||
@@ -431,7 +421,7 @@ class UnsetProfileConfig(RuntimeConfig):
|
||||
|
||||
def __getattribute__(self, name):
|
||||
# Override __getattribute__ to check that the attribute isn't 'banned'.
|
||||
if name in {'profile_name', 'target_name'}:
|
||||
if name in {"profile_name", "target_name"}:
|
||||
raise RuntimeException(
|
||||
f'Error: disallowed attribute "{name}" - no profile!'
|
||||
)
|
||||
@@ -449,8 +439,8 @@ class UnsetProfileConfig(RuntimeConfig):
|
||||
project: Project,
|
||||
profile: Profile,
|
||||
args: Any,
|
||||
dependencies: Optional[Mapping[str, 'RuntimeConfig']] = None,
|
||||
) -> 'RuntimeConfig':
|
||||
dependencies: Optional[Mapping[str, "RuntimeConfig"]] = None,
|
||||
) -> "RuntimeConfig":
|
||||
"""Instantiate a RuntimeConfig from its components.
|
||||
|
||||
:param profile: Ignored.
|
||||
@@ -458,7 +448,7 @@ class UnsetProfileConfig(RuntimeConfig):
|
||||
:param args: The parsed command-line arguments.
|
||||
:returns RuntimeConfig: The new configuration.
|
||||
"""
|
||||
cli_vars: Dict[str, Any] = parse_cli_vars(getattr(args, 'vars', '{}'))
|
||||
cli_vars: Dict[str, Any] = parse_cli_vars(getattr(args, "vars", "{}"))
|
||||
|
||||
return cls(
|
||||
project_name=project.project_name,
|
||||
@@ -491,10 +481,10 @@ class UnsetProfileConfig(RuntimeConfig):
|
||||
vars=project.vars,
|
||||
config_version=project.config_version,
|
||||
unrendered=project.unrendered,
|
||||
profile_name='',
|
||||
target_name='',
|
||||
profile_name="",
|
||||
target_name="",
|
||||
config=UnsetConfig(),
|
||||
threads=getattr(args, 'threads', 1),
|
||||
threads=getattr(args, "threads", 1),
|
||||
credentials=UnsetCredentials(),
|
||||
args=args,
|
||||
cli_vars=cli_vars,
|
||||
@@ -509,16 +499,11 @@ class UnsetProfileConfig(RuntimeConfig):
|
||||
profile_name: Optional[str],
|
||||
) -> Profile:
|
||||
try:
|
||||
profile = Profile.render_from_args(
|
||||
args, profile_renderer, profile_name
|
||||
)
|
||||
profile = Profile.render_from_args(args, profile_renderer, profile_name)
|
||||
except (DbtProjectError, DbtProfileError) as exc:
|
||||
logger.debug(
|
||||
'Profile not loaded due to error: {}', exc, exc_info=True
|
||||
)
|
||||
logger.debug("Profile not loaded due to error: {}", exc, exc_info=True)
|
||||
logger.info(
|
||||
'No profile "{}" found, continuing with no target',
|
||||
profile_name
|
||||
'No profile "{}" found, continuing with no target', profile_name
|
||||
)
|
||||
# return the poisoned form
|
||||
profile = UnsetProfile()
|
||||
@@ -527,7 +512,7 @@ class UnsetProfileConfig(RuntimeConfig):
|
||||
return profile
|
||||
|
||||
@classmethod
|
||||
def from_args(cls: Type[RuntimeConfig], args: Any) -> 'RuntimeConfig':
|
||||
def from_args(cls: Type[RuntimeConfig], args: Any) -> "RuntimeConfig":
|
||||
"""Given arguments, read in dbt_project.yml from the current directory,
|
||||
read in packages.yml if it exists, and use them to find the profile to
|
||||
load.
|
||||
@@ -542,11 +527,7 @@ class UnsetProfileConfig(RuntimeConfig):
|
||||
# if it's a real profile, return a real config
|
||||
cls = RuntimeConfig
|
||||
|
||||
return cls.from_parts(
|
||||
project=project,
|
||||
profile=profile,
|
||||
args=args
|
||||
)
|
||||
return cls.from_parts(project=project, profile=profile, args=args)
|
||||
|
||||
|
||||
UNUSED_RESOURCE_CONFIGURATION_PATH_MESSAGE = """\
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
from dbt.clients.yaml_helper import ( # noqa: F401
|
||||
yaml, Loader, Dumper, load_yaml_text
|
||||
)
|
||||
from dbt.clients.yaml_helper import yaml, Loader, Dumper, load_yaml_text # noqa: F401
|
||||
from dbt.dataclass_schema import ValidationError
|
||||
|
||||
from .renderer import SelectorRenderer
|
||||
@@ -30,9 +28,8 @@ Validator Error:
|
||||
|
||||
|
||||
class SelectorConfig(Dict[str, SelectionSpec]):
|
||||
|
||||
@classmethod
|
||||
def selectors_from_dict(cls, data: Dict[str, Any]) -> 'SelectorConfig':
|
||||
def selectors_from_dict(cls, data: Dict[str, Any]) -> "SelectorConfig":
|
||||
try:
|
||||
SelectorFile.validate(data)
|
||||
selector_file = SelectorFile.from_dict(data)
|
||||
@@ -45,12 +42,12 @@ class SelectorConfig(Dict[str, SelectionSpec]):
|
||||
f"union, intersection, string, dictionary. No lists. "
|
||||
f"\nhttps://docs.getdbt.com/reference/node-selection/"
|
||||
f"yaml-selectors",
|
||||
result_type='invalid_selector'
|
||||
result_type="invalid_selector",
|
||||
) from exc
|
||||
except RuntimeException as exc:
|
||||
raise DbtSelectorsError(
|
||||
f'Could not read selector file data: {exc}',
|
||||
result_type='invalid_selector',
|
||||
f"Could not read selector file data: {exc}",
|
||||
result_type="invalid_selector",
|
||||
) from exc
|
||||
|
||||
return cls(selectors)
|
||||
@@ -60,26 +57,28 @@ class SelectorConfig(Dict[str, SelectionSpec]):
|
||||
cls,
|
||||
data: Dict[str, Any],
|
||||
renderer: SelectorRenderer,
|
||||
) -> 'SelectorConfig':
|
||||
) -> "SelectorConfig":
|
||||
try:
|
||||
rendered = renderer.render_data(data)
|
||||
except (ValidationError, RuntimeException) as exc:
|
||||
raise DbtSelectorsError(
|
||||
f'Could not render selector data: {exc}',
|
||||
result_type='invalid_selector',
|
||||
f"Could not render selector data: {exc}",
|
||||
result_type="invalid_selector",
|
||||
) from exc
|
||||
return cls.selectors_from_dict(rendered)
|
||||
|
||||
@classmethod
|
||||
def from_path(
|
||||
cls, path: Path, renderer: SelectorRenderer,
|
||||
) -> 'SelectorConfig':
|
||||
cls,
|
||||
path: Path,
|
||||
renderer: SelectorRenderer,
|
||||
) -> "SelectorConfig":
|
||||
try:
|
||||
data = load_yaml_text(load_file_contents(str(path)))
|
||||
except (ValidationError, RuntimeException) as exc:
|
||||
raise DbtSelectorsError(
|
||||
f'Could not read selector file: {exc}',
|
||||
result_type='invalid_selector',
|
||||
f"Could not read selector file: {exc}",
|
||||
result_type="invalid_selector",
|
||||
path=path,
|
||||
) from exc
|
||||
|
||||
@@ -91,9 +90,7 @@ class SelectorConfig(Dict[str, SelectionSpec]):
|
||||
|
||||
|
||||
def selector_data_from_root(project_root: str) -> Dict[str, Any]:
|
||||
selector_filepath = resolve_path_from_base(
|
||||
'selectors.yml', project_root
|
||||
)
|
||||
selector_filepath = resolve_path_from_base("selectors.yml", project_root)
|
||||
|
||||
if path_exists(selector_filepath):
|
||||
selectors_dict = load_yaml_text(load_file_contents(selector_filepath))
|
||||
@@ -102,18 +99,16 @@ def selector_data_from_root(project_root: str) -> Dict[str, Any]:
|
||||
return selectors_dict
|
||||
|
||||
|
||||
def selector_config_from_data(
|
||||
selectors_data: Dict[str, Any]
|
||||
) -> SelectorConfig:
|
||||
def selector_config_from_data(selectors_data: Dict[str, Any]) -> SelectorConfig:
|
||||
if not selectors_data:
|
||||
selectors_data = {'selectors': []}
|
||||
selectors_data = {"selectors": []}
|
||||
|
||||
try:
|
||||
selectors = SelectorConfig.selectors_from_dict(selectors_data)
|
||||
except ValidationError as e:
|
||||
raise DbtSelectorsError(
|
||||
MALFORMED_SELECTOR_ERROR.format(error=str(e.message)),
|
||||
result_type='invalid_selector',
|
||||
result_type="invalid_selector",
|
||||
) from e
|
||||
return selectors
|
||||
|
||||
@@ -125,7 +120,6 @@ def selector_config_from_data(
|
||||
# be necessary to make changes here. Ideally it would be
|
||||
# good to combine the two flows into one at some point.
|
||||
class SelectorDict:
|
||||
|
||||
@classmethod
|
||||
def parse_dict_definition(cls, definition):
|
||||
key = list(definition)[0]
|
||||
@@ -136,10 +130,10 @@ class SelectorDict:
|
||||
new_value = cls.parse_from_definition(sel_def)
|
||||
new_values.append(new_value)
|
||||
value = new_values
|
||||
if key == 'exclude':
|
||||
if key == "exclude":
|
||||
definition = {key: value}
|
||||
elif len(definition) == 1:
|
||||
definition = {'method': key, 'value': value}
|
||||
definition = {"method": key, "value": value}
|
||||
return definition
|
||||
|
||||
@classmethod
|
||||
@@ -161,10 +155,10 @@ class SelectorDict:
|
||||
def parse_from_definition(cls, definition):
|
||||
if isinstance(definition, str):
|
||||
definition = SelectionCriteria.dict_from_single_spec(definition)
|
||||
elif 'union' in definition:
|
||||
definition = cls.parse_a_definition('union', definition)
|
||||
elif 'intersection' in definition:
|
||||
definition = cls.parse_a_definition('intersection', definition)
|
||||
elif "union" in definition:
|
||||
definition = cls.parse_a_definition("union", definition)
|
||||
elif "intersection" in definition:
|
||||
definition = cls.parse_a_definition("intersection", definition)
|
||||
elif isinstance(definition, dict):
|
||||
definition = cls.parse_dict_definition(definition)
|
||||
return definition
|
||||
@@ -175,8 +169,8 @@ class SelectorDict:
|
||||
def parse_from_selectors_list(cls, selectors):
|
||||
selector_dict = {}
|
||||
for selector in selectors:
|
||||
sel_name = selector['name']
|
||||
sel_name = selector["name"]
|
||||
selector_dict[sel_name] = selector
|
||||
definition = cls.parse_from_definition(selector['definition'])
|
||||
selector_dict[sel_name]['definition'] = definition
|
||||
definition = cls.parse_from_definition(selector["definition"])
|
||||
selector_dict[sel_name]["definition"] = definition
|
||||
return selector_dict
|
||||
|
||||
@@ -15,9 +15,8 @@ def parse_cli_vars(var_string: str) -> Dict[str, Any]:
|
||||
type_name = var_type.__name__
|
||||
raise_compiler_error(
|
||||
"The --vars argument must be a YAML dictionary, but was "
|
||||
"of type '{}'".format(type_name))
|
||||
except ValidationException:
|
||||
logger.error(
|
||||
"The YAML provided in the --vars argument is not valid.\n"
|
||||
"of type '{}'".format(type_name)
|
||||
)
|
||||
except ValidationException:
|
||||
logger.error("The YAML provided in the --vars argument is not valid.\n")
|
||||
raise
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
import json
|
||||
import os
|
||||
from typing import (
|
||||
Any, Dict, NoReturn, Optional, Mapping
|
||||
)
|
||||
from typing import Any, Dict, NoReturn, Optional, Mapping
|
||||
|
||||
from dbt import flags
|
||||
from dbt import tracking
|
||||
from dbt.clients.jinja import undefined_error, get_rendered
|
||||
from dbt.clients.yaml_helper import ( # noqa: F401
|
||||
yaml, safe_load, SafeLoader, Loader, Dumper
|
||||
yaml,
|
||||
safe_load,
|
||||
SafeLoader,
|
||||
Loader,
|
||||
Dumper,
|
||||
)
|
||||
from dbt.contracts.graph.compiled import CompiledResource
|
||||
from dbt.exceptions import raise_compiler_error, MacroReturn
|
||||
@@ -25,38 +27,26 @@ import re
|
||||
def get_pytz_module_context() -> Dict[str, Any]:
|
||||
context_exports = pytz.__all__ # type: ignore
|
||||
|
||||
return {
|
||||
name: getattr(pytz, name) for name in context_exports
|
||||
}
|
||||
return {name: getattr(pytz, name) for name in context_exports}
|
||||
|
||||
|
||||
def get_datetime_module_context() -> Dict[str, Any]:
|
||||
context_exports = [
|
||||
'date',
|
||||
'datetime',
|
||||
'time',
|
||||
'timedelta',
|
||||
'tzinfo'
|
||||
]
|
||||
context_exports = ["date", "datetime", "time", "timedelta", "tzinfo"]
|
||||
|
||||
return {
|
||||
name: getattr(datetime, name) for name in context_exports
|
||||
}
|
||||
return {name: getattr(datetime, name) for name in context_exports}
|
||||
|
||||
|
||||
def get_re_module_context() -> Dict[str, Any]:
|
||||
context_exports = re.__all__
|
||||
|
||||
return {
|
||||
name: getattr(re, name) for name in context_exports
|
||||
}
|
||||
return {name: getattr(re, name) for name in context_exports}
|
||||
|
||||
|
||||
def get_context_modules() -> Dict[str, Dict[str, Any]]:
|
||||
return {
|
||||
'pytz': get_pytz_module_context(),
|
||||
'datetime': get_datetime_module_context(),
|
||||
're': get_re_module_context(),
|
||||
"pytz": get_pytz_module_context(),
|
||||
"datetime": get_datetime_module_context(),
|
||||
"re": get_re_module_context(),
|
||||
}
|
||||
|
||||
|
||||
@@ -90,8 +80,8 @@ class ContextMeta(type):
|
||||
new_dct = {}
|
||||
|
||||
for base in bases:
|
||||
context_members.update(getattr(base, '_context_members_', {}))
|
||||
context_attrs.update(getattr(base, '_context_attrs_', {}))
|
||||
context_members.update(getattr(base, "_context_members_", {}))
|
||||
context_attrs.update(getattr(base, "_context_attrs_", {}))
|
||||
|
||||
for key, value in dct.items():
|
||||
if isinstance(value, ContextMember):
|
||||
@@ -100,21 +90,22 @@ class ContextMeta(type):
|
||||
context_attrs[context_key] = key
|
||||
value = value.inner
|
||||
new_dct[key] = value
|
||||
new_dct['_context_members_'] = context_members
|
||||
new_dct['_context_attrs_'] = context_attrs
|
||||
new_dct["_context_members_"] = context_members
|
||||
new_dct["_context_attrs_"] = context_attrs
|
||||
return type.__new__(mcls, name, bases, new_dct)
|
||||
|
||||
|
||||
class Var:
|
||||
UndefinedVarError = "Required var '{}' not found in config:\nVars "\
|
||||
"supplied to {} = {}"
|
||||
UndefinedVarError = (
|
||||
"Required var '{}' not found in config:\nVars " "supplied to {} = {}"
|
||||
)
|
||||
_VAR_NOTSET = object()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: Mapping[str, Any],
|
||||
cli_vars: Mapping[str, Any],
|
||||
node: Optional[CompiledResource] = None
|
||||
node: Optional[CompiledResource] = None,
|
||||
) -> None:
|
||||
self._context: Mapping[str, Any] = context
|
||||
self._cli_vars: Mapping[str, Any] = cli_vars
|
||||
@@ -129,14 +120,12 @@ class Var:
|
||||
if self._node is not None:
|
||||
return self._node.name
|
||||
else:
|
||||
return '<Configuration>'
|
||||
return "<Configuration>"
|
||||
|
||||
def get_missing_var(self, var_name):
|
||||
dct = {k: self._merged[k] for k in self._merged}
|
||||
pretty_vars = json.dumps(dct, sort_keys=True, indent=4)
|
||||
msg = self.UndefinedVarError.format(
|
||||
var_name, self.node_name, pretty_vars
|
||||
)
|
||||
msg = self.UndefinedVarError.format(var_name, self.node_name, pretty_vars)
|
||||
raise_compiler_error(msg, self._node)
|
||||
|
||||
def has_var(self, var_name: str):
|
||||
@@ -167,7 +156,7 @@ class BaseContext(metaclass=ContextMeta):
|
||||
def generate_builtins(self):
|
||||
builtins: Dict[str, Any] = {}
|
||||
for key, value in self._context_members_.items():
|
||||
if hasattr(value, '__get__'):
|
||||
if hasattr(value, "__get__"):
|
||||
# handle properties, bound methods, etc
|
||||
value = value.__get__(self)
|
||||
builtins[key] = value
|
||||
@@ -175,9 +164,9 @@ class BaseContext(metaclass=ContextMeta):
|
||||
|
||||
# no dbtClassMixin so this is not an actual override
|
||||
def to_dict(self):
|
||||
self._ctx['context'] = self._ctx
|
||||
self._ctx["context"] = self._ctx
|
||||
builtins = self.generate_builtins()
|
||||
self._ctx['builtins'] = builtins
|
||||
self._ctx["builtins"] = builtins
|
||||
self._ctx.update(builtins)
|
||||
return self._ctx
|
||||
|
||||
@@ -286,18 +275,20 @@ class BaseContext(metaclass=ContextMeta):
|
||||
msg = f"Env var required but not provided: '{var}'"
|
||||
undefined_error(msg)
|
||||
|
||||
if os.environ.get('DBT_MACRO_DEBUGGING'):
|
||||
if os.environ.get("DBT_MACRO_DEBUGGING"):
|
||||
|
||||
@contextmember
|
||||
@staticmethod
|
||||
def debug():
|
||||
"""Enter a debugger at this line in the compiled jinja code."""
|
||||
import sys
|
||||
import ipdb # type: ignore
|
||||
|
||||
frame = sys._getframe(3)
|
||||
ipdb.set_trace(frame)
|
||||
return ''
|
||||
return ""
|
||||
|
||||
@contextmember('return')
|
||||
@contextmember("return")
|
||||
@staticmethod
|
||||
def _return(data: Any) -> NoReturn:
|
||||
"""The `return` function can be used in macros to return data to the
|
||||
@@ -348,9 +339,7 @@ class BaseContext(metaclass=ContextMeta):
|
||||
|
||||
@contextmember
|
||||
@staticmethod
|
||||
def tojson(
|
||||
value: Any, default: Any = None, sort_keys: bool = False
|
||||
) -> Any:
|
||||
def tojson(value: Any, default: Any = None, sort_keys: bool = False) -> Any:
|
||||
"""The `tojson` context method can be used to serialize a Python
|
||||
object primitive, eg. a `dict` or `list` to a json string.
|
||||
|
||||
@@ -446,7 +435,7 @@ class BaseContext(metaclass=ContextMeta):
|
||||
logger.info(msg)
|
||||
else:
|
||||
logger.debug(msg)
|
||||
return ''
|
||||
return ""
|
||||
|
||||
@contextproperty
|
||||
def run_started_at(self) -> Optional[datetime.datetime]:
|
||||
@@ -4,16 +4,14 @@ from dbt.contracts.connection import AdapterRequiredConfig
|
||||
from dbt.node_types import NodeType
|
||||
from dbt.utils import MultiDict
|
||||
|
||||
from dbt.context.base import contextproperty, Var
|
||||
from dbt.context import contextproperty, Var
|
||||
from dbt.context.target import TargetContext
|
||||
|
||||
|
||||
class ConfiguredContext(TargetContext):
|
||||
config: AdapterRequiredConfig
|
||||
|
||||
def __init__(
|
||||
self, config: AdapterRequiredConfig
|
||||
) -> None:
|
||||
def __init__(self, config: AdapterRequiredConfig) -> None:
|
||||
super().__init__(config, config.cli_vars)
|
||||
|
||||
@contextproperty
|
||||
@@ -70,9 +68,7 @@ class SchemaYamlContext(ConfiguredContext):
|
||||
|
||||
@contextproperty
|
||||
def var(self) -> ConfiguredVar:
|
||||
return ConfiguredVar(
|
||||
self._ctx, self.config, self._project_name
|
||||
)
|
||||
return ConfiguredVar(self._ctx, self.config, self._project_name)
|
||||
|
||||
|
||||
def generate_schema_yml(
|
||||
|
||||
@@ -17,8 +17,8 @@ class ModelParts(IsFQNResource):
|
||||
package_name: str
|
||||
|
||||
|
||||
T = TypeVar('T') # any old type
|
||||
C = TypeVar('C', bound=BaseConfig)
|
||||
T = TypeVar("T") # any old type
|
||||
C = TypeVar("C", bound=BaseConfig)
|
||||
|
||||
|
||||
class ConfigSource:
|
||||
@@ -36,13 +36,13 @@ class UnrenderedConfig(ConfigSource):
|
||||
def get_config_dict(self, resource_type: NodeType) -> Dict[str, Any]:
|
||||
unrendered = self.project.unrendered.project_dict
|
||||
if resource_type == NodeType.Seed:
|
||||
model_configs = unrendered.get('seeds')
|
||||
model_configs = unrendered.get("seeds")
|
||||
elif resource_type == NodeType.Snapshot:
|
||||
model_configs = unrendered.get('snapshots')
|
||||
model_configs = unrendered.get("snapshots")
|
||||
elif resource_type == NodeType.Source:
|
||||
model_configs = unrendered.get('sources')
|
||||
model_configs = unrendered.get("sources")
|
||||
else:
|
||||
model_configs = unrendered.get('models')
|
||||
model_configs = unrendered.get("models")
|
||||
|
||||
if model_configs is None:
|
||||
return {}
|
||||
@@ -79,8 +79,8 @@ class BaseContextConfigGenerator(Generic[T]):
|
||||
dependencies = self._active_project.load_dependencies()
|
||||
if project_name not in dependencies:
|
||||
raise InternalException(
|
||||
f'Project name {project_name} not found in dependencies '
|
||||
f'(found {list(dependencies)})'
|
||||
f"Project name {project_name} not found in dependencies "
|
||||
f"(found {list(dependencies)})"
|
||||
)
|
||||
return dependencies[project_name]
|
||||
|
||||
@@ -92,7 +92,7 @@ class BaseContextConfigGenerator(Generic[T]):
|
||||
for level_config in fqn_search(model_configs, fqn):
|
||||
result = {}
|
||||
for key, value in level_config.items():
|
||||
if key.startswith('+'):
|
||||
if key.startswith("+"):
|
||||
result[key[1:]] = deepcopy(value)
|
||||
elif not isinstance(value, dict):
|
||||
result[key] = deepcopy(value)
|
||||
@@ -171,13 +171,9 @@ class ContextConfigGenerator(BaseContextConfigGenerator[C]):
|
||||
def _update_from_config(
|
||||
self, result: C, partial: Dict[str, Any], validate: bool = False
|
||||
) -> C:
|
||||
translated = self._active_project.credentials.translate_aliases(
|
||||
partial
|
||||
)
|
||||
translated = self._active_project.credentials.translate_aliases(partial)
|
||||
return result.update_from(
|
||||
translated,
|
||||
self._active_project.credentials.type,
|
||||
validate=validate
|
||||
translated, self._active_project.credentials.type, validate=validate
|
||||
)
|
||||
|
||||
def calculate_node_config_dict(
|
||||
@@ -219,11 +215,7 @@ class UnrenderedConfigGenerator(BaseContextConfigGenerator[Dict[str, Any]]):
|
||||
base=base,
|
||||
)
|
||||
|
||||
def initial_result(
|
||||
self,
|
||||
resource_type: NodeType,
|
||||
base: bool
|
||||
) -> Dict[str, Any]:
|
||||
def initial_result(self, resource_type: NodeType, base: bool) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
def _update_from_config(
|
||||
@@ -232,9 +224,7 @@ class UnrenderedConfigGenerator(BaseContextConfigGenerator[Dict[str, Any]]):
|
||||
partial: Dict[str, Any],
|
||||
validate: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
translated = self._active_project.credentials.translate_aliases(
|
||||
partial
|
||||
)
|
||||
translated = self._active_project.credentials.translate_aliases(partial)
|
||||
result.update(translated)
|
||||
return result
|
||||
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from typing import (
|
||||
Any, Dict, Union
|
||||
)
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from dbt.exceptions import (
|
||||
doc_invalid_args,
|
||||
@@ -11,7 +9,7 @@ from dbt.contracts.graph.compiled import CompileResultNode
|
||||
from dbt.contracts.graph.manifest import Manifest
|
||||
from dbt.contracts.graph.parsed import ParsedMacro
|
||||
|
||||
from dbt.context.base import contextmember
|
||||
from dbt.context import contextmember
|
||||
from dbt.context.configured import SchemaYamlContext
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from typing import (
|
||||
Dict, MutableMapping, Optional
|
||||
)
|
||||
from typing import Dict, MutableMapping, Optional
|
||||
from dbt.contracts.graph.parsed import ParsedMacro
|
||||
from dbt.exceptions import raise_duplicate_macro_name, raise_compiler_error
|
||||
from dbt.include.global_project import PROJECT_NAME as GLOBAL_PROJECT_NAME
|
||||
@@ -45,8 +43,7 @@ class MacroResolver:
|
||||
for pkg in reversed(self.internal_package_names):
|
||||
if pkg in self.internal_packages:
|
||||
# Turn the internal packages into a flat namespace
|
||||
self.internal_packages_namespace.update(
|
||||
self.internal_packages[pkg])
|
||||
self.internal_packages_namespace.update(self.internal_packages[pkg])
|
||||
|
||||
def _build_macros_by_name(self):
|
||||
macros_by_name = {}
|
||||
@@ -74,9 +71,7 @@ class MacroResolver:
|
||||
package_namespaces[macro.package_name] = namespace
|
||||
|
||||
if macro.name in namespace:
|
||||
raise_duplicate_macro_name(
|
||||
macro, macro, macro.package_name
|
||||
)
|
||||
raise_duplicate_macro_name(macro, macro, macro.package_name)
|
||||
package_namespaces[macro.package_name][macro.name] = macro
|
||||
|
||||
def add_macro(self, macro: ParsedMacro):
|
||||
@@ -99,8 +94,10 @@ class MacroResolver:
|
||||
|
||||
def get_macro_id(self, local_package, macro_name):
|
||||
local_package_macros = {}
|
||||
if (local_package not in self.internal_package_names and
|
||||
local_package in self.packages):
|
||||
if (
|
||||
local_package not in self.internal_package_names
|
||||
and local_package in self.packages
|
||||
):
|
||||
local_package_macros = self.packages[local_package]
|
||||
# First: search the local packages for this macro
|
||||
if macro_name in local_package_macros:
|
||||
@@ -117,9 +114,7 @@ class MacroResolver:
|
||||
# is that you can limit the number of macros provided to the
|
||||
# context dictionary in the 'to_dict' manifest method.
|
||||
class TestMacroNamespace:
|
||||
def __init__(
|
||||
self, macro_resolver, ctx, node, thread_ctx, depends_on_macros
|
||||
):
|
||||
def __init__(self, macro_resolver, ctx, node, thread_ctx, depends_on_macros):
|
||||
self.macro_resolver = macro_resolver
|
||||
self.ctx = ctx
|
||||
self.node = node
|
||||
@@ -129,7 +124,10 @@ class TestMacroNamespace:
|
||||
for macro_unique_id in depends_on_macros:
|
||||
macro = self.manifest.macros[macro_unique_id]
|
||||
local_namespace[macro.name] = MacroGenerator(
|
||||
macro, self.ctx, self.node, self.thread_ctx,
|
||||
macro,
|
||||
self.ctx,
|
||||
self.node,
|
||||
self.thread_ctx,
|
||||
)
|
||||
self.local_namespace = local_namespace
|
||||
|
||||
@@ -144,10 +142,6 @@ class TestMacroNamespace:
|
||||
elif package_name in self.resolver.packages:
|
||||
macro = self.macro_resolver.packages[package_name].get(name)
|
||||
else:
|
||||
raise_compiler_error(
|
||||
f"Could not find package '{package_name}'"
|
||||
)
|
||||
macro_func = MacroGenerator(
|
||||
macro, self.ctx, self.node, self.thread_ctx
|
||||
)
|
||||
raise_compiler_error(f"Could not find package '{package_name}'")
|
||||
macro_func = MacroGenerator(macro, self.ctx, self.node, self.thread_ctx)
|
||||
return macro_func
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
from typing import (
|
||||
Any, Dict, Iterable, Union, Optional, List, Iterator, Mapping, Set
|
||||
)
|
||||
from typing import Any, Dict, Iterable, Union, Optional, List, Iterator, Mapping, Set
|
||||
|
||||
from dbt.clients.jinja import MacroGenerator, MacroStack
|
||||
from dbt.contracts.graph.parsed import ParsedMacro
|
||||
from dbt.include.global_project import PROJECT_NAME as GLOBAL_PROJECT_NAME
|
||||
from dbt.exceptions import (
|
||||
raise_duplicate_macro_name, raise_compiler_error
|
||||
)
|
||||
from dbt.exceptions import raise_duplicate_macro_name, raise_compiler_error
|
||||
|
||||
|
||||
FlatNamespace = Dict[str, MacroGenerator]
|
||||
@@ -75,9 +71,7 @@ class MacroNamespace(Mapping):
|
||||
elif package_name in self.packages:
|
||||
return self.packages[package_name].get(name)
|
||||
else:
|
||||
raise_compiler_error(
|
||||
f"Could not find package '{package_name}'"
|
||||
)
|
||||
raise_compiler_error(f"Could not find package '{package_name}'")
|
||||
|
||||
|
||||
# This class builds the MacroNamespace by adding macros to
|
||||
@@ -122,9 +116,7 @@ class MacroNamespaceBuilder:
|
||||
hierarchy[macro.package_name] = namespace
|
||||
|
||||
if macro.name in namespace:
|
||||
raise_duplicate_macro_name(
|
||||
macro_func.macro, macro, macro.package_name
|
||||
)
|
||||
raise_duplicate_macro_name(macro_func.macro, macro, macro.package_name)
|
||||
hierarchy[macro.package_name][macro.name] = macro_func
|
||||
|
||||
def add_macro(self, macro: ParsedMacro, ctx: Dict[str, Any]):
|
||||
|
||||
@@ -17,6 +17,7 @@ class ManifestContext(ConfiguredContext):
|
||||
The given macros can override any previous context values, which will be
|
||||
available as if they were accessed relative to the package name.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: AdapterRequiredConfig,
|
||||
@@ -37,13 +38,12 @@ class ManifestContext(ConfiguredContext):
|
||||
# this takes all the macros in the manifest and adds them
|
||||
# to the MacroNamespaceBuilder stored in self.namespace
|
||||
builder = self._get_namespace_builder()
|
||||
return builder.build_namespace(
|
||||
self.manifest.macros.values(), self._ctx
|
||||
)
|
||||
return builder.build_namespace(self.manifest.macros.values(), self._ctx)
|
||||
|
||||
def _get_namespace_builder(self) -> MacroNamespaceBuilder:
|
||||
# avoid an import loop
|
||||
from dbt.adapters.factory import get_adapter_package_names
|
||||
|
||||
internal_packages: List[str] = get_adapter_package_names(
|
||||
self.config.credentials.type
|
||||
)
|
||||
@@ -68,14 +68,10 @@ class ManifestContext(ConfiguredContext):
|
||||
|
||||
|
||||
class QueryHeaderContext(ManifestContext):
|
||||
def __init__(
|
||||
self, config: AdapterRequiredConfig, manifest: Manifest
|
||||
) -> None:
|
||||
def __init__(self, config: AdapterRequiredConfig, manifest: Manifest) -> None:
|
||||
super().__init__(config, manifest, config.project_name)
|
||||
|
||||
|
||||
def generate_query_header_context(
|
||||
config: AdapterRequiredConfig, manifest: Manifest
|
||||
):
|
||||
def generate_query_header_context(config: AdapterRequiredConfig, manifest: Manifest):
|
||||
ctx = QueryHeaderContext(config, manifest)
|
||||
return ctx.to_dict()
|
||||
|
||||
@@ -1,7 +1,15 @@
|
||||
import abc
|
||||
import os
|
||||
from typing import (
|
||||
Callable, Any, Dict, Optional, Union, List, TypeVar, Type, Iterable,
|
||||
Callable,
|
||||
Any,
|
||||
Dict,
|
||||
Optional,
|
||||
Union,
|
||||
List,
|
||||
TypeVar,
|
||||
Type,
|
||||
Iterable,
|
||||
Mapping,
|
||||
)
|
||||
from typing_extensions import Protocol
|
||||
@@ -12,16 +20,14 @@ from dbt.adapters.factory import get_adapter, get_adapter_package_names
|
||||
from dbt.clients import agate_helper
|
||||
from dbt.clients.jinja import get_rendered, MacroGenerator, MacroStack
|
||||
from dbt.config import RuntimeConfig, Project
|
||||
from .base import contextmember, contextproperty, Var
|
||||
from dbt.context import contextmember, contextproperty, Var
|
||||
from .configured import FQNLookup
|
||||
from .context_config import ContextConfig
|
||||
from dbt.context.macro_resolver import MacroResolver, TestMacroNamespace
|
||||
from .macros import MacroNamespaceBuilder, MacroNamespace
|
||||
from .manifest import ManifestContext
|
||||
from dbt.contracts.connection import AdapterResponse
|
||||
from dbt.contracts.graph.manifest import (
|
||||
Manifest, AnyManifest, Disabled, MacroManifest
|
||||
)
|
||||
from dbt.contracts.graph.manifest import Manifest, AnyManifest, Disabled, MacroManifest
|
||||
from dbt.contracts.graph.compiled import (
|
||||
CompiledResource,
|
||||
CompiledSeedNode,
|
||||
@@ -50,9 +56,7 @@ from dbt.config import IsFQNResource
|
||||
from dbt.logger import GLOBAL_LOGGER as logger # noqa
|
||||
from dbt.node_types import NodeType
|
||||
|
||||
from dbt.utils import (
|
||||
merge, AttrDict, MultiDict
|
||||
)
|
||||
from dbt.utils import merge, AttrDict, MultiDict
|
||||
|
||||
import agate
|
||||
|
||||
@@ -75,9 +79,8 @@ class RelationProxy:
|
||||
return self._relation_type.create_from_source(*args, **kwargs)
|
||||
|
||||
def create(self, *args, **kwargs):
|
||||
kwargs['quote_policy'] = merge(
|
||||
self._quoting_config,
|
||||
kwargs.pop('quote_policy', {})
|
||||
kwargs["quote_policy"] = merge(
|
||||
self._quoting_config, kwargs.pop("quote_policy", {})
|
||||
)
|
||||
return self._relation_type.create(*args, **kwargs)
|
||||
|
||||
@@ -94,7 +97,7 @@ class BaseDatabaseWrapper:
|
||||
self._namespace = namespace
|
||||
|
||||
def __getattr__(self, name):
|
||||
raise NotImplementedError('subclasses need to implement this')
|
||||
raise NotImplementedError("subclasses need to implement this")
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
@@ -110,7 +113,7 @@ class BaseDatabaseWrapper:
|
||||
# a future version of this could have plugins automatically call fall
|
||||
# back to their dependencies' dependencies by using
|
||||
# `get_adapter_type_names` instead of `[self.config.credentials.type]`
|
||||
search_prefixes = [self._adapter.type(), 'default']
|
||||
search_prefixes = [self._adapter.type(), "default"]
|
||||
return search_prefixes
|
||||
|
||||
def dispatch(
|
||||
@@ -118,8 +121,8 @@ class BaseDatabaseWrapper:
|
||||
) -> MacroGenerator:
|
||||
search_packages: List[Optional[str]]
|
||||
|
||||
if '.' in macro_name:
|
||||
suggest_package, suggest_macro_name = macro_name.split('.', 1)
|
||||
if "." in macro_name:
|
||||
suggest_package, suggest_macro_name = macro_name.split(".", 1)
|
||||
msg = (
|
||||
f'In adapter.dispatch, got a macro name of "{macro_name}", '
|
||||
f'but "." is not a valid macro name component. Did you mean '
|
||||
@@ -132,7 +135,7 @@ class BaseDatabaseWrapper:
|
||||
search_packages = [None]
|
||||
elif isinstance(packages, str):
|
||||
raise CompilationException(
|
||||
f'In adapter.dispatch, got a string packages argument '
|
||||
f"In adapter.dispatch, got a string packages argument "
|
||||
f'("{packages}"), but packages should be None or a list.'
|
||||
)
|
||||
else:
|
||||
@@ -142,26 +145,24 @@ class BaseDatabaseWrapper:
|
||||
|
||||
for package_name in search_packages:
|
||||
for prefix in self._get_adapter_macro_prefixes():
|
||||
search_name = f'{prefix}__{macro_name}'
|
||||
search_name = f"{prefix}__{macro_name}"
|
||||
try:
|
||||
# this uses the namespace from the context
|
||||
macro = self._namespace.get_from_package(
|
||||
package_name, search_name
|
||||
)
|
||||
macro = self._namespace.get_from_package(package_name, search_name)
|
||||
except CompilationException as exc:
|
||||
raise CompilationException(
|
||||
f'In dispatch: {exc.msg}',
|
||||
f"In dispatch: {exc.msg}",
|
||||
) from exc
|
||||
|
||||
if package_name is None:
|
||||
attempts.append(search_name)
|
||||
else:
|
||||
attempts.append(f'{package_name}.{search_name}')
|
||||
attempts.append(f"{package_name}.{search_name}")
|
||||
|
||||
if macro is not None:
|
||||
return macro
|
||||
|
||||
searched = ', '.join(repr(a) for a in attempts)
|
||||
searched = ", ".join(repr(a) for a in attempts)
|
||||
msg = (
|
||||
f"In dispatch: No macro named '{macro_name}' found\n"
|
||||
f" Searched for: {searched}"
|
||||
@@ -191,14 +192,10 @@ class BaseResolver(metaclass=abc.ABCMeta):
|
||||
|
||||
class BaseRefResolver(BaseResolver):
|
||||
@abc.abstractmethod
|
||||
def resolve(
|
||||
self, name: str, package: Optional[str] = None
|
||||
) -> RelationProxy:
|
||||
def resolve(self, name: str, package: Optional[str] = None) -> RelationProxy:
|
||||
...
|
||||
|
||||
def _repack_args(
|
||||
self, name: str, package: Optional[str]
|
||||
) -> List[str]:
|
||||
def _repack_args(self, name: str, package: Optional[str]) -> List[str]:
|
||||
if package is None:
|
||||
return [name]
|
||||
else:
|
||||
@@ -207,14 +204,13 @@ class BaseRefResolver(BaseResolver):
|
||||
def validate_args(self, name: str, package: Optional[str]):
|
||||
if not isinstance(name, str):
|
||||
raise CompilationException(
|
||||
f'The name argument to ref() must be a string, got '
|
||||
f'{type(name)}'
|
||||
f"The name argument to ref() must be a string, got " f"{type(name)}"
|
||||
)
|
||||
|
||||
if package is not None and not isinstance(package, str):
|
||||
raise CompilationException(
|
||||
f'The package argument to ref() must be a string or None, got '
|
||||
f'{type(package)}'
|
||||
f"The package argument to ref() must be a string or None, got "
|
||||
f"{type(package)}"
|
||||
)
|
||||
|
||||
def __call__(self, *args: str) -> RelationProxy:
|
||||
@@ -239,20 +235,19 @@ class BaseSourceResolver(BaseResolver):
|
||||
def validate_args(self, source_name: str, table_name: str):
|
||||
if not isinstance(source_name, str):
|
||||
raise CompilationException(
|
||||
f'The source name (first) argument to source() must be a '
|
||||
f'string, got {type(source_name)}'
|
||||
f"The source name (first) argument to source() must be a "
|
||||
f"string, got {type(source_name)}"
|
||||
)
|
||||
if not isinstance(table_name, str):
|
||||
raise CompilationException(
|
||||
f'The table name (second) argument to source() must be a '
|
||||
f'string, got {type(table_name)}'
|
||||
f"The table name (second) argument to source() must be a "
|
||||
f"string, got {type(table_name)}"
|
||||
)
|
||||
|
||||
def __call__(self, *args: str) -> RelationProxy:
|
||||
if len(args) != 2:
|
||||
raise_compiler_error(
|
||||
f"source() takes exactly two arguments ({len(args)} given)",
|
||||
self.model
|
||||
f"source() takes exactly two arguments ({len(args)} given)", self.model
|
||||
)
|
||||
self.validate_args(args[0], args[1])
|
||||
return self.resolve(args[0], args[1])
|
||||
@@ -270,14 +265,15 @@ class ParseConfigObject(Config):
|
||||
self.context_config = context_config
|
||||
|
||||
def _transform_config(self, config):
|
||||
for oldkey in ('pre_hook', 'post_hook'):
|
||||
for oldkey in ("pre_hook", "post_hook"):
|
||||
if oldkey in config:
|
||||
newkey = oldkey.replace('_', '-')
|
||||
newkey = oldkey.replace("_", "-")
|
||||
if newkey in config:
|
||||
raise_compiler_error(
|
||||
'Invalid config, has conflicting keys "{}" and "{}"'
|
||||
.format(oldkey, newkey),
|
||||
self.model
|
||||
'Invalid config, has conflicting keys "{}" and "{}"'.format(
|
||||
oldkey, newkey
|
||||
),
|
||||
self.model,
|
||||
)
|
||||
config[newkey] = config.pop(oldkey)
|
||||
return config
|
||||
@@ -288,29 +284,25 @@ class ParseConfigObject(Config):
|
||||
elif len(args) == 0 and len(kwargs) > 0:
|
||||
opts = kwargs
|
||||
else:
|
||||
raise_compiler_error(
|
||||
"Invalid inline model config",
|
||||
self.model)
|
||||
raise_compiler_error("Invalid inline model config", self.model)
|
||||
|
||||
opts = self._transform_config(opts)
|
||||
|
||||
# it's ok to have a parse context with no context config, but you must
|
||||
# not call it!
|
||||
if self.context_config is None:
|
||||
raise RuntimeException(
|
||||
'At parse time, did not receive a context config'
|
||||
)
|
||||
raise RuntimeException("At parse time, did not receive a context config")
|
||||
self.context_config.update_in_model_config(opts)
|
||||
return ''
|
||||
return ""
|
||||
|
||||
def set(self, name, value):
|
||||
return self.__call__({name: value})
|
||||
|
||||
def require(self, name, validator=None):
|
||||
return ''
|
||||
return ""
|
||||
|
||||
def get(self, name, validator=None, default=None):
|
||||
return ''
|
||||
return ""
|
||||
|
||||
def persist_relation_docs(self) -> bool:
|
||||
return False
|
||||
@@ -320,14 +312,12 @@ class ParseConfigObject(Config):
|
||||
|
||||
|
||||
class RuntimeConfigObject(Config):
|
||||
def __init__(
|
||||
self, model, context_config: Optional[ContextConfig] = None
|
||||
):
|
||||
def __init__(self, model, context_config: Optional[ContextConfig] = None):
|
||||
self.model = model
|
||||
# we never use or get a config, only the parser cares
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return ''
|
||||
return ""
|
||||
|
||||
def set(self, name, value):
|
||||
return self.__call__({name: value})
|
||||
@@ -337,7 +327,7 @@ class RuntimeConfigObject(Config):
|
||||
|
||||
def _lookup(self, name, default=_MISSING):
|
||||
# if this is a macro, there might be no `model.config`.
|
||||
if not hasattr(self.model, 'config'):
|
||||
if not hasattr(self.model, "config"):
|
||||
result = default
|
||||
else:
|
||||
result = self.model.config.get(name, default)
|
||||
@@ -362,22 +352,24 @@ class RuntimeConfigObject(Config):
|
||||
return to_return
|
||||
|
||||
def persist_relation_docs(self) -> bool:
|
||||
persist_docs = self.get('persist_docs', default={})
|
||||
persist_docs = self.get("persist_docs", default={})
|
||||
if not isinstance(persist_docs, dict):
|
||||
raise_compiler_error(
|
||||
f"Invalid value provided for 'persist_docs'. Expected dict "
|
||||
f"but received {type(persist_docs)}")
|
||||
f"but received {type(persist_docs)}"
|
||||
)
|
||||
|
||||
return persist_docs.get('relation', False)
|
||||
return persist_docs.get("relation", False)
|
||||
|
||||
def persist_column_docs(self) -> bool:
|
||||
persist_docs = self.get('persist_docs', default={})
|
||||
persist_docs = self.get("persist_docs", default={})
|
||||
if not isinstance(persist_docs, dict):
|
||||
raise_compiler_error(
|
||||
f"Invalid value provided for 'persist_docs'. Expected dict "
|
||||
f"but received {type(persist_docs)}")
|
||||
f"but received {type(persist_docs)}"
|
||||
)
|
||||
|
||||
return persist_docs.get('columns', False)
|
||||
return persist_docs.get("columns", False)
|
||||
|
||||
|
||||
# `adapter` implementations
|
||||
@@ -387,8 +379,10 @@ class ParseDatabaseWrapper(BaseDatabaseWrapper):
|
||||
"""
|
||||
|
||||
def __getattr__(self, name):
|
||||
override = (name in self._adapter._available_ and
|
||||
name in self._adapter._parse_replacements_)
|
||||
override = (
|
||||
name in self._adapter._available_
|
||||
and name in self._adapter._parse_replacements_
|
||||
)
|
||||
|
||||
if override:
|
||||
return self._adapter._parse_replacements_[name]
|
||||
@@ -420,9 +414,7 @@ class RuntimeDatabaseWrapper(BaseDatabaseWrapper):
|
||||
|
||||
# `ref` implementations
|
||||
class ParseRefResolver(BaseRefResolver):
|
||||
def resolve(
|
||||
self, name: str, package: Optional[str] = None
|
||||
) -> RelationProxy:
|
||||
def resolve(self, name: str, package: Optional[str] = None) -> RelationProxy:
|
||||
self.model.refs.append(self._repack_args(name, package))
|
||||
|
||||
return self.Relation.create_from(self.config, self.model)
|
||||
@@ -452,22 +444,15 @@ class RuntimeRefResolver(BaseRefResolver):
|
||||
self.validate(target_model, target_name, target_package)
|
||||
return self.create_relation(target_model, target_name)
|
||||
|
||||
def create_relation(
|
||||
self, target_model: ManifestNode, name: str
|
||||
) -> RelationProxy:
|
||||
def create_relation(self, target_model: ManifestNode, name: str) -> RelationProxy:
|
||||
if target_model.is_ephemeral_model:
|
||||
self.model.set_cte(target_model.unique_id, None)
|
||||
return self.Relation.create_ephemeral_from_node(
|
||||
self.config, target_model
|
||||
)
|
||||
return self.Relation.create_ephemeral_from_node(self.config, target_model)
|
||||
else:
|
||||
return self.Relation.create_from(self.config, target_model)
|
||||
|
||||
def validate(
|
||||
self,
|
||||
resolved: ManifestNode,
|
||||
target_name: str,
|
||||
target_package: Optional[str]
|
||||
self, resolved: ManifestNode, target_name: str, target_package: Optional[str]
|
||||
) -> None:
|
||||
if resolved.unique_id not in self.model.depends_on.nodes:
|
||||
args = self._repack_args(target_name, target_package)
|
||||
@@ -483,16 +468,15 @@ class OperationRefResolver(RuntimeRefResolver):
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def create_relation(
|
||||
self, target_model: ManifestNode, name: str
|
||||
) -> RelationProxy:
|
||||
def create_relation(self, target_model: ManifestNode, name: str) -> RelationProxy:
|
||||
if target_model.is_ephemeral_model:
|
||||
# In operations, we can't ref() ephemeral nodes, because
|
||||
# ParsedMacros do not support set_cte
|
||||
raise_compiler_error(
|
||||
'Operations can not ref() ephemeral nodes, but {} is ephemeral'
|
||||
.format(target_model.name),
|
||||
self.model
|
||||
"Operations can not ref() ephemeral nodes, but {} is ephemeral".format(
|
||||
target_model.name
|
||||
),
|
||||
self.model,
|
||||
)
|
||||
else:
|
||||
return super().create_relation(target_model, name)
|
||||
@@ -544,8 +528,7 @@ class ModelConfiguredVar(Var):
|
||||
if package_name not in dependencies:
|
||||
# I don't think this is actually reachable
|
||||
raise_compiler_error(
|
||||
f'Node package named {package_name} not found!',
|
||||
self._node
|
||||
f"Node package named {package_name} not found!", self._node
|
||||
)
|
||||
yield dependencies[package_name]
|
||||
yield self._config
|
||||
@@ -617,7 +600,7 @@ class OperationProvider(RuntimeProvider):
|
||||
ref = OperationRefResolver
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
# Base context collection, used for parsing configs.
|
||||
@@ -631,9 +614,7 @@ class ProviderContext(ManifestContext):
|
||||
context_config: Optional[ContextConfig],
|
||||
) -> None:
|
||||
if provider is None:
|
||||
raise InternalException(
|
||||
f"Invalid provider given to context: {provider}"
|
||||
)
|
||||
raise InternalException(f"Invalid provider given to context: {provider}")
|
||||
# mypy appeasement - we know it'll be a RuntimeConfig
|
||||
self.config: RuntimeConfig
|
||||
self.model: Union[ParsedMacro, ManifestNode] = model
|
||||
@@ -643,16 +624,12 @@ class ProviderContext(ManifestContext):
|
||||
self.provider: Provider = provider
|
||||
self.adapter = get_adapter(self.config)
|
||||
# The macro namespace is used in creating the DatabaseWrapper
|
||||
self.db_wrapper = self.provider.DatabaseWrapper(
|
||||
self.adapter, self.namespace
|
||||
)
|
||||
self.db_wrapper = self.provider.DatabaseWrapper(self.adapter, self.namespace)
|
||||
|
||||
# This overrides the method in ManifestContext, and provides
|
||||
# a model, which the ManifestContext builder does not
|
||||
def _get_namespace_builder(self):
|
||||
internal_packages = get_adapter_package_names(
|
||||
self.config.credentials.type
|
||||
)
|
||||
internal_packages = get_adapter_package_names(self.config.credentials.type)
|
||||
return MacroNamespaceBuilder(
|
||||
self.config.project_name,
|
||||
self.search_package,
|
||||
@@ -671,19 +648,19 @@ class ProviderContext(ManifestContext):
|
||||
|
||||
@contextmember
|
||||
def store_result(
|
||||
self, name: str,
|
||||
response: Any,
|
||||
agate_table: Optional[agate.Table] = None
|
||||
self, name: str, response: Any, agate_table: Optional[agate.Table] = None
|
||||
) -> str:
|
||||
if agate_table is None:
|
||||
agate_table = agate_helper.empty_table()
|
||||
|
||||
self.sql_results[name] = AttrDict({
|
||||
'response': response,
|
||||
'data': agate_helper.as_matrix(agate_table),
|
||||
'table': agate_table
|
||||
})
|
||||
return ''
|
||||
self.sql_results[name] = AttrDict(
|
||||
{
|
||||
"response": response,
|
||||
"data": agate_helper.as_matrix(agate_table),
|
||||
"table": agate_table,
|
||||
}
|
||||
)
|
||||
return ""
|
||||
|
||||
@contextmember
|
||||
def store_raw_result(
|
||||
@@ -692,10 +669,11 @@ class ProviderContext(ManifestContext):
|
||||
message=Optional[str],
|
||||
code=Optional[str],
|
||||
rows_affected=Optional[str],
|
||||
agate_table: Optional[agate.Table] = None
|
||||
agate_table: Optional[agate.Table] = None,
|
||||
) -> str:
|
||||
response = AdapterResponse(
|
||||
_message=message, code=code, rows_affected=rows_affected)
|
||||
_message=message, code=code, rows_affected=rows_affected
|
||||
)
|
||||
return self.store_result(name, response, agate_table)
|
||||
|
||||
@contextproperty
|
||||
@@ -708,25 +686,28 @@ class ProviderContext(ManifestContext):
|
||||
elif value == arg:
|
||||
return
|
||||
raise ValidationException(
|
||||
'Expected value "{}" to be one of {}'
|
||||
.format(value, ','.join(map(str, args))))
|
||||
'Expected value "{}" to be one of {}'.format(
|
||||
value, ",".join(map(str, args))
|
||||
)
|
||||
)
|
||||
|
||||
return inner
|
||||
|
||||
return AttrDict({
|
||||
'any': validate_any,
|
||||
})
|
||||
return AttrDict(
|
||||
{
|
||||
"any": validate_any,
|
||||
}
|
||||
)
|
||||
|
||||
@contextmember
|
||||
def write(self, payload: str) -> str:
|
||||
# macros/source defs aren't 'writeable'.
|
||||
if isinstance(self.model, (ParsedMacro, ParsedSourceDefinition)):
|
||||
raise_compiler_error(
|
||||
'cannot "write" macros or sources'
|
||||
)
|
||||
raise_compiler_error('cannot "write" macros or sources')
|
||||
self.model.build_path = self.model.write_node(
|
||||
self.config.target_path, 'run', payload
|
||||
self.config.target_path, "run", payload
|
||||
)
|
||||
return ''
|
||||
return ""
|
||||
|
||||
@contextmember
|
||||
def render(self, string: str) -> str:
|
||||
@@ -739,20 +720,17 @@ class ProviderContext(ManifestContext):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception:
|
||||
raise_compiler_error(
|
||||
message_if_exception, self.model
|
||||
)
|
||||
raise_compiler_error(message_if_exception, self.model)
|
||||
|
||||
@contextmember
|
||||
def load_agate_table(self) -> agate.Table:
|
||||
if not isinstance(self.model, (ParsedSeedNode, CompiledSeedNode)):
|
||||
raise_compiler_error(
|
||||
'can only load_agate_table for seeds (got a {})'
|
||||
.format(self.model.resource_type)
|
||||
"can only load_agate_table for seeds (got a {})".format(
|
||||
self.model.resource_type
|
||||
)
|
||||
path = os.path.join(
|
||||
self.model.root_path, self.model.original_file_path
|
||||
)
|
||||
path = os.path.join(self.model.root_path, self.model.original_file_path)
|
||||
column_types = self.model.config.column_types
|
||||
try:
|
||||
table = agate_helper.from_csv(path, text_columns=column_types)
|
||||
@@ -810,7 +788,7 @@ class ProviderContext(ManifestContext):
|
||||
self.db_wrapper, self.model, self.config, self.manifest
|
||||
)
|
||||
|
||||
@contextproperty('config')
|
||||
@contextproperty("config")
|
||||
def ctx_config(self) -> Config:
|
||||
"""The `config` variable exists to handle end-user configuration for
|
||||
custom materializations. Configs like `unique_key` can be implemented
|
||||
@@ -982,7 +960,7 @@ class ProviderContext(ManifestContext):
|
||||
node=self.model,
|
||||
)
|
||||
|
||||
@contextproperty('adapter')
|
||||
@contextproperty("adapter")
|
||||
def ctx_adapter(self) -> BaseDatabaseWrapper:
|
||||
"""`adapter` is a wrapper around the internal database adapter used by
|
||||
dbt. It allows users to make calls to the database in their dbt models.
|
||||
@@ -994,8 +972,8 @@ class ProviderContext(ManifestContext):
|
||||
@contextproperty
|
||||
def api(self) -> Dict[str, Any]:
|
||||
return {
|
||||
'Relation': self.db_wrapper.Relation,
|
||||
'Column': self.adapter.Column,
|
||||
"Relation": self.db_wrapper.Relation,
|
||||
"Column": self.adapter.Column,
|
||||
}
|
||||
|
||||
@contextproperty
|
||||
@@ -1113,7 +1091,7 @@ class ProviderContext(ManifestContext):
|
||||
""" # noqa
|
||||
return self.manifest.flat_graph
|
||||
|
||||
@contextproperty('model')
|
||||
@contextproperty("model")
|
||||
def ctx_model(self) -> Dict[str, Any]:
|
||||
return self.model.to_dict(omit_none=True)
|
||||
|
||||
@@ -1177,22 +1155,20 @@ class ProviderContext(ManifestContext):
|
||||
...
|
||||
{%- endmacro %}
|
||||
"""
|
||||
deprecations.warn('adapter-macro', macro_name=name)
|
||||
deprecations.warn("adapter-macro", macro_name=name)
|
||||
original_name = name
|
||||
package_names: Optional[List[str]] = None
|
||||
if '.' in name:
|
||||
package_name, name = name.split('.', 1)
|
||||
if "." in name:
|
||||
package_name, name = name.split(".", 1)
|
||||
package_names = [package_name]
|
||||
|
||||
try:
|
||||
macro = self.db_wrapper.dispatch(
|
||||
macro_name=name, packages=package_names
|
||||
)
|
||||
macro = self.db_wrapper.dispatch(macro_name=name, packages=package_names)
|
||||
except CompilationException as exc:
|
||||
raise CompilationException(
|
||||
f'In adapter_macro: {exc.msg}\n'
|
||||
f"In adapter_macro: {exc.msg}\n"
|
||||
f" Original name: '{original_name}'",
|
||||
node=self.model
|
||||
node=self.model,
|
||||
) from exc
|
||||
return macro(*args, **kwargs)
|
||||
|
||||
@@ -1230,35 +1206,27 @@ class ModelContext(ProviderContext):
|
||||
def pre_hooks(self) -> List[Dict[str, Any]]:
|
||||
if isinstance(self.model, ParsedSourceDefinition):
|
||||
return []
|
||||
return [
|
||||
h.to_dict(omit_none=True) for h in self.model.config.pre_hook
|
||||
]
|
||||
return [h.to_dict(omit_none=True) for h in self.model.config.pre_hook]
|
||||
|
||||
@contextproperty
|
||||
def post_hooks(self) -> List[Dict[str, Any]]:
|
||||
if isinstance(self.model, ParsedSourceDefinition):
|
||||
return []
|
||||
return [
|
||||
h.to_dict(omit_none=True) for h in self.model.config.post_hook
|
||||
]
|
||||
return [h.to_dict(omit_none=True) for h in self.model.config.post_hook]
|
||||
|
||||
@contextproperty
|
||||
def sql(self) -> Optional[str]:
|
||||
if getattr(self.model, 'extra_ctes_injected', None):
|
||||
if getattr(self.model, "extra_ctes_injected", None):
|
||||
return self.model.compiled_sql
|
||||
return None
|
||||
|
||||
@contextproperty
|
||||
def database(self) -> str:
|
||||
return getattr(
|
||||
self.model, 'database', self.config.credentials.database
|
||||
)
|
||||
return getattr(self.model, "database", self.config.credentials.database)
|
||||
|
||||
@contextproperty
|
||||
def schema(self) -> str:
|
||||
return getattr(
|
||||
self.model, 'schema', self.config.credentials.schema
|
||||
)
|
||||
return getattr(self.model, "schema", self.config.credentials.schema)
|
||||
|
||||
@contextproperty
|
||||
def this(self) -> Optional[RelationProxy]:
|
||||
@@ -1306,9 +1274,7 @@ def generate_parser_model(
|
||||
# The __init__ method of ModelContext also initializes
|
||||
# a ManifestContext object which creates a MacroNamespaceBuilder
|
||||
# which adds every macro in the Manifest.
|
||||
ctx = ModelContext(
|
||||
model, config, manifest, ParseProvider(), context_config
|
||||
)
|
||||
ctx = ModelContext(model, config, manifest, ParseProvider(), context_config)
|
||||
# The 'to_dict' method in ManifestContext moves all of the macro names
|
||||
# in the macro 'namespace' up to top level keys
|
||||
return ctx.to_dict()
|
||||
@@ -1319,9 +1285,7 @@ def generate_generate_component_name_macro(
|
||||
config: RuntimeConfig,
|
||||
manifest: MacroManifest,
|
||||
) -> Dict[str, Any]:
|
||||
ctx = MacroContext(
|
||||
macro, config, manifest, GenerateNameProvider(), None
|
||||
)
|
||||
ctx = MacroContext(macro, config, manifest, GenerateNameProvider(), None)
|
||||
return ctx.to_dict()
|
||||
|
||||
|
||||
@@ -1330,9 +1294,7 @@ def generate_runtime_model(
|
||||
config: RuntimeConfig,
|
||||
manifest: Manifest,
|
||||
) -> Dict[str, Any]:
|
||||
ctx = ModelContext(
|
||||
model, config, manifest, RuntimeProvider(), None
|
||||
)
|
||||
ctx = ModelContext(model, config, manifest, RuntimeProvider(), None)
|
||||
return ctx.to_dict()
|
||||
|
||||
|
||||
@@ -1342,9 +1304,7 @@ def generate_runtime_macro(
|
||||
manifest: Manifest,
|
||||
package_name: Optional[str],
|
||||
) -> Dict[str, Any]:
|
||||
ctx = MacroContext(
|
||||
macro, config, manifest, OperationProvider(), package_name
|
||||
)
|
||||
ctx = MacroContext(macro, config, manifest, OperationProvider(), package_name)
|
||||
return ctx.to_dict()
|
||||
|
||||
|
||||
@@ -1353,18 +1313,17 @@ class ExposureRefResolver(BaseResolver):
|
||||
if len(args) not in (1, 2):
|
||||
ref_invalid_args(self.model, args)
|
||||
self.model.refs.append(list(args))
|
||||
return ''
|
||||
return ""
|
||||
|
||||
|
||||
class ExposureSourceResolver(BaseResolver):
|
||||
def __call__(self, *args) -> str:
|
||||
if len(args) != 2:
|
||||
raise_compiler_error(
|
||||
f"source() takes exactly two arguments ({len(args)} given)",
|
||||
self.model
|
||||
f"source() takes exactly two arguments ({len(args)} given)", self.model
|
||||
)
|
||||
self.model.sources.append(list(args))
|
||||
return ''
|
||||
return ""
|
||||
|
||||
|
||||
def generate_parse_exposure(
|
||||
@@ -1375,18 +1334,18 @@ def generate_parse_exposure(
|
||||
) -> Dict[str, Any]:
|
||||
project = config.load_dependencies()[package_name]
|
||||
return {
|
||||
'ref': ExposureRefResolver(
|
||||
"ref": ExposureRefResolver(
|
||||
None,
|
||||
exposure,
|
||||
project,
|
||||
manifest,
|
||||
),
|
||||
'source': ExposureSourceResolver(
|
||||
"source": ExposureSourceResolver(
|
||||
None,
|
||||
exposure,
|
||||
project,
|
||||
manifest,
|
||||
)
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -1422,8 +1381,7 @@ class TestContext(ProviderContext):
|
||||
if self.model.depends_on and self.model.depends_on.macros:
|
||||
depends_on_macros = self.model.depends_on.macros
|
||||
macro_namespace = TestMacroNamespace(
|
||||
self.macro_resolver, self.ctx, self.node, self.thread_ctx,
|
||||
depends_on_macros
|
||||
self.macro_resolver, self.ctx, self.node, self.thread_ctx, depends_on_macros
|
||||
)
|
||||
self._namespace = macro_namespace
|
||||
|
||||
@@ -1433,11 +1391,10 @@ def generate_test_context(
|
||||
config: RuntimeConfig,
|
||||
manifest: Manifest,
|
||||
context_config: ContextConfig,
|
||||
macro_resolver: MacroResolver
|
||||
macro_resolver: MacroResolver,
|
||||
) -> Dict[str, Any]:
|
||||
ctx = TestContext(
|
||||
model, config, manifest, ParseProvider(), context_config,
|
||||
macro_resolver
|
||||
model, config, manifest, ParseProvider(), context_config, macro_resolver
|
||||
)
|
||||
# The 'to_dict' method in ManifestContext moves all of the macro names
|
||||
# in the macro 'namespace' up to top level keys
|
||||
|
||||
@@ -2,9 +2,7 @@ from typing import Any, Dict
|
||||
|
||||
from dbt.contracts.connection import HasCredentials
|
||||
|
||||
from dbt.context.base import (
|
||||
BaseContext, contextproperty
|
||||
)
|
||||
from dbt.context import BaseContext, contextproperty
|
||||
|
||||
|
||||
class TargetContext(BaseContext):
|
||||
|
||||
@@ -2,25 +2,35 @@ import abc
|
||||
import itertools
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (
|
||||
Any, ClassVar, Dict, Tuple, Iterable, Optional, List, Callable,
|
||||
Any,
|
||||
ClassVar,
|
||||
Dict,
|
||||
Tuple,
|
||||
Iterable,
|
||||
Optional,
|
||||
List,
|
||||
Callable,
|
||||
)
|
||||
from dbt.exceptions import InternalException
|
||||
from dbt.utils import translate_aliases
|
||||
from dbt.logger import GLOBAL_LOGGER as logger
|
||||
from typing_extensions import Protocol
|
||||
from dbt.dataclass_schema import (
|
||||
dbtClassMixin, StrEnum, ExtensibleDbtClassMixin,
|
||||
ValidatedStringMixin, register_pattern
|
||||
dbtClassMixin,
|
||||
StrEnum,
|
||||
ExtensibleDbtClassMixin,
|
||||
ValidatedStringMixin,
|
||||
register_pattern,
|
||||
)
|
||||
from dbt.contracts.util import Replaceable
|
||||
|
||||
|
||||
class Identifier(ValidatedStringMixin):
|
||||
ValidationRegex = r'^[A-Za-z_][A-Za-z0-9_]+$'
|
||||
ValidationRegex = r"^[A-Za-z_][A-Za-z0-9_]+$"
|
||||
|
||||
|
||||
# we need register_pattern for jsonschema validation
|
||||
register_pattern(Identifier, r'^[A-Za-z_][A-Za-z0-9_]+$')
|
||||
register_pattern(Identifier, r"^[A-Za-z_][A-Za-z0-9_]+$")
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -34,10 +44,10 @@ class AdapterResponse(dbtClassMixin):
|
||||
|
||||
|
||||
class ConnectionState(StrEnum):
|
||||
INIT = 'init'
|
||||
OPEN = 'open'
|
||||
CLOSED = 'closed'
|
||||
FAIL = 'fail'
|
||||
INIT = "init"
|
||||
OPEN = "open"
|
||||
CLOSED = "closed"
|
||||
FAIL = "fail"
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
@@ -81,8 +91,7 @@ class Connection(ExtensibleDbtClassMixin, Replaceable):
|
||||
self._handle.resolve(self)
|
||||
except RecursionError as exc:
|
||||
raise InternalException(
|
||||
"A connection's open() method attempted to read the "
|
||||
"handle value"
|
||||
"A connection's open() method attempted to read the " "handle value"
|
||||
) from exc
|
||||
return self._handle
|
||||
|
||||
@@ -101,8 +110,7 @@ class LazyHandle:
|
||||
|
||||
def resolve(self, connection: Connection) -> Connection:
|
||||
logger.debug(
|
||||
'Opening a new connection, currently in state {}'
|
||||
.format(connection.state)
|
||||
"Opening a new connection, currently in state {}".format(connection.state)
|
||||
)
|
||||
return self.opener(connection)
|
||||
|
||||
@@ -112,33 +120,24 @@ class LazyHandle:
|
||||
# for why we have type: ignore. Maybe someday dataclasses + abstract classes
|
||||
# will work.
|
||||
@dataclass # type: ignore
|
||||
class Credentials(
|
||||
ExtensibleDbtClassMixin,
|
||||
Replaceable,
|
||||
metaclass=abc.ABCMeta
|
||||
):
|
||||
class Credentials(ExtensibleDbtClassMixin, Replaceable, metaclass=abc.ABCMeta):
|
||||
database: str
|
||||
schema: str
|
||||
_ALIASES: ClassVar[Dict[str, str]] = field(default={}, init=False)
|
||||
|
||||
@abc.abstractproperty
|
||||
def type(self) -> str:
|
||||
raise NotImplementedError(
|
||||
'type not implemented for base credentials class'
|
||||
)
|
||||
raise NotImplementedError("type not implemented for base credentials class")
|
||||
|
||||
def connection_info(
|
||||
self, *, with_aliases: bool = False
|
||||
) -> Iterable[Tuple[str, Any]]:
|
||||
"""Return an ordered iterator of key/value pairs for pretty-printing.
|
||||
"""
|
||||
"""Return an ordered iterator of key/value pairs for pretty-printing."""
|
||||
as_dict = self.to_dict(omit_none=False)
|
||||
connection_keys = set(self._connection_keys())
|
||||
aliases: List[str] = []
|
||||
if with_aliases:
|
||||
aliases = [
|
||||
k for k, v in self._ALIASES.items() if v in connection_keys
|
||||
]
|
||||
aliases = [k for k, v in self._ALIASES.items() if v in connection_keys]
|
||||
for key in itertools.chain(self._connection_keys(), aliases):
|
||||
if key in as_dict:
|
||||
yield key, as_dict[key]
|
||||
@@ -162,11 +161,13 @@ class Credentials(
|
||||
def __post_serialize__(self, dct):
|
||||
# no super() -- do we need it?
|
||||
if self._ALIASES:
|
||||
dct.update({
|
||||
dct.update(
|
||||
{
|
||||
new_name: dct[canonical_name]
|
||||
for new_name, canonical_name in self._ALIASES.items()
|
||||
if canonical_name in dct
|
||||
})
|
||||
}
|
||||
)
|
||||
return dct
|
||||
|
||||
|
||||
@@ -188,10 +189,10 @@ class HasCredentials(Protocol):
|
||||
threads: int
|
||||
|
||||
def to_target_dict(self):
|
||||
raise NotImplementedError('to_target_dict not implemented')
|
||||
raise NotImplementedError("to_target_dict not implemented")
|
||||
|
||||
|
||||
DEFAULT_QUERY_COMMENT = '''
|
||||
DEFAULT_QUERY_COMMENT = """
|
||||
{%- set comment_dict = {} -%}
|
||||
{%- do comment_dict.update(
|
||||
app='dbt',
|
||||
@@ -208,7 +209,7 @@ DEFAULT_QUERY_COMMENT = '''
|
||||
{%- do comment_dict.update(connection_name=connection_name) -%}
|
||||
{%- endif -%}
|
||||
{{ return(tojson(comment_dict)) }}
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -11,7 +11,7 @@ from .util import MacroKey, SourceKey
|
||||
|
||||
|
||||
MAXIMUM_SEED_SIZE = 1 * 1024 * 1024
|
||||
MAXIMUM_SEED_SIZE_NAME = '1MB'
|
||||
MAXIMUM_SEED_SIZE_NAME = "1MB"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -28,9 +28,7 @@ class FilePath(dbtClassMixin):
|
||||
@property
|
||||
def full_path(self) -> str:
|
||||
# useful for symlink preservation
|
||||
return os.path.join(
|
||||
self.project_root, self.searched_path, self.relative_path
|
||||
)
|
||||
return os.path.join(self.project_root, self.searched_path, self.relative_path)
|
||||
|
||||
@property
|
||||
def absolute_path(self) -> str:
|
||||
@@ -40,13 +38,10 @@ class FilePath(dbtClassMixin):
|
||||
def original_file_path(self) -> str:
|
||||
# this is mostly used for reporting errors. It doesn't show the project
|
||||
# name, should it?
|
||||
return os.path.join(
|
||||
self.searched_path, self.relative_path
|
||||
)
|
||||
return os.path.join(self.searched_path, self.relative_path)
|
||||
|
||||
def seed_too_large(self) -> bool:
|
||||
"""Return whether the file this represents is over the seed size limit
|
||||
"""
|
||||
"""Return whether the file this represents is over the seed size limit"""
|
||||
return os.stat(self.full_path).st_size > MAXIMUM_SEED_SIZE
|
||||
|
||||
|
||||
@@ -57,35 +52,35 @@ class FileHash(dbtClassMixin):
|
||||
|
||||
@classmethod
|
||||
def empty(cls):
|
||||
return FileHash(name='none', checksum='')
|
||||
return FileHash(name="none", checksum="")
|
||||
|
||||
@classmethod
|
||||
def path(cls, path: str):
|
||||
return FileHash(name='path', checksum=path)
|
||||
return FileHash(name="path", checksum=path)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, FileHash):
|
||||
return NotImplemented
|
||||
|
||||
if self.name == 'none' or self.name != other.name:
|
||||
if self.name == "none" or self.name != other.name:
|
||||
return False
|
||||
|
||||
return self.checksum == other.checksum
|
||||
|
||||
def compare(self, contents: str) -> bool:
|
||||
"""Compare the file contents with the given hash"""
|
||||
if self.name == 'none':
|
||||
if self.name == "none":
|
||||
return False
|
||||
|
||||
return self.from_contents(contents, name=self.name) == self.checksum
|
||||
|
||||
@classmethod
|
||||
def from_contents(cls, contents: str, name='sha256') -> 'FileHash':
|
||||
def from_contents(cls, contents: str, name="sha256") -> "FileHash":
|
||||
"""Create a file hash from the given file contents. The hash is always
|
||||
the utf-8 encoding of the contents given, because dbt only reads files
|
||||
as utf-8.
|
||||
"""
|
||||
data = contents.encode('utf-8')
|
||||
data = contents.encode("utf-8")
|
||||
checksum = hashlib.new(name, data).hexdigest()
|
||||
return cls(name=name, checksum=checksum)
|
||||
|
||||
@@ -94,24 +89,25 @@ class FileHash(dbtClassMixin):
|
||||
class RemoteFile(dbtClassMixin):
|
||||
@property
|
||||
def searched_path(self) -> str:
|
||||
return 'from remote system'
|
||||
return "from remote system"
|
||||
|
||||
@property
|
||||
def relative_path(self) -> str:
|
||||
return 'from remote system'
|
||||
return "from remote system"
|
||||
|
||||
@property
|
||||
def absolute_path(self) -> str:
|
||||
return 'from remote system'
|
||||
return "from remote system"
|
||||
|
||||
@property
|
||||
def original_file_path(self):
|
||||
return 'from remote system'
|
||||
return "from remote system"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SourceFile(dbtClassMixin):
|
||||
"""Define a source file in dbt"""
|
||||
|
||||
path: Union[FilePath, RemoteFile] # the path information
|
||||
checksum: FileHash
|
||||
# we don't want to serialize this
|
||||
@@ -133,14 +129,14 @@ class SourceFile(dbtClassMixin):
|
||||
def search_key(self) -> Optional[str]:
|
||||
if isinstance(self.path, RemoteFile):
|
||||
return None
|
||||
if self.checksum.name == 'none':
|
||||
if self.checksum.name == "none":
|
||||
return None
|
||||
return self.path.search_key
|
||||
|
||||
@property
|
||||
def contents(self) -> str:
|
||||
if self._contents is None:
|
||||
raise InternalException('SourceFile has no contents!')
|
||||
raise InternalException("SourceFile has no contents!")
|
||||
return self._contents
|
||||
|
||||
@contents.setter
|
||||
@@ -148,20 +144,20 @@ class SourceFile(dbtClassMixin):
|
||||
self._contents = value
|
||||
|
||||
@classmethod
|
||||
def empty(cls, path: FilePath) -> 'SourceFile':
|
||||
def empty(cls, path: FilePath) -> "SourceFile":
|
||||
self = cls(path=path, checksum=FileHash.empty())
|
||||
self.contents = ''
|
||||
self.contents = ""
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def big_seed(cls, path: FilePath) -> 'SourceFile':
|
||||
def big_seed(cls, path: FilePath) -> "SourceFile":
|
||||
"""Parse seeds over the size limit with just the path"""
|
||||
self = cls(path=path, checksum=FileHash.path(path.original_file_path))
|
||||
self.contents = ''
|
||||
self.contents = ""
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def remote(cls, contents: str) -> 'SourceFile':
|
||||
def remote(cls, contents: str) -> "SourceFile":
|
||||
self = cls(path=RemoteFile(), checksum=FileHash.empty())
|
||||
self.contents = contents
|
||||
return self
|
||||
|
||||
@@ -58,31 +58,29 @@ class CompiledNode(ParsedNode, CompiledNodeMixin):
|
||||
|
||||
@dataclass
|
||||
class CompiledAnalysisNode(CompiledNode):
|
||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Analysis]})
|
||||
resource_type: NodeType = field(metadata={"restrict": [NodeType.Analysis]})
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompiledHookNode(CompiledNode):
|
||||
resource_type: NodeType = field(
|
||||
metadata={'restrict': [NodeType.Operation]}
|
||||
)
|
||||
resource_type: NodeType = field(metadata={"restrict": [NodeType.Operation]})
|
||||
index: Optional[int] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompiledModelNode(CompiledNode):
|
||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Model]})
|
||||
resource_type: NodeType = field(metadata={"restrict": [NodeType.Model]})
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompiledRPCNode(CompiledNode):
|
||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.RPCCall]})
|
||||
resource_type: NodeType = field(metadata={"restrict": [NodeType.RPCCall]})
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompiledSeedNode(CompiledNode):
|
||||
# keep this in sync with ParsedSeedNode!
|
||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Seed]})
|
||||
resource_type: NodeType = field(metadata={"restrict": [NodeType.Seed]})
|
||||
config: SeedConfig = field(default_factory=SeedConfig)
|
||||
|
||||
@property
|
||||
@@ -96,26 +94,25 @@ class CompiledSeedNode(CompiledNode):
|
||||
|
||||
@dataclass
|
||||
class CompiledSnapshotNode(CompiledNode):
|
||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Snapshot]})
|
||||
resource_type: NodeType = field(metadata={"restrict": [NodeType.Snapshot]})
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompiledDataTestNode(CompiledNode):
|
||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Test]})
|
||||
resource_type: NodeType = field(metadata={"restrict": [NodeType.Test]})
|
||||
config: TestConfig = field(default_factory=TestConfig)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompiledSchemaTestNode(CompiledNode, HasTestMetadata):
|
||||
# keep this in sync with ParsedSchemaTestNode!
|
||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Test]})
|
||||
resource_type: NodeType = field(metadata={"restrict": [NodeType.Test]})
|
||||
column_name: Optional[str] = None
|
||||
config: TestConfig = field(default_factory=TestConfig)
|
||||
|
||||
def same_config(self, other) -> bool:
|
||||
return (
|
||||
self.unrendered_config.get('severity') ==
|
||||
other.unrendered_config.get('severity')
|
||||
return self.unrendered_config.get("severity") == other.unrendered_config.get(
|
||||
"severity"
|
||||
)
|
||||
|
||||
def same_column_name(self, other) -> bool:
|
||||
@@ -125,11 +122,7 @@ class CompiledSchemaTestNode(CompiledNode, HasTestMetadata):
|
||||
if other is None:
|
||||
return False
|
||||
|
||||
return (
|
||||
self.same_config(other) and
|
||||
self.same_fqn(other) and
|
||||
True
|
||||
)
|
||||
return self.same_config(other) and self.same_fqn(other) and True
|
||||
|
||||
|
||||
CompiledTestNode = Union[CompiledDataTestNode, CompiledSchemaTestNode]
|
||||
@@ -175,8 +168,7 @@ def parsed_instance_for(compiled: CompiledNode) -> ParsedResource:
|
||||
cls = PARSED_TYPES.get(type(compiled))
|
||||
if cls is None:
|
||||
# how???
|
||||
raise ValueError('invalid resource_type: {}'
|
||||
.format(compiled.resource_type))
|
||||
raise ValueError("invalid resource_type: {}".format(compiled.resource_type))
|
||||
|
||||
return cls.from_dict(compiled.to_dict(omit_none=True))
|
||||
|
||||
|
||||
@@ -4,25 +4,51 @@ from dataclasses import dataclass, field
|
||||
from itertools import chain, islice
|
||||
from multiprocessing.synchronize import Lock
|
||||
from typing import (
|
||||
Dict, List, Optional, Union, Mapping, MutableMapping, Any, Set, Tuple,
|
||||
TypeVar, Callable, Iterable, Generic, cast, AbstractSet
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Union,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
Any,
|
||||
Set,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Callable,
|
||||
Iterable,
|
||||
Generic,
|
||||
cast,
|
||||
AbstractSet,
|
||||
)
|
||||
from typing_extensions import Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from dbt.contracts.graph.compiled import (
|
||||
CompileResultNode, ManifestNode, NonSourceCompiledNode, GraphMemberNode
|
||||
CompileResultNode,
|
||||
ManifestNode,
|
||||
NonSourceCompiledNode,
|
||||
GraphMemberNode,
|
||||
)
|
||||
from dbt.contracts.graph.parsed import (
|
||||
ParsedMacro, ParsedDocumentation, ParsedNodePatch, ParsedMacroPatch,
|
||||
ParsedSourceDefinition, ParsedExposure
|
||||
ParsedMacro,
|
||||
ParsedDocumentation,
|
||||
ParsedNodePatch,
|
||||
ParsedMacroPatch,
|
||||
ParsedSourceDefinition,
|
||||
ParsedExposure,
|
||||
)
|
||||
from dbt.contracts.files import SourceFile
|
||||
from dbt.contracts.util import (
|
||||
BaseArtifactMetadata, MacroKey, SourceKey, ArtifactMixin, schema_version
|
||||
BaseArtifactMetadata,
|
||||
MacroKey,
|
||||
SourceKey,
|
||||
ArtifactMixin,
|
||||
schema_version,
|
||||
)
|
||||
from dbt.exceptions import (
|
||||
raise_duplicate_resource_name, raise_compiler_error, warn_or_error,
|
||||
raise_duplicate_resource_name,
|
||||
raise_compiler_error,
|
||||
warn_or_error,
|
||||
raise_invalid_patch,
|
||||
)
|
||||
from dbt.helper_types import PathSet
|
||||
@@ -40,12 +66,12 @@ RefName = str
|
||||
UniqueID = str
|
||||
|
||||
|
||||
K_T = TypeVar('K_T')
|
||||
V_T = TypeVar('V_T')
|
||||
K_T = TypeVar("K_T")
|
||||
V_T = TypeVar("V_T")
|
||||
|
||||
|
||||
class PackageAwareCache(Generic[K_T, V_T]):
|
||||
def __init__(self, manifest: 'Manifest'):
|
||||
def __init__(self, manifest: "Manifest"):
|
||||
self.storage: Dict[K_T, Dict[PackageName, UniqueID]] = {}
|
||||
self._manifest = manifest
|
||||
self.populate()
|
||||
@@ -95,12 +121,10 @@ class DocCache(PackageAwareCache[DocName, ParsedDocumentation]):
|
||||
for doc in self._manifest.docs.values():
|
||||
self.add_doc(doc)
|
||||
|
||||
def perform_lookup(
|
||||
self, unique_id: UniqueID
|
||||
) -> ParsedDocumentation:
|
||||
def perform_lookup(self, unique_id: UniqueID) -> ParsedDocumentation:
|
||||
if unique_id not in self._manifest.docs:
|
||||
raise dbt.exceptions.InternalException(
|
||||
f'Doc {unique_id} found in cache but not found in manifest'
|
||||
f"Doc {unique_id} found in cache but not found in manifest"
|
||||
)
|
||||
return self._manifest.docs[unique_id]
|
||||
|
||||
@@ -117,12 +141,10 @@ class SourceCache(PackageAwareCache[SourceKey, ParsedSourceDefinition]):
|
||||
for source in self._manifest.sources.values():
|
||||
self.add_source(source)
|
||||
|
||||
def perform_lookup(
|
||||
self, unique_id: UniqueID
|
||||
) -> ParsedSourceDefinition:
|
||||
def perform_lookup(self, unique_id: UniqueID) -> ParsedSourceDefinition:
|
||||
if unique_id not in self._manifest.sources:
|
||||
raise dbt.exceptions.InternalException(
|
||||
f'Source {unique_id} found in cache but not found in manifest'
|
||||
f"Source {unique_id} found in cache but not found in manifest"
|
||||
)
|
||||
return self._manifest.sources[unique_id]
|
||||
|
||||
@@ -131,7 +153,7 @@ class RefableCache(PackageAwareCache[RefName, ManifestNode]):
|
||||
# refables are actually unique, so the Dict[PackageName, UniqueID] will
|
||||
# only ever have exactly one value, but doing 3 dict lookups instead of 1
|
||||
# is not a big deal at all and retains consistency
|
||||
def __init__(self, manifest: 'Manifest'):
|
||||
def __init__(self, manifest: "Manifest"):
|
||||
self._cached_types = set(NodeType.refable())
|
||||
super().__init__(manifest)
|
||||
|
||||
@@ -145,12 +167,10 @@ class RefableCache(PackageAwareCache[RefName, ManifestNode]):
|
||||
for node in self._manifest.nodes.values():
|
||||
self.add_node(node)
|
||||
|
||||
def perform_lookup(
|
||||
self, unique_id: UniqueID
|
||||
) -> ManifestNode:
|
||||
def perform_lookup(self, unique_id: UniqueID) -> ManifestNode:
|
||||
if unique_id not in self._manifest.nodes:
|
||||
raise dbt.exceptions.InternalException(
|
||||
f'Node {unique_id} found in cache but not found in manifest'
|
||||
f"Node {unique_id} found in cache but not found in manifest"
|
||||
)
|
||||
return self._manifest.nodes[unique_id]
|
||||
|
||||
@@ -171,30 +191,31 @@ def _search_packages(
|
||||
@dataclass
|
||||
class ManifestMetadata(BaseArtifactMetadata):
|
||||
"""Metadata for the manifest."""
|
||||
|
||||
dbt_schema_version: str = field(
|
||||
default_factory=lambda: str(WritableManifest.dbt_schema_version)
|
||||
)
|
||||
project_id: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'description': 'A unique identifier for the project',
|
||||
"description": "A unique identifier for the project",
|
||||
},
|
||||
)
|
||||
user_id: Optional[UUID] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'description': 'A unique identifier for the user',
|
||||
"description": "A unique identifier for the user",
|
||||
},
|
||||
)
|
||||
send_anonymous_usage_stats: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata=dict(description=(
|
||||
'Whether dbt is configured to send anonymous usage statistics'
|
||||
)),
|
||||
metadata=dict(
|
||||
description=("Whether dbt is configured to send anonymous usage statistics")
|
||||
),
|
||||
)
|
||||
adapter_type: Optional[str] = field(
|
||||
default=None,
|
||||
metadata=dict(description='The type name of the adapter'),
|
||||
metadata=dict(description="The type name of the adapter"),
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -205,9 +226,7 @@ class ManifestMetadata(BaseArtifactMetadata):
|
||||
self.user_id = tracking.active_user.id
|
||||
|
||||
if self.send_anonymous_usage_stats is None:
|
||||
self.send_anonymous_usage_stats = (
|
||||
not tracking.active_user.do_not_track
|
||||
)
|
||||
self.send_anonymous_usage_stats = not tracking.active_user.do_not_track
|
||||
|
||||
@classmethod
|
||||
def default(cls):
|
||||
@@ -281,7 +300,7 @@ class MaterializationCandidate(MacroCandidate):
|
||||
@classmethod
|
||||
def from_macro(
|
||||
cls, candidate: MacroCandidate, specificity: Specificity
|
||||
) -> 'MaterializationCandidate':
|
||||
) -> "MaterializationCandidate":
|
||||
return cls(
|
||||
locality=candidate.locality,
|
||||
macro=candidate.macro,
|
||||
@@ -292,15 +311,14 @@ class MaterializationCandidate(MacroCandidate):
|
||||
if not isinstance(other, MaterializationCandidate):
|
||||
return NotImplemented
|
||||
equal = (
|
||||
self.specificity == other.specificity and
|
||||
self.locality == other.locality
|
||||
self.specificity == other.specificity and self.locality == other.locality
|
||||
)
|
||||
if equal:
|
||||
raise_compiler_error(
|
||||
'Found two materializations with the name {} (packages {} and '
|
||||
'{}). dbt cannot resolve this ambiguity'
|
||||
.format(self.macro.name, self.macro.package_name,
|
||||
other.macro.package_name)
|
||||
"Found two materializations with the name {} (packages {} and "
|
||||
"{}). dbt cannot resolve this ambiguity".format(
|
||||
self.macro.name, self.macro.package_name, other.macro.package_name
|
||||
)
|
||||
)
|
||||
|
||||
return equal
|
||||
@@ -319,7 +337,7 @@ class MaterializationCandidate(MacroCandidate):
|
||||
return False
|
||||
|
||||
|
||||
M = TypeVar('M', bound=MacroCandidate)
|
||||
M = TypeVar("M", bound=MacroCandidate)
|
||||
|
||||
|
||||
class CandidateList(List[M]):
|
||||
@@ -347,10 +365,10 @@ class Searchable(Protocol):
|
||||
|
||||
@property
|
||||
def search_name(self) -> str:
|
||||
raise NotImplementedError('search_name not implemented')
|
||||
raise NotImplementedError("search_name not implemented")
|
||||
|
||||
|
||||
N = TypeVar('N', bound=Searchable)
|
||||
N = TypeVar("N", bound=Searchable)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -382,7 +400,7 @@ class NameSearcher(Generic[N]):
|
||||
return None
|
||||
|
||||
|
||||
D = TypeVar('D')
|
||||
D = TypeVar("D")
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -393,19 +411,18 @@ class Disabled(Generic[D]):
|
||||
MaybeDocumentation = Optional[ParsedDocumentation]
|
||||
|
||||
|
||||
MaybeParsedSource = Optional[Union[
|
||||
MaybeParsedSource = Optional[
|
||||
Union[
|
||||
ParsedSourceDefinition,
|
||||
Disabled[ParsedSourceDefinition],
|
||||
]]
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
MaybeNonSource = Optional[Union[
|
||||
ManifestNode,
|
||||
Disabled[ManifestNode]
|
||||
]]
|
||||
MaybeNonSource = Optional[Union[ManifestNode, Disabled[ManifestNode]]]
|
||||
|
||||
|
||||
T = TypeVar('T', bound=GraphMemberNode)
|
||||
T = TypeVar("T", bound=GraphMemberNode)
|
||||
|
||||
|
||||
def _update_into(dest: MutableMapping[str, T], new_item: T):
|
||||
@@ -416,14 +433,13 @@ def _update_into(dest: MutableMapping[str, T], new_item: T):
|
||||
unique_id = new_item.unique_id
|
||||
if unique_id not in dest:
|
||||
raise dbt.exceptions.RuntimeException(
|
||||
f'got an update_{new_item.resource_type} call with an '
|
||||
f'unrecognized {new_item.resource_type}: {new_item.unique_id}'
|
||||
f"got an update_{new_item.resource_type} call with an "
|
||||
f"unrecognized {new_item.resource_type}: {new_item.unique_id}"
|
||||
)
|
||||
existing = dest[unique_id]
|
||||
if new_item.original_file_path != existing.original_file_path:
|
||||
raise dbt.exceptions.RuntimeException(
|
||||
f'cannot update a {new_item.resource_type} to have a new file '
|
||||
f'path!'
|
||||
f"cannot update a {new_item.resource_type} to have a new file " f"path!"
|
||||
)
|
||||
dest[unique_id] = new_item
|
||||
|
||||
@@ -447,6 +463,7 @@ class MacroMethods:
|
||||
"""
|
||||
filter: Optional[Callable[[MacroCandidate], bool]] = None
|
||||
if package is not None:
|
||||
|
||||
def filter(candidate: MacroCandidate) -> bool:
|
||||
return package == candidate.macro.package_name
|
||||
|
||||
@@ -469,11 +486,12 @@ class MacroMethods:
|
||||
- return the `generate_{component}_name` macro from the 'dbt'
|
||||
internal project
|
||||
"""
|
||||
|
||||
def filter(candidate: MacroCandidate) -> bool:
|
||||
return candidate.locality != Locality.Imported
|
||||
|
||||
candidates: CandidateList = self._find_macros_by_name(
|
||||
name=f'generate_{component}_name',
|
||||
name=f"generate_{component}_name",
|
||||
root_project_name=root_project_name,
|
||||
# filter out imported packages
|
||||
filter=filter,
|
||||
@@ -484,12 +502,12 @@ class MacroMethods:
|
||||
self,
|
||||
name: str,
|
||||
root_project_name: str,
|
||||
filter: Optional[Callable[[MacroCandidate], bool]] = None
|
||||
filter: Optional[Callable[[MacroCandidate], bool]] = None,
|
||||
) -> CandidateList:
|
||||
"""Find macros by their name.
|
||||
"""
|
||||
"""Find macros by their name."""
|
||||
# avoid an import cycle
|
||||
from dbt.adapters.factory import get_adapter_package_names
|
||||
|
||||
candidates: CandidateList = CandidateList()
|
||||
packages = set(get_adapter_package_names(self.metadata.adapter_type))
|
||||
for unique_id, macro in self.macros.items():
|
||||
@@ -507,8 +525,8 @@ class MacroMethods:
|
||||
|
||||
@dataclass
|
||||
class Manifest(MacroMethods):
|
||||
"""The manifest for the full graph, after parsing and during compilation.
|
||||
"""
|
||||
"""The manifest for the full graph, after parsing and during compilation."""
|
||||
|
||||
# These attributes are both positional and by keyword. If an attribute
|
||||
# is added it must all be added in the __reduce_ex__ method in the
|
||||
# args tuple in the right position.
|
||||
@@ -541,7 +559,7 @@ class Manifest(MacroMethods):
|
||||
"""
|
||||
with self._lock:
|
||||
existing = self.nodes[new_node.unique_id]
|
||||
if getattr(existing, 'compiled', False):
|
||||
if getattr(existing, "compiled", False):
|
||||
# already compiled -> must be a NonSourceCompiledNode
|
||||
return cast(NonSourceCompiledNode, existing)
|
||||
_update_into(self.nodes, new_node)
|
||||
@@ -563,39 +581,30 @@ class Manifest(MacroMethods):
|
||||
manifest!
|
||||
"""
|
||||
self.flat_graph = {
|
||||
'nodes': {
|
||||
k: v.to_dict(omit_none=False)
|
||||
for k, v in self.nodes.items()
|
||||
},
|
||||
'sources': {
|
||||
k: v.to_dict(omit_none=False)
|
||||
for k, v in self.sources.items()
|
||||
}
|
||||
"nodes": {k: v.to_dict(omit_none=False) for k, v in self.nodes.items()},
|
||||
"sources": {k: v.to_dict(omit_none=False) for k, v in self.sources.items()},
|
||||
}
|
||||
|
||||
def find_disabled_by_name(
|
||||
self, name: str, package: Optional[str] = None
|
||||
) -> Optional[ManifestNode]:
|
||||
searcher: NameSearcher = NameSearcher(
|
||||
name, package, NodeType.refable()
|
||||
)
|
||||
searcher: NameSearcher = NameSearcher(name, package, NodeType.refable())
|
||||
result = searcher.search(self.disabled)
|
||||
return result
|
||||
|
||||
def find_disabled_source_by_name(
|
||||
self, source_name: str, table_name: str, package: Optional[str] = None
|
||||
) -> Optional[ParsedSourceDefinition]:
|
||||
search_name = f'{source_name}.{table_name}'
|
||||
searcher: NameSearcher = NameSearcher(
|
||||
search_name, package, [NodeType.Source]
|
||||
)
|
||||
search_name = f"{source_name}.{table_name}"
|
||||
searcher: NameSearcher = NameSearcher(search_name, package, [NodeType.Source])
|
||||
result = searcher.search(self.disabled)
|
||||
if result is not None:
|
||||
assert isinstance(result, ParsedSourceDefinition)
|
||||
return result
|
||||
|
||||
def _materialization_candidates_for(
|
||||
self, project_name: str,
|
||||
self,
|
||||
project_name: str,
|
||||
materialization_name: str,
|
||||
adapter_type: Optional[str],
|
||||
) -> CandidateList:
|
||||
@@ -618,13 +627,16 @@ class Manifest(MacroMethods):
|
||||
def find_materialization_macro_by_name(
|
||||
self, project_name: str, materialization_name: str, adapter_type: str
|
||||
) -> Optional[ParsedMacro]:
|
||||
candidates: CandidateList = CandidateList(chain.from_iterable(
|
||||
candidates: CandidateList = CandidateList(
|
||||
chain.from_iterable(
|
||||
self._materialization_candidates_for(
|
||||
project_name=project_name,
|
||||
materialization_name=materialization_name,
|
||||
adapter_type=atype,
|
||||
) for atype in (adapter_type, None)
|
||||
))
|
||||
)
|
||||
for atype in (adapter_type, None)
|
||||
)
|
||||
)
|
||||
return candidates.last()
|
||||
|
||||
def get_resource_fqns(self) -> Mapping[str, PathSet]:
|
||||
@@ -648,9 +660,7 @@ class Manifest(MacroMethods):
|
||||
if node.resource_type in NodeType.refable():
|
||||
self._refs_cache.add_node(node)
|
||||
|
||||
def patch_macros(
|
||||
self, patches: MutableMapping[MacroKey, ParsedMacroPatch]
|
||||
) -> None:
|
||||
def patch_macros(self, patches: MutableMapping[MacroKey, ParsedMacroPatch]) -> None:
|
||||
for macro in self.macros.values():
|
||||
key = (macro.package_name, macro.name)
|
||||
patch = patches.pop(key, None)
|
||||
@@ -662,12 +672,10 @@ class Manifest(MacroMethods):
|
||||
for patch in patches.values():
|
||||
warn_or_error(
|
||||
f'WARNING: Found documentation for macro "{patch.name}" '
|
||||
f'which was not found'
|
||||
f"which was not found"
|
||||
)
|
||||
|
||||
def patch_nodes(
|
||||
self, patches: MutableMapping[str, ParsedNodePatch]
|
||||
) -> None:
|
||||
def patch_nodes(self, patches: MutableMapping[str, ParsedNodePatch]) -> None:
|
||||
"""Patch nodes with the given dict of patches. Note that this consumes
|
||||
the input!
|
||||
This relies on the fact that all nodes have unique _name_ fields, not
|
||||
@@ -684,15 +692,15 @@ class Manifest(MacroMethods):
|
||||
|
||||
expected_key = node.resource_type.pluralize()
|
||||
if expected_key != patch.yaml_key:
|
||||
if patch.yaml_key == 'models':
|
||||
if patch.yaml_key == "models":
|
||||
deprecations.warn(
|
||||
'models-key-mismatch',
|
||||
patch=patch, node=node, expected_key=expected_key
|
||||
"models-key-mismatch",
|
||||
patch=patch,
|
||||
node=node,
|
||||
expected_key=expected_key,
|
||||
)
|
||||
else:
|
||||
raise_invalid_patch(
|
||||
node, patch.yaml_key, patch.original_file_path
|
||||
)
|
||||
raise_invalid_patch(node, patch.yaml_key, patch.original_file_path)
|
||||
|
||||
node.patch(patch)
|
||||
|
||||
@@ -701,22 +709,25 @@ class Manifest(MacroMethods):
|
||||
for patch in patches.values():
|
||||
# since patches aren't nodes, we can't use the existing
|
||||
# target_not_found warning
|
||||
logger.debug((
|
||||
logger.debug(
|
||||
(
|
||||
'WARNING: Found documentation for resource "{}" which was '
|
||||
'not found or is disabled').format(patch.name)
|
||||
"not found or is disabled"
|
||||
).format(patch.name)
|
||||
)
|
||||
|
||||
def get_used_schemas(self, resource_types=None):
|
||||
return frozenset({
|
||||
(node.database, node.schema) for node in
|
||||
chain(self.nodes.values(), self.sources.values())
|
||||
return frozenset(
|
||||
{
|
||||
(node.database, node.schema)
|
||||
for node in chain(self.nodes.values(), self.sources.values())
|
||||
if not resource_types or node.resource_type in resource_types
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
def get_used_databases(self):
|
||||
return frozenset(
|
||||
x.database for x in
|
||||
chain(self.nodes.values(), self.sources.values())
|
||||
x.database for x in chain(self.nodes.values(), self.sources.values())
|
||||
)
|
||||
|
||||
def deepcopy(self):
|
||||
@@ -733,11 +744,13 @@ class Manifest(MacroMethods):
|
||||
)
|
||||
|
||||
def writable_manifest(self):
|
||||
edge_members = list(chain(
|
||||
edge_members = list(
|
||||
chain(
|
||||
self.nodes.values(),
|
||||
self.sources.values(),
|
||||
self.exposures.values(),
|
||||
))
|
||||
)
|
||||
)
|
||||
forward_edges, backward_edges = build_edges(edge_members)
|
||||
|
||||
return WritableManifest(
|
||||
@@ -771,7 +784,7 @@ class Manifest(MacroMethods):
|
||||
else:
|
||||
# something terrible has happened
|
||||
raise dbt.exceptions.InternalException(
|
||||
'Expected node {} not found in manifest'.format(unique_id)
|
||||
"Expected node {} not found in manifest".format(unique_id)
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -820,9 +833,7 @@ class Manifest(MacroMethods):
|
||||
|
||||
# it's possible that the node is disabled
|
||||
if disabled is None:
|
||||
disabled = self.find_disabled_by_name(
|
||||
target_model_name, pkg
|
||||
)
|
||||
disabled = self.find_disabled_by_name(target_model_name, pkg)
|
||||
|
||||
if disabled is not None:
|
||||
return Disabled(disabled)
|
||||
@@ -833,7 +844,7 @@ class Manifest(MacroMethods):
|
||||
target_source_name: str,
|
||||
target_table_name: str,
|
||||
current_project: str,
|
||||
node_package: str
|
||||
node_package: str,
|
||||
) -> MaybeParsedSource:
|
||||
key = (target_source_name, target_table_name)
|
||||
candidates = _search_packages(current_project, node_package)
|
||||
@@ -866,9 +877,7 @@ class Manifest(MacroMethods):
|
||||
resolve_ref except the is_enabled checks are unnecessary as docs are
|
||||
always enabled.
|
||||
"""
|
||||
candidates = _search_packages(
|
||||
current_project, node_package, package
|
||||
)
|
||||
candidates = _search_packages(current_project, node_package, package)
|
||||
|
||||
for pkg in candidates:
|
||||
result = self.docs_cache.find_cached_value(name, pkg)
|
||||
@@ -879,7 +888,7 @@ class Manifest(MacroMethods):
|
||||
def merge_from_artifact(
|
||||
self,
|
||||
adapter,
|
||||
other: 'WritableManifest',
|
||||
other: "WritableManifest",
|
||||
selected: AbstractSet[UniqueID],
|
||||
) -> None:
|
||||
"""Given the selected unique IDs and a writable manifest, update this
|
||||
@@ -892,10 +901,10 @@ class Manifest(MacroMethods):
|
||||
for unique_id, node in other.nodes.items():
|
||||
current = self.nodes.get(unique_id)
|
||||
if current and (
|
||||
node.resource_type in refables and
|
||||
not node.is_ephemeral and
|
||||
unique_id not in selected and
|
||||
not adapter.get_relation(
|
||||
node.resource_type in refables
|
||||
and not node.is_ephemeral
|
||||
and unique_id not in selected
|
||||
and not adapter.get_relation(
|
||||
current.database, current.schema, current.identifier
|
||||
)
|
||||
):
|
||||
@@ -904,9 +913,7 @@ class Manifest(MacroMethods):
|
||||
|
||||
# log up to 5 items
|
||||
sample = list(islice(merged, 5))
|
||||
logger.debug(
|
||||
f'Merged {len(merged)} items from state (sample: {sample})'
|
||||
)
|
||||
logger.debug(f"Merged {len(merged)} items from state (sample: {sample})")
|
||||
|
||||
# Provide support for copy.deepcopy() - we just need to avoid the lock!
|
||||
# pickle and deepcopy use this. It returns a callable object used to
|
||||
@@ -948,47 +955,53 @@ AnyManifest = Union[Manifest, MacroManifest]
|
||||
|
||||
|
||||
@dataclass
|
||||
@schema_version('manifest', 1)
|
||||
@schema_version("manifest", 1)
|
||||
class WritableManifest(ArtifactMixin):
|
||||
nodes: Mapping[UniqueID, ManifestNode] = field(
|
||||
metadata=dict(description=(
|
||||
'The nodes defined in the dbt project and its dependencies'
|
||||
))
|
||||
metadata=dict(
|
||||
description=("The nodes defined in the dbt project and its dependencies")
|
||||
)
|
||||
)
|
||||
sources: Mapping[UniqueID, ParsedSourceDefinition] = field(
|
||||
metadata=dict(description=(
|
||||
'The sources defined in the dbt project and its dependencies'
|
||||
))
|
||||
metadata=dict(
|
||||
description=("The sources defined in the dbt project and its dependencies")
|
||||
)
|
||||
)
|
||||
macros: Mapping[UniqueID, ParsedMacro] = field(
|
||||
metadata=dict(description=(
|
||||
'The macros defined in the dbt project and its dependencies'
|
||||
))
|
||||
metadata=dict(
|
||||
description=("The macros defined in the dbt project and its dependencies")
|
||||
)
|
||||
)
|
||||
docs: Mapping[UniqueID, ParsedDocumentation] = field(
|
||||
metadata=dict(description=(
|
||||
'The docs defined in the dbt project and its dependencies'
|
||||
))
|
||||
metadata=dict(
|
||||
description=("The docs defined in the dbt project and its dependencies")
|
||||
)
|
||||
)
|
||||
exposures: Mapping[UniqueID, ParsedExposure] = field(
|
||||
metadata=dict(description=(
|
||||
'The exposures defined in the dbt project and its dependencies'
|
||||
))
|
||||
metadata=dict(
|
||||
description=(
|
||||
"The exposures defined in the dbt project and its dependencies"
|
||||
)
|
||||
)
|
||||
)
|
||||
selectors: Mapping[UniqueID, Any] = field(
|
||||
metadata=dict(description=(
|
||||
'The selectors defined in selectors.yml'
|
||||
))
|
||||
metadata=dict(description=("The selectors defined in selectors.yml"))
|
||||
)
|
||||
disabled: Optional[List[CompileResultNode]] = field(
|
||||
metadata=dict(description="A list of the disabled nodes in the target")
|
||||
)
|
||||
parent_map: Optional[NodeEdgeMap] = field(
|
||||
metadata=dict(
|
||||
description="A mapping from child nodes to their dependencies",
|
||||
)
|
||||
)
|
||||
child_map: Optional[NodeEdgeMap] = field(
|
||||
metadata=dict(
|
||||
description="A mapping from parent nodes to their dependents",
|
||||
)
|
||||
)
|
||||
metadata: ManifestMetadata = field(
|
||||
metadata=dict(
|
||||
description="Metadata about the manifest",
|
||||
)
|
||||
)
|
||||
disabled: Optional[List[CompileResultNode]] = field(metadata=dict(
|
||||
description='A list of the disabled nodes in the target'
|
||||
))
|
||||
parent_map: Optional[NodeEdgeMap] = field(metadata=dict(
|
||||
description='A mapping from child nodes to their dependencies',
|
||||
))
|
||||
child_map: Optional[NodeEdgeMap] = field(metadata=dict(
|
||||
description='A mapping from parent nodes to their dependents',
|
||||
))
|
||||
metadata: ManifestMetadata = field(metadata=dict(
|
||||
description='Metadata about the manifest',
|
||||
))
|
||||
|
||||
@@ -2,11 +2,20 @@ from dataclasses import field, Field, dataclass
|
||||
from enum import Enum
|
||||
from itertools import chain
|
||||
from typing import (
|
||||
Any, List, Optional, Dict, MutableMapping, Union, Type,
|
||||
TypeVar, Callable,
|
||||
Any,
|
||||
List,
|
||||
Optional,
|
||||
Dict,
|
||||
MutableMapping,
|
||||
Union,
|
||||
Type,
|
||||
TypeVar,
|
||||
Callable,
|
||||
)
|
||||
from dbt.dataclass_schema import (
|
||||
dbtClassMixin, ValidationError, register_pattern,
|
||||
dbtClassMixin,
|
||||
ValidationError,
|
||||
register_pattern,
|
||||
)
|
||||
from dbt.contracts.graph.unparsed import AdditionalPropertiesAllowed
|
||||
from dbt.exceptions import CompilationException, InternalException
|
||||
@@ -15,7 +24,7 @@ from dbt import hooks
|
||||
from dbt.node_types import NodeType
|
||||
|
||||
|
||||
M = TypeVar('M', bound='Metadata')
|
||||
M = TypeVar("M", bound="Metadata")
|
||||
|
||||
|
||||
def _get_meta_value(cls: Type[M], fld: Field, key: str, default: Any) -> M:
|
||||
@@ -30,9 +39,7 @@ def _get_meta_value(cls: Type[M], fld: Field, key: str, default: Any) -> M:
|
||||
try:
|
||||
return cls(value)
|
||||
except ValueError as exc:
|
||||
raise InternalException(
|
||||
f'Invalid {cls} value: {value}'
|
||||
) from exc
|
||||
raise InternalException(f"Invalid {cls} value: {value}") from exc
|
||||
|
||||
|
||||
def _set_meta_value(
|
||||
@@ -54,19 +61,17 @@ class Metadata(Enum):
|
||||
|
||||
return _get_meta_value(cls, fld, key, default)
|
||||
|
||||
def meta(
|
||||
self, existing: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
def meta(self, existing: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
key = self.metadata_key()
|
||||
return _set_meta_value(self, key, existing)
|
||||
|
||||
@classmethod
|
||||
def default_field(cls) -> 'Metadata':
|
||||
raise NotImplementedError('Not implemented')
|
||||
def default_field(cls) -> "Metadata":
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@classmethod
|
||||
def metadata_key(cls) -> str:
|
||||
raise NotImplementedError('Not implemented')
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
|
||||
class MergeBehavior(Metadata):
|
||||
@@ -75,12 +80,12 @@ class MergeBehavior(Metadata):
|
||||
Clobber = 3
|
||||
|
||||
@classmethod
|
||||
def default_field(cls) -> 'MergeBehavior':
|
||||
def default_field(cls) -> "MergeBehavior":
|
||||
return cls.Clobber
|
||||
|
||||
@classmethod
|
||||
def metadata_key(cls) -> str:
|
||||
return 'merge'
|
||||
return "merge"
|
||||
|
||||
|
||||
class ShowBehavior(Metadata):
|
||||
@@ -88,12 +93,12 @@ class ShowBehavior(Metadata):
|
||||
Hide = 2
|
||||
|
||||
@classmethod
|
||||
def default_field(cls) -> 'ShowBehavior':
|
||||
def default_field(cls) -> "ShowBehavior":
|
||||
return cls.Show
|
||||
|
||||
@classmethod
|
||||
def metadata_key(cls) -> str:
|
||||
return 'show_hide'
|
||||
return "show_hide"
|
||||
|
||||
@classmethod
|
||||
def should_show(cls, fld: Field) -> bool:
|
||||
@@ -105,12 +110,12 @@ class CompareBehavior(Metadata):
|
||||
Exclude = 2
|
||||
|
||||
@classmethod
|
||||
def default_field(cls) -> 'CompareBehavior':
|
||||
def default_field(cls) -> "CompareBehavior":
|
||||
return cls.Include
|
||||
|
||||
@classmethod
|
||||
def metadata_key(cls) -> str:
|
||||
return 'compare'
|
||||
return "compare"
|
||||
|
||||
@classmethod
|
||||
def should_include(cls, fld: Field) -> bool:
|
||||
@@ -142,32 +147,30 @@ def _merge_field_value(
|
||||
return _listify(self_value) + _listify(other_value)
|
||||
elif merge_behavior == MergeBehavior.Update:
|
||||
if not isinstance(self_value, dict):
|
||||
raise InternalException(f'expected dict, got {self_value}')
|
||||
raise InternalException(f"expected dict, got {self_value}")
|
||||
if not isinstance(other_value, dict):
|
||||
raise InternalException(f'expected dict, got {other_value}')
|
||||
raise InternalException(f"expected dict, got {other_value}")
|
||||
value = self_value.copy()
|
||||
value.update(other_value)
|
||||
return value
|
||||
else:
|
||||
raise InternalException(
|
||||
f'Got an invalid merge_behavior: {merge_behavior}'
|
||||
)
|
||||
raise InternalException(f"Got an invalid merge_behavior: {merge_behavior}")
|
||||
|
||||
|
||||
def insensitive_patterns(*patterns: str):
|
||||
lowercased = []
|
||||
for pattern in patterns:
|
||||
lowercased.append(
|
||||
''.join('[{}{}]'.format(s.upper(), s.lower()) for s in pattern)
|
||||
"".join("[{}{}]".format(s.upper(), s.lower()) for s in pattern)
|
||||
)
|
||||
return '^({})$'.format('|'.join(lowercased))
|
||||
return "^({})$".format("|".join(lowercased))
|
||||
|
||||
|
||||
class Severity(str):
|
||||
pass
|
||||
|
||||
|
||||
register_pattern(Severity, insensitive_patterns('warn', 'error'))
|
||||
register_pattern(Severity, insensitive_patterns("warn", "error"))
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -177,13 +180,11 @@ class Hook(dbtClassMixin, Replaceable):
|
||||
index: Optional[int] = None
|
||||
|
||||
|
||||
T = TypeVar('T', bound='BaseConfig')
|
||||
T = TypeVar("T", bound="BaseConfig")
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseConfig(
|
||||
AdditionalPropertiesAllowed, Replaceable, MutableMapping[str, Any]
|
||||
):
|
||||
class BaseConfig(AdditionalPropertiesAllowed, Replaceable, MutableMapping[str, Any]):
|
||||
# Implement MutableMapping so this config will behave as some macros expect
|
||||
# during parsing (notably, syntax like `{{ node.config['schema'] }}`)
|
||||
def __getitem__(self, key):
|
||||
@@ -204,8 +205,7 @@ class BaseConfig(
|
||||
def __delitem__(self, key):
|
||||
if hasattr(self, key):
|
||||
msg = (
|
||||
'Error, tried to delete config key "{}": Cannot delete '
|
||||
'built-in keys'
|
||||
'Error, tried to delete config key "{}": Cannot delete ' "built-in keys"
|
||||
).format(key)
|
||||
raise CompilationException(msg)
|
||||
else:
|
||||
@@ -245,9 +245,7 @@ class BaseConfig(
|
||||
return unrendered[key] == other[key]
|
||||
|
||||
@classmethod
|
||||
def same_contents(
|
||||
cls, unrendered: Dict[str, Any], other: Dict[str, Any]
|
||||
) -> bool:
|
||||
def same_contents(cls, unrendered: Dict[str, Any], other: Dict[str, Any]) -> bool:
|
||||
"""This is like __eq__, except it ignores some fields."""
|
||||
seen = set()
|
||||
for fld, target_name in cls._get_fields():
|
||||
@@ -265,9 +263,7 @@ class BaseConfig(
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def _extract_dict(
|
||||
cls, src: Dict[str, Any], data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
def _extract_dict(cls, src: Dict[str, Any], data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Find all the items in data that match a target_field on this class,
|
||||
and merge them with the data found in `src` for target_field, using the
|
||||
field's specified merge behavior. Matching items will be removed from
|
||||
@@ -307,6 +303,7 @@ class BaseConfig(
|
||||
"""
|
||||
# sadly, this is a circular import
|
||||
from dbt.adapters.factory import get_config_class_by_name
|
||||
|
||||
dct = self.to_dict(omit_none=False)
|
||||
|
||||
adapter_config_cls = get_config_class_by_name(adapter_type)
|
||||
@@ -348,7 +345,7 @@ class SourceConfig(BaseConfig):
|
||||
@dataclass
|
||||
class NodeConfig(BaseConfig):
|
||||
enabled: bool = True
|
||||
materialized: str = 'view'
|
||||
materialized: str = "view"
|
||||
persist_docs: Dict[str, Any] = field(default_factory=dict)
|
||||
post_hook: List[Hook] = field(
|
||||
default_factory=list,
|
||||
@@ -389,16 +386,16 @@ class NodeConfig(BaseConfig):
|
||||
)
|
||||
tags: Union[List[str], str] = field(
|
||||
default_factory=list_str,
|
||||
metadata=metas(ShowBehavior.Hide,
|
||||
MergeBehavior.Append,
|
||||
CompareBehavior.Exclude),
|
||||
metadata=metas(
|
||||
ShowBehavior.Hide, MergeBehavior.Append, CompareBehavior.Exclude
|
||||
),
|
||||
)
|
||||
full_refresh: Optional[bool] = None
|
||||
|
||||
@classmethod
|
||||
def __pre_deserialize__(cls, data):
|
||||
data = super().__pre_deserialize__(data)
|
||||
field_map = {'post-hook': 'post_hook', 'pre-hook': 'pre_hook'}
|
||||
field_map = {"post-hook": "post_hook", "pre-hook": "pre_hook"}
|
||||
# create a new dict because otherwise it gets overwritten in
|
||||
# tests
|
||||
new_dict = {}
|
||||
@@ -416,7 +413,7 @@ class NodeConfig(BaseConfig):
|
||||
|
||||
def __post_serialize__(self, dct):
|
||||
dct = super().__post_serialize__(dct)
|
||||
field_map = {'post_hook': 'post-hook', 'pre_hook': 'pre-hook'}
|
||||
field_map = {"post_hook": "post-hook", "pre_hook": "pre-hook"}
|
||||
for field_name in field_map:
|
||||
if field_name in dct:
|
||||
dct[field_map[field_name]] = dct.pop(field_name)
|
||||
@@ -425,24 +422,24 @@ class NodeConfig(BaseConfig):
|
||||
# this is still used by jsonschema validation
|
||||
@classmethod
|
||||
def field_mapping(cls):
|
||||
return {'post_hook': 'post-hook', 'pre_hook': 'pre-hook'}
|
||||
return {"post_hook": "post-hook", "pre_hook": "pre-hook"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SeedConfig(NodeConfig):
|
||||
materialized: str = 'seed'
|
||||
materialized: str = "seed"
|
||||
quote_columns: Optional[bool] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestConfig(NodeConfig):
|
||||
materialized: str = 'test'
|
||||
severity: Severity = Severity('ERROR')
|
||||
materialized: str = "test"
|
||||
severity: Severity = Severity("ERROR")
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmptySnapshotConfig(NodeConfig):
|
||||
materialized: str = 'snapshot'
|
||||
materialized: str = "snapshot"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -457,25 +454,28 @@ class SnapshotConfig(EmptySnapshotConfig):
|
||||
@classmethod
|
||||
def validate(cls, data):
|
||||
super().validate(data)
|
||||
if data.get('strategy') == 'check':
|
||||
if not data.get('check_cols'):
|
||||
if data.get("strategy") == "check":
|
||||
if not data.get("check_cols"):
|
||||
raise ValidationError(
|
||||
"A snapshot configured with the check strategy must "
|
||||
"specify a check_cols configuration.")
|
||||
if (isinstance(data['check_cols'], str) and
|
||||
data['check_cols'] != 'all'):
|
||||
"specify a check_cols configuration."
|
||||
)
|
||||
if isinstance(data["check_cols"], str) and data["check_cols"] != "all":
|
||||
raise ValidationError(
|
||||
f"Invalid value for 'check_cols': {data['check_cols']}. "
|
||||
"Expected 'all' or a list of strings.")
|
||||
"Expected 'all' or a list of strings."
|
||||
)
|
||||
|
||||
elif data.get('strategy') == 'timestamp':
|
||||
if not data.get('updated_at'):
|
||||
elif data.get("strategy") == "timestamp":
|
||||
if not data.get("updated_at"):
|
||||
raise ValidationError(
|
||||
"A snapshot configured with the timestamp strategy "
|
||||
"must specify an updated_at configuration.")
|
||||
if data.get('check_cols'):
|
||||
"must specify an updated_at configuration."
|
||||
)
|
||||
if data.get("check_cols"):
|
||||
raise ValidationError(
|
||||
"A 'timestamp' snapshot should not have 'check_cols'")
|
||||
"A 'timestamp' snapshot should not have 'check_cols'"
|
||||
)
|
||||
# If the strategy is not 'check' or 'timestamp' it's a custom strategy,
|
||||
# formerly supported with GenericSnapshotConfig
|
||||
|
||||
@@ -497,9 +497,7 @@ RESOURCE_TYPES: Dict[NodeType, Type[BaseConfig]] = {
|
||||
# base resource types are like resource types, except nothing has mandatory
|
||||
# configs.
|
||||
BASE_RESOURCE_TYPES: Dict[NodeType, Type[BaseConfig]] = RESOURCE_TYPES.copy()
|
||||
BASE_RESOURCE_TYPES.update({
|
||||
NodeType.Snapshot: EmptySnapshotConfig
|
||||
})
|
||||
BASE_RESOURCE_TYPES.update({NodeType.Snapshot: EmptySnapshotConfig})
|
||||
|
||||
|
||||
def get_config_for(resource_type: NodeType, base=False) -> Type[BaseConfig]:
|
||||
|
||||
@@ -13,18 +13,27 @@ from typing import (
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from dbt.dataclass_schema import (
|
||||
dbtClassMixin, ExtensibleDbtClassMixin
|
||||
)
|
||||
from dbt.dataclass_schema import dbtClassMixin, ExtensibleDbtClassMixin
|
||||
|
||||
from dbt.clients.system import write_file
|
||||
from dbt.contracts.files import FileHash, MAXIMUM_SEED_SIZE_NAME
|
||||
from dbt.contracts.graph.unparsed import (
|
||||
UnparsedNode, UnparsedDocumentation, Quoting, Docs,
|
||||
UnparsedBaseNode, FreshnessThreshold, ExternalTable,
|
||||
HasYamlMetadata, MacroArgument, UnparsedSourceDefinition,
|
||||
UnparsedSourceTableDefinition, UnparsedColumn, TestDef,
|
||||
ExposureOwner, ExposureType, MaturityType
|
||||
UnparsedNode,
|
||||
UnparsedDocumentation,
|
||||
Quoting,
|
||||
Docs,
|
||||
UnparsedBaseNode,
|
||||
FreshnessThreshold,
|
||||
ExternalTable,
|
||||
HasYamlMetadata,
|
||||
MacroArgument,
|
||||
UnparsedSourceDefinition,
|
||||
UnparsedSourceTableDefinition,
|
||||
UnparsedColumn,
|
||||
TestDef,
|
||||
ExposureOwner,
|
||||
ExposureType,
|
||||
MaturityType,
|
||||
)
|
||||
from dbt.contracts.util import Replaceable, AdditionalPropertiesMixin
|
||||
from dbt.exceptions import warn_or_error
|
||||
@@ -44,13 +53,9 @@ from .model_config import (
|
||||
|
||||
|
||||
@dataclass
|
||||
class ColumnInfo(
|
||||
AdditionalPropertiesMixin,
|
||||
ExtensibleDbtClassMixin,
|
||||
Replaceable
|
||||
):
|
||||
class ColumnInfo(AdditionalPropertiesMixin, ExtensibleDbtClassMixin, Replaceable):
|
||||
name: str
|
||||
description: str = ''
|
||||
description: str = ""
|
||||
meta: Dict[str, Any] = field(default_factory=dict)
|
||||
data_type: Optional[str] = None
|
||||
quote: Optional[bool] = None
|
||||
@@ -62,7 +67,7 @@ class ColumnInfo(
|
||||
class HasFqn(dbtClassMixin, Replaceable):
|
||||
fqn: List[str]
|
||||
|
||||
def same_fqn(self, other: 'HasFqn') -> bool:
|
||||
def same_fqn(self, other: "HasFqn") -> bool:
|
||||
return self.fqn == other.fqn
|
||||
|
||||
|
||||
@@ -101,8 +106,8 @@ class HasRelationMetadata(dbtClassMixin, Replaceable):
|
||||
@classmethod
|
||||
def __pre_deserialize__(cls, data):
|
||||
data = super().__pre_deserialize__(data)
|
||||
if 'database' not in data:
|
||||
data['database'] = None
|
||||
if "database" not in data:
|
||||
data["database"] = None
|
||||
return data
|
||||
|
||||
|
||||
@@ -117,7 +122,7 @@ class ParsedNodeMixins(dbtClassMixin):
|
||||
|
||||
@property
|
||||
def is_ephemeral(self):
|
||||
return self.config.materialized == 'ephemeral'
|
||||
return self.config.materialized == "ephemeral"
|
||||
|
||||
@property
|
||||
def is_ephemeral_model(self):
|
||||
@@ -127,7 +132,7 @@ class ParsedNodeMixins(dbtClassMixin):
|
||||
def depends_on_nodes(self):
|
||||
return self.depends_on.nodes
|
||||
|
||||
def patch(self, patch: 'ParsedNodePatch'):
|
||||
def patch(self, patch: "ParsedNodePatch"):
|
||||
"""Given a ParsedNodePatch, add the new information to the node."""
|
||||
# explicitly pick out the parts to update so we don't inadvertently
|
||||
# step on the model name or anything
|
||||
@@ -153,11 +158,7 @@ class ParsedNodeMixins(dbtClassMixin):
|
||||
|
||||
@dataclass
|
||||
class ParsedNodeMandatory(
|
||||
UnparsedNode,
|
||||
HasUniqueID,
|
||||
HasFqn,
|
||||
HasRelationMetadata,
|
||||
Replaceable
|
||||
UnparsedNode, HasUniqueID, HasFqn, HasRelationMetadata, Replaceable
|
||||
):
|
||||
alias: str
|
||||
checksum: FileHash
|
||||
@@ -174,7 +175,7 @@ class ParsedNodeDefaults(ParsedNodeMandatory):
|
||||
refs: List[List[str]] = field(default_factory=list)
|
||||
sources: List[List[Any]] = field(default_factory=list)
|
||||
depends_on: DependsOn = field(default_factory=DependsOn)
|
||||
description: str = field(default='')
|
||||
description: str = field(default="")
|
||||
columns: Dict[str, ColumnInfo] = field(default_factory=dict)
|
||||
meta: Dict[str, Any] = field(default_factory=dict)
|
||||
docs: Docs = field(default_factory=Docs)
|
||||
@@ -184,31 +185,28 @@ class ParsedNodeDefaults(ParsedNodeMandatory):
|
||||
unrendered_config: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def write_node(self, target_path: str, subdirectory: str, payload: str):
|
||||
if (os.path.basename(self.path) ==
|
||||
os.path.basename(self.original_file_path)):
|
||||
if os.path.basename(self.path) == os.path.basename(self.original_file_path):
|
||||
# One-to-one relationship of nodes to files.
|
||||
path = self.original_file_path
|
||||
else:
|
||||
# Many-to-one relationship of nodes to files.
|
||||
path = os.path.join(self.original_file_path, self.path)
|
||||
full_path = os.path.join(
|
||||
target_path, subdirectory, self.package_name, path
|
||||
)
|
||||
full_path = os.path.join(target_path, subdirectory, self.package_name, path)
|
||||
|
||||
write_file(full_path, payload)
|
||||
return full_path
|
||||
|
||||
|
||||
T = TypeVar('T', bound='ParsedNode')
|
||||
T = TypeVar("T", bound="ParsedNode")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedNode(ParsedNodeDefaults, ParsedNodeMixins):
|
||||
def _persist_column_docs(self) -> bool:
|
||||
return bool(self.config.persist_docs.get('columns'))
|
||||
return bool(self.config.persist_docs.get("columns"))
|
||||
|
||||
def _persist_relation_docs(self) -> bool:
|
||||
return bool(self.config.persist_docs.get('relation'))
|
||||
return bool(self.config.persist_docs.get("relation"))
|
||||
|
||||
def same_body(self: T, other: T) -> bool:
|
||||
return self.raw_sql == other.raw_sql
|
||||
@@ -223,9 +221,7 @@ class ParsedNode(ParsedNodeDefaults, ParsedNodeMixins):
|
||||
|
||||
if self._persist_column_docs():
|
||||
# assert other._persist_column_docs()
|
||||
column_descriptions = {
|
||||
k: v.description for k, v in self.columns.items()
|
||||
}
|
||||
column_descriptions = {k: v.description for k, v in self.columns.items()}
|
||||
other_column_descriptions = {
|
||||
k: v.description for k, v in other.columns.items()
|
||||
}
|
||||
@@ -239,7 +235,7 @@ class ParsedNode(ParsedNodeDefaults, ParsedNodeMixins):
|
||||
# compares the configured value, rather than the ultimate value (so
|
||||
# generate_*_name and unset values derived from the target are
|
||||
# ignored)
|
||||
keys = ('database', 'schema', 'alias')
|
||||
keys = ("database", "schema", "alias")
|
||||
for key in keys:
|
||||
mine = self.unrendered_config.get(key)
|
||||
others = other.unrendered_config.get(key)
|
||||
@@ -258,36 +254,34 @@ class ParsedNode(ParsedNodeDefaults, ParsedNodeMixins):
|
||||
return False
|
||||
|
||||
return (
|
||||
self.same_body(old) and
|
||||
self.same_config(old) and
|
||||
self.same_persisted_description(old) and
|
||||
self.same_fqn(old) and
|
||||
self.same_database_representation(old) and
|
||||
True
|
||||
self.same_body(old)
|
||||
and self.same_config(old)
|
||||
and self.same_persisted_description(old)
|
||||
and self.same_fqn(old)
|
||||
and self.same_database_representation(old)
|
||||
and True
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedAnalysisNode(ParsedNode):
|
||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Analysis]})
|
||||
resource_type: NodeType = field(metadata={"restrict": [NodeType.Analysis]})
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedHookNode(ParsedNode):
|
||||
resource_type: NodeType = field(
|
||||
metadata={'restrict': [NodeType.Operation]}
|
||||
)
|
||||
resource_type: NodeType = field(metadata={"restrict": [NodeType.Operation]})
|
||||
index: Optional[int] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedModelNode(ParsedNode):
|
||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Model]})
|
||||
resource_type: NodeType = field(metadata={"restrict": [NodeType.Model]})
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedRPCNode(ParsedNode):
|
||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.RPCCall]})
|
||||
resource_type: NodeType = field(metadata={"restrict": [NodeType.RPCCall]})
|
||||
|
||||
|
||||
def same_seeds(first: ParsedNode, second: ParsedNode) -> bool:
|
||||
@@ -297,31 +291,31 @@ def same_seeds(first: ParsedNode, second: ParsedNode) -> bool:
|
||||
# if the current checksum is a path, we want to log a warning.
|
||||
result = first.checksum == second.checksum
|
||||
|
||||
if first.checksum.name == 'path':
|
||||
if first.checksum.name == "path":
|
||||
msg: str
|
||||
if second.checksum.name != 'path':
|
||||
if second.checksum.name != "path":
|
||||
msg = (
|
||||
f'Found a seed ({first.package_name}.{first.name}) '
|
||||
f'>{MAXIMUM_SEED_SIZE_NAME} in size. The previous file was '
|
||||
f'<={MAXIMUM_SEED_SIZE_NAME}, so it has changed'
|
||||
f"Found a seed ({first.package_name}.{first.name}) "
|
||||
f">{MAXIMUM_SEED_SIZE_NAME} in size. The previous file was "
|
||||
f"<={MAXIMUM_SEED_SIZE_NAME}, so it has changed"
|
||||
)
|
||||
elif result:
|
||||
msg = (
|
||||
f'Found a seed ({first.package_name}.{first.name}) '
|
||||
f'>{MAXIMUM_SEED_SIZE_NAME} in size at the same path, dbt '
|
||||
f'cannot tell if it has changed: assuming they are the same'
|
||||
f"Found a seed ({first.package_name}.{first.name}) "
|
||||
f">{MAXIMUM_SEED_SIZE_NAME} in size at the same path, dbt "
|
||||
f"cannot tell if it has changed: assuming they are the same"
|
||||
)
|
||||
elif not result:
|
||||
msg = (
|
||||
f'Found a seed ({first.package_name}.{first.name}) '
|
||||
f'>{MAXIMUM_SEED_SIZE_NAME} in size. The previous file was in '
|
||||
f'a different location, assuming it has changed'
|
||||
f"Found a seed ({first.package_name}.{first.name}) "
|
||||
f">{MAXIMUM_SEED_SIZE_NAME} in size. The previous file was in "
|
||||
f"a different location, assuming it has changed"
|
||||
)
|
||||
else:
|
||||
msg = (
|
||||
f'Found a seed ({first.package_name}.{first.name}) '
|
||||
f'>{MAXIMUM_SEED_SIZE_NAME} in size. The previous file had a '
|
||||
f'checksum type of {second.checksum.name}, so it has changed'
|
||||
f"Found a seed ({first.package_name}.{first.name}) "
|
||||
f">{MAXIMUM_SEED_SIZE_NAME} in size. The previous file had a "
|
||||
f"checksum type of {second.checksum.name}, so it has changed"
|
||||
)
|
||||
warn_or_error(msg, node=first)
|
||||
|
||||
@@ -331,7 +325,7 @@ def same_seeds(first: ParsedNode, second: ParsedNode) -> bool:
|
||||
@dataclass
|
||||
class ParsedSeedNode(ParsedNode):
|
||||
# keep this in sync with CompiledSeedNode!
|
||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Seed]})
|
||||
resource_type: NodeType = field(metadata={"restrict": [NodeType.Seed]})
|
||||
config: SeedConfig = field(default_factory=SeedConfig)
|
||||
|
||||
@property
|
||||
@@ -357,21 +351,20 @@ class HasTestMetadata(dbtClassMixin):
|
||||
|
||||
@dataclass
|
||||
class ParsedDataTestNode(ParsedNode):
|
||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Test]})
|
||||
resource_type: NodeType = field(metadata={"restrict": [NodeType.Test]})
|
||||
config: TestConfig = field(default_factory=TestConfig)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedSchemaTestNode(ParsedNode, HasTestMetadata):
|
||||
# keep this in sync with CompiledSchemaTestNode!
|
||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Test]})
|
||||
resource_type: NodeType = field(metadata={"restrict": [NodeType.Test]})
|
||||
column_name: Optional[str] = None
|
||||
config: TestConfig = field(default_factory=TestConfig)
|
||||
|
||||
def same_config(self, other) -> bool:
|
||||
return (
|
||||
self.unrendered_config.get('severity') ==
|
||||
other.unrendered_config.get('severity')
|
||||
return self.unrendered_config.get("severity") == other.unrendered_config.get(
|
||||
"severity"
|
||||
)
|
||||
|
||||
def same_column_name(self, other) -> bool:
|
||||
@@ -381,11 +374,7 @@ class ParsedSchemaTestNode(ParsedNode, HasTestMetadata):
|
||||
if other is None:
|
||||
return False
|
||||
|
||||
return (
|
||||
self.same_config(other) and
|
||||
self.same_fqn(other) and
|
||||
True
|
||||
)
|
||||
return self.same_config(other) and self.same_fqn(other) and True
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -396,13 +385,13 @@ class IntermediateSnapshotNode(ParsedNode):
|
||||
# defined in config blocks. To fix that, we have an intermediate type that
|
||||
# uses a regular node config, which the snapshot parser will then convert
|
||||
# into a full ParsedSnapshotNode after rendering.
|
||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Snapshot]})
|
||||
resource_type: NodeType = field(metadata={"restrict": [NodeType.Snapshot]})
|
||||
config: EmptySnapshotConfig = field(default_factory=EmptySnapshotConfig)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedSnapshotNode(ParsedNode):
|
||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Snapshot]})
|
||||
resource_type: NodeType = field(metadata={"restrict": [NodeType.Snapshot]})
|
||||
config: SnapshotConfig
|
||||
|
||||
|
||||
@@ -431,12 +420,12 @@ class ParsedMacroPatch(ParsedPatch):
|
||||
class ParsedMacro(UnparsedBaseNode, HasUniqueID):
|
||||
name: str
|
||||
macro_sql: str
|
||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Macro]})
|
||||
resource_type: NodeType = field(metadata={"restrict": [NodeType.Macro]})
|
||||
# TODO: can macros even have tags?
|
||||
tags: List[str] = field(default_factory=list)
|
||||
# TODO: is this ever populated?
|
||||
depends_on: MacroDependsOn = field(default_factory=MacroDependsOn)
|
||||
description: str = ''
|
||||
description: str = ""
|
||||
meta: Dict[str, Any] = field(default_factory=dict)
|
||||
docs: Docs = field(default_factory=Docs)
|
||||
patch_path: Optional[str] = None
|
||||
@@ -457,7 +446,7 @@ class ParsedMacro(UnparsedBaseNode, HasUniqueID):
|
||||
dct = self.to_dict(omit_none=False)
|
||||
self.validate(dct)
|
||||
|
||||
def same_contents(self, other: Optional['ParsedMacro']) -> bool:
|
||||
def same_contents(self, other: Optional["ParsedMacro"]) -> bool:
|
||||
if other is None:
|
||||
return False
|
||||
# the only thing that makes one macro different from another with the
|
||||
@@ -474,7 +463,7 @@ class ParsedDocumentation(UnparsedDocumentation, HasUniqueID):
|
||||
def search_name(self):
|
||||
return self.name
|
||||
|
||||
def same_contents(self, other: Optional['ParsedDocumentation']) -> bool:
|
||||
def same_contents(self, other: Optional["ParsedDocumentation"]) -> bool:
|
||||
if other is None:
|
||||
return False
|
||||
# the only thing that makes one doc different from another with the
|
||||
@@ -493,11 +482,11 @@ def normalize_test(testdef: TestDef) -> Dict[str, Any]:
|
||||
class UnpatchedSourceDefinition(UnparsedBaseNode, HasUniqueID, HasFqn):
|
||||
source: UnparsedSourceDefinition
|
||||
table: UnparsedSourceTableDefinition
|
||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Source]})
|
||||
resource_type: NodeType = field(metadata={"restrict": [NodeType.Source]})
|
||||
patch_path: Optional[Path] = None
|
||||
|
||||
def get_full_source_name(self):
|
||||
return f'{self.source.name}_{self.table.name}'
|
||||
return f"{self.source.name}_{self.table.name}"
|
||||
|
||||
def get_source_representation(self):
|
||||
return f'source("{self.source.name}", "{self.table.name}")'
|
||||
@@ -522,9 +511,7 @@ class UnpatchedSourceDefinition(UnparsedBaseNode, HasUniqueID, HasFqn):
|
||||
else:
|
||||
return self.table.columns
|
||||
|
||||
def get_tests(
|
||||
self
|
||||
) -> Iterator[Tuple[Dict[str, Any], Optional[UnparsedColumn]]]:
|
||||
def get_tests(self) -> Iterator[Tuple[Dict[str, Any], Optional[UnparsedColumn]]]:
|
||||
for test in self.tests:
|
||||
yield normalize_test(test), None
|
||||
|
||||
@@ -543,22 +530,19 @@ class UnpatchedSourceDefinition(UnparsedBaseNode, HasUniqueID, HasFqn):
|
||||
|
||||
@dataclass
|
||||
class ParsedSourceDefinition(
|
||||
UnparsedBaseNode,
|
||||
HasUniqueID,
|
||||
HasRelationMetadata,
|
||||
HasFqn
|
||||
UnparsedBaseNode, HasUniqueID, HasRelationMetadata, HasFqn
|
||||
):
|
||||
name: str
|
||||
source_name: str
|
||||
source_description: str
|
||||
loader: str
|
||||
identifier: str
|
||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Source]})
|
||||
resource_type: NodeType = field(metadata={"restrict": [NodeType.Source]})
|
||||
quoting: Quoting = field(default_factory=Quoting)
|
||||
loaded_at_field: Optional[str] = None
|
||||
freshness: Optional[FreshnessThreshold] = None
|
||||
external: Optional[ExternalTable] = None
|
||||
description: str = ''
|
||||
description: str = ""
|
||||
columns: Dict[str, ColumnInfo] = field(default_factory=dict)
|
||||
meta: Dict[str, Any] = field(default_factory=dict)
|
||||
source_meta: Dict[str, Any] = field(default_factory=dict)
|
||||
@@ -568,36 +552,34 @@ class ParsedSourceDefinition(
|
||||
unrendered_config: Dict[str, Any] = field(default_factory=dict)
|
||||
relation_name: Optional[str] = None
|
||||
|
||||
def same_database_representation(
|
||||
self, other: 'ParsedSourceDefinition'
|
||||
) -> bool:
|
||||
def same_database_representation(self, other: "ParsedSourceDefinition") -> bool:
|
||||
return (
|
||||
self.database == other.database and
|
||||
self.schema == other.schema and
|
||||
self.identifier == other.identifier and
|
||||
True
|
||||
self.database == other.database
|
||||
and self.schema == other.schema
|
||||
and self.identifier == other.identifier
|
||||
and True
|
||||
)
|
||||
|
||||
def same_quoting(self, other: 'ParsedSourceDefinition') -> bool:
|
||||
def same_quoting(self, other: "ParsedSourceDefinition") -> bool:
|
||||
return self.quoting == other.quoting
|
||||
|
||||
def same_freshness(self, other: 'ParsedSourceDefinition') -> bool:
|
||||
def same_freshness(self, other: "ParsedSourceDefinition") -> bool:
|
||||
return (
|
||||
self.freshness == other.freshness and
|
||||
self.loaded_at_field == other.loaded_at_field and
|
||||
True
|
||||
self.freshness == other.freshness
|
||||
and self.loaded_at_field == other.loaded_at_field
|
||||
and True
|
||||
)
|
||||
|
||||
def same_external(self, other: 'ParsedSourceDefinition') -> bool:
|
||||
def same_external(self, other: "ParsedSourceDefinition") -> bool:
|
||||
return self.external == other.external
|
||||
|
||||
def same_config(self, old: 'ParsedSourceDefinition') -> bool:
|
||||
def same_config(self, old: "ParsedSourceDefinition") -> bool:
|
||||
return self.config.same_contents(
|
||||
self.unrendered_config,
|
||||
old.unrendered_config,
|
||||
)
|
||||
|
||||
def same_contents(self, old: Optional['ParsedSourceDefinition']) -> bool:
|
||||
def same_contents(self, old: Optional["ParsedSourceDefinition"]) -> bool:
|
||||
# existing when it didn't before is a change!
|
||||
if old is None:
|
||||
return True
|
||||
@@ -611,17 +593,17 @@ class ParsedSourceDefinition(
|
||||
# metadata/tags changes are not "changes"
|
||||
# patching/description changes are not "changes"
|
||||
return (
|
||||
self.same_database_representation(old) and
|
||||
self.same_fqn(old) and
|
||||
self.same_config(old) and
|
||||
self.same_quoting(old) and
|
||||
self.same_freshness(old) and
|
||||
self.same_external(old) and
|
||||
True
|
||||
self.same_database_representation(old)
|
||||
and self.same_fqn(old)
|
||||
and self.same_config(old)
|
||||
and self.same_quoting(old)
|
||||
and self.same_freshness(old)
|
||||
and self.same_external(old)
|
||||
and True
|
||||
)
|
||||
|
||||
def get_full_source_name(self):
|
||||
return f'{self.source_name}_{self.name}'
|
||||
return f"{self.source_name}_{self.name}"
|
||||
|
||||
def get_source_representation(self):
|
||||
return f'source("{self.source.name}", "{self.table.name}")'
|
||||
@@ -656,7 +638,7 @@ class ParsedSourceDefinition(
|
||||
|
||||
@property
|
||||
def search_name(self):
|
||||
return f'{self.source_name}.{self.name}'
|
||||
return f"{self.source_name}.{self.name}"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -665,7 +647,7 @@ class ParsedExposure(UnparsedBaseNode, HasUniqueID, HasFqn):
|
||||
type: ExposureType
|
||||
owner: ExposureOwner
|
||||
resource_type: NodeType = NodeType.Exposure
|
||||
description: str = ''
|
||||
description: str = ""
|
||||
maturity: Optional[MaturityType] = None
|
||||
url: Optional[str] = None
|
||||
depends_on: DependsOn = field(default_factory=DependsOn)
|
||||
@@ -685,38 +667,38 @@ class ParsedExposure(UnparsedBaseNode, HasUniqueID, HasFqn):
|
||||
def tags(self):
|
||||
return []
|
||||
|
||||
def same_depends_on(self, old: 'ParsedExposure') -> bool:
|
||||
def same_depends_on(self, old: "ParsedExposure") -> bool:
|
||||
return set(self.depends_on.nodes) == set(old.depends_on.nodes)
|
||||
|
||||
def same_description(self, old: 'ParsedExposure') -> bool:
|
||||
def same_description(self, old: "ParsedExposure") -> bool:
|
||||
return self.description == old.description
|
||||
|
||||
def same_maturity(self, old: 'ParsedExposure') -> bool:
|
||||
def same_maturity(self, old: "ParsedExposure") -> bool:
|
||||
return self.maturity == old.maturity
|
||||
|
||||
def same_owner(self, old: 'ParsedExposure') -> bool:
|
||||
def same_owner(self, old: "ParsedExposure") -> bool:
|
||||
return self.owner == old.owner
|
||||
|
||||
def same_exposure_type(self, old: 'ParsedExposure') -> bool:
|
||||
def same_exposure_type(self, old: "ParsedExposure") -> bool:
|
||||
return self.type == old.type
|
||||
|
||||
def same_url(self, old: 'ParsedExposure') -> bool:
|
||||
def same_url(self, old: "ParsedExposure") -> bool:
|
||||
return self.url == old.url
|
||||
|
||||
def same_contents(self, old: Optional['ParsedExposure']) -> bool:
|
||||
def same_contents(self, old: Optional["ParsedExposure"]) -> bool:
|
||||
# existing when it didn't before is a change!
|
||||
if old is None:
|
||||
return True
|
||||
|
||||
return (
|
||||
self.same_fqn(old) and
|
||||
self.same_exposure_type(old) and
|
||||
self.same_owner(old) and
|
||||
self.same_maturity(old) and
|
||||
self.same_url(old) and
|
||||
self.same_description(old) and
|
||||
self.same_depends_on(old) and
|
||||
True
|
||||
self.same_fqn(old)
|
||||
and self.same_exposure_type(old)
|
||||
and self.same_owner(old)
|
||||
and self.same_maturity(old)
|
||||
and self.same_url(old)
|
||||
and self.same_description(old)
|
||||
and self.same_depends_on(old)
|
||||
and True
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -4,13 +4,12 @@ from dbt.contracts.util import (
|
||||
Mergeable,
|
||||
Replaceable,
|
||||
)
|
||||
|
||||
# trigger the PathEncoder
|
||||
import dbt.helper_types # noqa:F401
|
||||
from dbt.exceptions import CompilationException
|
||||
|
||||
from dbt.dataclass_schema import (
|
||||
dbtClassMixin, StrEnum, ExtensibleDbtClassMixin
|
||||
)
|
||||
from dbt.dataclass_schema import dbtClassMixin, StrEnum, ExtensibleDbtClassMixin
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import timedelta
|
||||
@@ -37,13 +36,15 @@ class HasSQL:
|
||||
|
||||
@dataclass
|
||||
class UnparsedMacro(UnparsedBaseNode, HasSQL):
|
||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Macro]})
|
||||
resource_type: NodeType = field(metadata={"restrict": [NodeType.Macro]})
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnparsedNode(UnparsedBaseNode, HasSQL):
|
||||
name: str
|
||||
resource_type: NodeType = field(metadata={'restrict': [
|
||||
resource_type: NodeType = field(
|
||||
metadata={
|
||||
"restrict": [
|
||||
NodeType.Model,
|
||||
NodeType.Analysis,
|
||||
NodeType.Test,
|
||||
@@ -51,7 +52,9 @@ class UnparsedNode(UnparsedBaseNode, HasSQL):
|
||||
NodeType.Operation,
|
||||
NodeType.Seed,
|
||||
NodeType.RPCCall,
|
||||
]})
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
@property
|
||||
def search_name(self):
|
||||
@@ -60,9 +63,7 @@ class UnparsedNode(UnparsedBaseNode, HasSQL):
|
||||
|
||||
@dataclass
|
||||
class UnparsedRunHook(UnparsedNode):
|
||||
resource_type: NodeType = field(
|
||||
metadata={'restrict': [NodeType.Operation]}
|
||||
)
|
||||
resource_type: NodeType = field(metadata={"restrict": [NodeType.Operation]})
|
||||
index: Optional[int] = None
|
||||
|
||||
|
||||
@@ -72,10 +73,9 @@ class Docs(dbtClassMixin, Replaceable):
|
||||
|
||||
|
||||
@dataclass
|
||||
class HasDocs(AdditionalPropertiesMixin, ExtensibleDbtClassMixin,
|
||||
Replaceable):
|
||||
class HasDocs(AdditionalPropertiesMixin, ExtensibleDbtClassMixin, Replaceable):
|
||||
name: str
|
||||
description: str = ''
|
||||
description: str = ""
|
||||
meta: Dict[str, Any] = field(default_factory=dict)
|
||||
data_type: Optional[str] = None
|
||||
docs: Docs = field(default_factory=Docs)
|
||||
@@ -131,7 +131,7 @@ class UnparsedNodeUpdate(HasColumnTests, HasTests, HasYamlMetadata):
|
||||
class MacroArgument(dbtClassMixin):
|
||||
name: str
|
||||
type: Optional[str] = None
|
||||
description: str = ''
|
||||
description: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -140,12 +140,12 @@ class UnparsedMacroUpdate(HasDocs, HasYamlMetadata):
|
||||
|
||||
|
||||
class TimePeriod(StrEnum):
|
||||
minute = 'minute'
|
||||
hour = 'hour'
|
||||
day = 'day'
|
||||
minute = "minute"
|
||||
hour = "hour"
|
||||
day = "day"
|
||||
|
||||
def plural(self) -> str:
|
||||
return str(self) + 's'
|
||||
return str(self) + "s"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -167,6 +167,7 @@ class FreshnessThreshold(dbtClassMixin, Mergeable):
|
||||
|
||||
def status(self, age: float) -> "dbt.contracts.results.FreshnessStatus":
|
||||
from dbt.contracts.results import FreshnessStatus
|
||||
|
||||
if self.error_after and self.error_after.exceeded(age):
|
||||
return FreshnessStatus.Error
|
||||
elif self.warn_after and self.warn_after.exceeded(age):
|
||||
@@ -179,24 +180,21 @@ class FreshnessThreshold(dbtClassMixin, Mergeable):
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdditionalPropertiesAllowed(
|
||||
AdditionalPropertiesMixin,
|
||||
ExtensibleDbtClassMixin
|
||||
):
|
||||
class AdditionalPropertiesAllowed(AdditionalPropertiesMixin, ExtensibleDbtClassMixin):
|
||||
_extra: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExternalPartition(AdditionalPropertiesAllowed, Replaceable):
|
||||
name: str = ''
|
||||
description: str = ''
|
||||
data_type: str = ''
|
||||
name: str = ""
|
||||
description: str = ""
|
||||
data_type: str = ""
|
||||
meta: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.name == '' or self.data_type == '':
|
||||
if self.name == "" or self.data_type == "":
|
||||
raise CompilationException(
|
||||
'External partition columns must have names and data types'
|
||||
"External partition columns must have names and data types"
|
||||
)
|
||||
|
||||
|
||||
@@ -225,43 +223,39 @@ class UnparsedSourceTableDefinition(HasColumnTests, HasTests):
|
||||
loaded_at_field: Optional[str] = None
|
||||
identifier: Optional[str] = None
|
||||
quoting: Quoting = field(default_factory=Quoting)
|
||||
freshness: Optional[FreshnessThreshold] = field(
|
||||
default_factory=FreshnessThreshold
|
||||
)
|
||||
freshness: Optional[FreshnessThreshold] = field(default_factory=FreshnessThreshold)
|
||||
external: Optional[ExternalTable] = None
|
||||
tags: List[str] = field(default_factory=list)
|
||||
|
||||
def __post_serialize__(self, dct):
|
||||
dct = super().__post_serialize__(dct)
|
||||
if 'freshness' not in dct and self.freshness is None:
|
||||
dct['freshness'] = None
|
||||
if "freshness" not in dct and self.freshness is None:
|
||||
dct["freshness"] = None
|
||||
return dct
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnparsedSourceDefinition(dbtClassMixin, Replaceable):
|
||||
name: str
|
||||
description: str = ''
|
||||
description: str = ""
|
||||
meta: Dict[str, Any] = field(default_factory=dict)
|
||||
database: Optional[str] = None
|
||||
schema: Optional[str] = None
|
||||
loader: str = ''
|
||||
loader: str = ""
|
||||
quoting: Quoting = field(default_factory=Quoting)
|
||||
freshness: Optional[FreshnessThreshold] = field(
|
||||
default_factory=FreshnessThreshold
|
||||
)
|
||||
freshness: Optional[FreshnessThreshold] = field(default_factory=FreshnessThreshold)
|
||||
loaded_at_field: Optional[str] = None
|
||||
tables: List[UnparsedSourceTableDefinition] = field(default_factory=list)
|
||||
tags: List[str] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def yaml_key(self) -> 'str':
|
||||
return 'sources'
|
||||
def yaml_key(self) -> "str":
|
||||
return "sources"
|
||||
|
||||
def __post_serialize__(self, dct):
|
||||
dct = super().__post_serialize__(dct)
|
||||
if 'freshnewss' not in dct and self.freshness is None:
|
||||
dct['freshness'] = None
|
||||
if "freshnewss" not in dct and self.freshness is None:
|
||||
dct["freshness"] = None
|
||||
return dct
|
||||
|
||||
|
||||
@@ -275,9 +269,7 @@ class SourceTablePatch(dbtClassMixin):
|
||||
loaded_at_field: Optional[str] = None
|
||||
identifier: Optional[str] = None
|
||||
quoting: Quoting = field(default_factory=Quoting)
|
||||
freshness: Optional[FreshnessThreshold] = field(
|
||||
default_factory=FreshnessThreshold
|
||||
)
|
||||
freshness: Optional[FreshnessThreshold] = field(default_factory=FreshnessThreshold)
|
||||
external: Optional[ExternalTable] = None
|
||||
tags: Optional[List[str]] = None
|
||||
tests: Optional[List[TestDef]] = None
|
||||
@@ -285,13 +277,13 @@ class SourceTablePatch(dbtClassMixin):
|
||||
|
||||
def to_patch_dict(self) -> Dict[str, Any]:
|
||||
dct = self.to_dict(omit_none=True)
|
||||
remove_keys = ('name')
|
||||
remove_keys = "name"
|
||||
for key in remove_keys:
|
||||
if key in dct:
|
||||
del dct[key]
|
||||
|
||||
if self.freshness is None:
|
||||
dct['freshness'] = None
|
||||
dct["freshness"] = None
|
||||
|
||||
return dct
|
||||
|
||||
@@ -299,13 +291,13 @@ class SourceTablePatch(dbtClassMixin):
|
||||
@dataclass
|
||||
class SourcePatch(dbtClassMixin, Replaceable):
|
||||
name: str = field(
|
||||
metadata=dict(description='The name of the source to override'),
|
||||
metadata=dict(description="The name of the source to override"),
|
||||
)
|
||||
overrides: str = field(
|
||||
metadata=dict(description='The package of the source to override'),
|
||||
metadata=dict(description="The package of the source to override"),
|
||||
)
|
||||
path: Path = field(
|
||||
metadata=dict(description='The path to the patch-defining yml file'),
|
||||
metadata=dict(description="The path to the patch-defining yml file"),
|
||||
)
|
||||
description: Optional[str] = None
|
||||
meta: Optional[Dict[str, Any]] = None
|
||||
@@ -322,13 +314,13 @@ class SourcePatch(dbtClassMixin, Replaceable):
|
||||
|
||||
def to_patch_dict(self) -> Dict[str, Any]:
|
||||
dct = self.to_dict(omit_none=True)
|
||||
remove_keys = ('name', 'overrides', 'tables', 'path')
|
||||
remove_keys = ("name", "overrides", "tables", "path")
|
||||
for key in remove_keys:
|
||||
if key in dct:
|
||||
del dct[key]
|
||||
|
||||
if self.freshness is None:
|
||||
dct['freshness'] = None
|
||||
dct["freshness"] = None
|
||||
|
||||
return dct
|
||||
|
||||
@@ -360,9 +352,9 @@ class UnparsedDocumentationFile(UnparsedDocumentation):
|
||||
# can't use total_ordering decorator here, as str provides an ordering already
|
||||
# and it's not the one we want.
|
||||
class Maturity(StrEnum):
|
||||
low = 'low'
|
||||
medium = 'medium'
|
||||
high = 'high'
|
||||
low = "low"
|
||||
medium = "medium"
|
||||
high = "high"
|
||||
|
||||
def __lt__(self, other):
|
||||
if not isinstance(other, Maturity):
|
||||
@@ -387,17 +379,17 @@ class Maturity(StrEnum):
|
||||
|
||||
|
||||
class ExposureType(StrEnum):
|
||||
Dashboard = 'dashboard'
|
||||
Notebook = 'notebook'
|
||||
Analysis = 'analysis'
|
||||
ML = 'ml'
|
||||
Application = 'application'
|
||||
Dashboard = "dashboard"
|
||||
Notebook = "notebook"
|
||||
Analysis = "analysis"
|
||||
ML = "ml"
|
||||
Application = "application"
|
||||
|
||||
|
||||
class MaturityType(StrEnum):
|
||||
Low = 'low'
|
||||
Medium = 'medium'
|
||||
High = 'high'
|
||||
Low = "low"
|
||||
Medium = "medium"
|
||||
High = "high"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -411,7 +403,7 @@ class UnparsedExposure(dbtClassMixin, Replaceable):
|
||||
name: str
|
||||
type: ExposureType
|
||||
owner: ExposureOwner
|
||||
description: str = ''
|
||||
description: str = ""
|
||||
maturity: Optional[MaturityType] = None
|
||||
url: Optional[str] = None
|
||||
depends_on: List[str] = field(default_factory=list)
|
||||
|
||||
@@ -5,24 +5,26 @@ from dbt.logger import GLOBAL_LOGGER as logger # noqa
|
||||
from dbt import tracking
|
||||
from dbt import ui
|
||||
from dbt.dataclass_schema import (
|
||||
dbtClassMixin, ValidationError,
|
||||
dbtClassMixin,
|
||||
ValidationError,
|
||||
HyphenatedDbtClassMixin,
|
||||
ExtensibleDbtClassMixin,
|
||||
register_pattern, ValidatedStringMixin
|
||||
register_pattern,
|
||||
ValidatedStringMixin,
|
||||
)
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, List, Dict, Union, Any
|
||||
from mashumaro.types import SerializableType
|
||||
|
||||
PIN_PACKAGE_URL = 'https://docs.getdbt.com/docs/package-management#section-specifying-package-versions' # noqa
|
||||
PIN_PACKAGE_URL = "https://docs.getdbt.com/docs/package-management#section-specifying-package-versions" # noqa
|
||||
DEFAULT_SEND_ANONYMOUS_USAGE_STATS = True
|
||||
|
||||
|
||||
class Name(ValidatedStringMixin):
|
||||
ValidationRegex = r'^[^\d\W]\w*$'
|
||||
ValidationRegex = r"^[^\d\W]\w*$"
|
||||
|
||||
|
||||
register_pattern(Name, r'^[^\d\W]\w*$')
|
||||
register_pattern(Name, r"^[^\d\W]\w*$")
|
||||
|
||||
|
||||
class SemverString(str, SerializableType):
|
||||
@@ -30,7 +32,7 @@ class SemverString(str, SerializableType):
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def _deserialize(cls, value: str) -> 'SemverString':
|
||||
def _deserialize(cls, value: str) -> "SemverString":
|
||||
return SemverString(value)
|
||||
|
||||
|
||||
@@ -39,7 +41,7 @@ class SemverString(str, SerializableType):
|
||||
# 'semver lite'.
|
||||
register_pattern(
|
||||
SemverString,
|
||||
r'^(?:0|[1-9]\d*)\.(?:0|[1-9]\d*)(\.(?:0|[1-9]\d*))?$',
|
||||
r"^(?:0|[1-9]\d*)\.(?:0|[1-9]\d*)(\.(?:0|[1-9]\d*))?$",
|
||||
)
|
||||
|
||||
|
||||
@@ -105,8 +107,7 @@ class ProjectPackageMetadata:
|
||||
|
||||
@classmethod
|
||||
def from_project(cls, project):
|
||||
return cls(name=project.project_name,
|
||||
packages=project.packages.packages)
|
||||
return cls(name=project.project_name, packages=project.packages.packages)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -124,46 +125,46 @@ class RegistryPackageMetadata(
|
||||
|
||||
# A list of all the reserved words that packages may not have as names.
|
||||
BANNED_PROJECT_NAMES = {
|
||||
'_sql_results',
|
||||
'adapter',
|
||||
'api',
|
||||
'column',
|
||||
'config',
|
||||
'context',
|
||||
'database',
|
||||
'env',
|
||||
'env_var',
|
||||
'exceptions',
|
||||
'execute',
|
||||
'flags',
|
||||
'fromjson',
|
||||
'fromyaml',
|
||||
'graph',
|
||||
'invocation_id',
|
||||
'load_agate_table',
|
||||
'load_result',
|
||||
'log',
|
||||
'model',
|
||||
'modules',
|
||||
'post_hooks',
|
||||
'pre_hooks',
|
||||
'ref',
|
||||
'render',
|
||||
'return',
|
||||
'run_started_at',
|
||||
'schema',
|
||||
'source',
|
||||
'sql',
|
||||
'sql_now',
|
||||
'store_result',
|
||||
'store_raw_result',
|
||||
'target',
|
||||
'this',
|
||||
'tojson',
|
||||
'toyaml',
|
||||
'try_or_compiler_error',
|
||||
'var',
|
||||
'write',
|
||||
"_sql_results",
|
||||
"adapter",
|
||||
"api",
|
||||
"column",
|
||||
"config",
|
||||
"context",
|
||||
"database",
|
||||
"env",
|
||||
"env_var",
|
||||
"exceptions",
|
||||
"execute",
|
||||
"flags",
|
||||
"fromjson",
|
||||
"fromyaml",
|
||||
"graph",
|
||||
"invocation_id",
|
||||
"load_agate_table",
|
||||
"load_result",
|
||||
"log",
|
||||
"model",
|
||||
"modules",
|
||||
"post_hooks",
|
||||
"pre_hooks",
|
||||
"ref",
|
||||
"render",
|
||||
"return",
|
||||
"run_started_at",
|
||||
"schema",
|
||||
"source",
|
||||
"sql",
|
||||
"sql_now",
|
||||
"store_result",
|
||||
"store_raw_result",
|
||||
"target",
|
||||
"this",
|
||||
"tojson",
|
||||
"toyaml",
|
||||
"try_or_compiler_error",
|
||||
"var",
|
||||
"write",
|
||||
}
|
||||
|
||||
|
||||
@@ -198,7 +199,7 @@ class Project(HyphenatedDbtClassMixin, Replaceable):
|
||||
vars: Optional[Dict[str, Any]] = field(
|
||||
default=None,
|
||||
metadata=dict(
|
||||
description='map project names to their vars override dicts',
|
||||
description="map project names to their vars override dicts",
|
||||
),
|
||||
)
|
||||
packages: List[PackageSpec] = field(default_factory=list)
|
||||
@@ -207,7 +208,7 @@ class Project(HyphenatedDbtClassMixin, Replaceable):
|
||||
@classmethod
|
||||
def validate(cls, data):
|
||||
super().validate(data)
|
||||
if data['name'] in BANNED_PROJECT_NAMES:
|
||||
if data["name"] in BANNED_PROJECT_NAMES:
|
||||
raise ValidationError(
|
||||
f"Invalid project name: {data['name']} is a reserved word"
|
||||
)
|
||||
@@ -235,8 +236,8 @@ class UserConfig(ExtensibleDbtClassMixin, Replaceable, UserConfigContract):
|
||||
|
||||
@dataclass
|
||||
class ProfileConfig(HyphenatedDbtClassMixin, Replaceable):
|
||||
profile_name: str = field(metadata={'preserve_underscore': True})
|
||||
target_name: str = field(metadata={'preserve_underscore': True})
|
||||
profile_name: str = field(metadata={"preserve_underscore": True})
|
||||
target_name: str = field(metadata={"preserve_underscore": True})
|
||||
config: UserConfig
|
||||
threads: int
|
||||
# TODO: make this a dynamic union of some kind?
|
||||
@@ -255,7 +256,7 @@ class ConfiguredQuoting(Quoting, Replaceable):
|
||||
class Configuration(Project, ProfileConfig):
|
||||
cli_vars: Dict[str, Any] = field(
|
||||
default_factory=dict,
|
||||
metadata={'preserve_underscore': True},
|
||||
metadata={"preserve_underscore": True},
|
||||
)
|
||||
quoting: Optional[ConfiguredQuoting] = None
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import (
|
||||
Optional, Dict,
|
||||
Optional,
|
||||
Dict,
|
||||
)
|
||||
from typing_extensions import Protocol
|
||||
|
||||
@@ -14,17 +15,17 @@ from dbt.utils import deep_merge
|
||||
|
||||
|
||||
class RelationType(StrEnum):
|
||||
Table = 'table'
|
||||
View = 'view'
|
||||
CTE = 'cte'
|
||||
MaterializedView = 'materializedview'
|
||||
External = 'external'
|
||||
Table = "table"
|
||||
View = "view"
|
||||
CTE = "cte"
|
||||
MaterializedView = "materializedview"
|
||||
External = "external"
|
||||
|
||||
|
||||
class ComponentName(StrEnum):
|
||||
Database = 'database'
|
||||
Schema = 'schema'
|
||||
Identifier = 'identifier'
|
||||
Database = "database"
|
||||
Schema = "schema"
|
||||
Identifier = "identifier"
|
||||
|
||||
|
||||
class HasQuoting(Protocol):
|
||||
@@ -43,12 +44,12 @@ class FakeAPIObject(dbtClassMixin, Replaceable, Mapping):
|
||||
raise KeyError(key) from None
|
||||
|
||||
def __iter__(self):
|
||||
deprecations.warn('not-a-dictionary', obj=self)
|
||||
deprecations.warn("not-a-dictionary", obj=self)
|
||||
for _, name in self._get_fields():
|
||||
yield name
|
||||
|
||||
def __len__(self):
|
||||
deprecations.warn('not-a-dictionary', obj=self)
|
||||
deprecations.warn("not-a-dictionary", obj=self)
|
||||
return len(fields(self.__class__))
|
||||
|
||||
def incorporate(self, **kwargs):
|
||||
@@ -72,8 +73,7 @@ class Policy(FakeAPIObject):
|
||||
return self.identifier
|
||||
else:
|
||||
raise ValueError(
|
||||
'Got a key of {}, expected one of {}'
|
||||
.format(key, list(ComponentName))
|
||||
"Got a key of {}, expected one of {}".format(key, list(ComponentName))
|
||||
)
|
||||
|
||||
def replace_dict(self, dct: Dict[ComponentName, bool]):
|
||||
@@ -93,15 +93,15 @@ class Path(FakeAPIObject):
|
||||
# handle pesky jinja2.Undefined sneaking in here and messing up rende
|
||||
if not isinstance(self.database, (type(None), str)):
|
||||
raise CompilationException(
|
||||
'Got an invalid path database: {}'.format(self.database)
|
||||
"Got an invalid path database: {}".format(self.database)
|
||||
)
|
||||
if not isinstance(self.schema, (type(None), str)):
|
||||
raise CompilationException(
|
||||
'Got an invalid path schema: {}'.format(self.schema)
|
||||
"Got an invalid path schema: {}".format(self.schema)
|
||||
)
|
||||
if not isinstance(self.identifier, (type(None), str)):
|
||||
raise CompilationException(
|
||||
'Got an invalid path identifier: {}'.format(self.identifier)
|
||||
"Got an invalid path identifier: {}".format(self.identifier)
|
||||
)
|
||||
|
||||
def get_lowered_part(self, key: ComponentName) -> Optional[str]:
|
||||
@@ -119,8 +119,7 @@ class Path(FakeAPIObject):
|
||||
return self.identifier
|
||||
else:
|
||||
raise ValueError(
|
||||
'Got a key of {}, expected one of {}'
|
||||
.format(key, list(ComponentName))
|
||||
"Got a key of {}, expected one of {}".format(key, list(ComponentName))
|
||||
)
|
||||
|
||||
def replace_dict(self, dct: Dict[ComponentName, str]):
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from dbt.contracts.graph.manifest import CompileResultNode
|
||||
from dbt.contracts.graph.unparsed import (
|
||||
FreshnessThreshold
|
||||
)
|
||||
from dbt.contracts.graph.unparsed import FreshnessThreshold
|
||||
from dbt.contracts.graph.parsed import ParsedSourceDefinition
|
||||
from dbt.contracts.util import (
|
||||
BaseArtifactMetadata,
|
||||
@@ -24,7 +22,13 @@ import agate
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
Union, Dict, List, Optional, Any, NamedTuple, Sequence,
|
||||
Union,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Any,
|
||||
NamedTuple,
|
||||
Sequence,
|
||||
)
|
||||
|
||||
from dbt.clients.system import write_json
|
||||
@@ -54,7 +58,7 @@ class collect_timing_info:
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.timing_info.end()
|
||||
with JsonOnly(), TimingProcessor(self.timing_info):
|
||||
logger.debug('finished collecting timing info')
|
||||
logger.debug("finished collecting timing info")
|
||||
|
||||
|
||||
class NodeStatus(StrEnum):
|
||||
@@ -99,8 +103,8 @@ class BaseResult(dbtClassMixin):
|
||||
@classmethod
|
||||
def __pre_deserialize__(cls, data):
|
||||
data = super().__pre_deserialize__(data)
|
||||
if 'message' not in data:
|
||||
data['message'] = None
|
||||
if "message" not in data:
|
||||
data["message"] = None
|
||||
return data
|
||||
|
||||
|
||||
@@ -112,9 +116,8 @@ class NodeResult(BaseResult):
|
||||
@dataclass
|
||||
class RunResult(NodeResult):
|
||||
agate_table: Optional[agate.Table] = field(
|
||||
default=None, metadata={
|
||||
'serialize': lambda x: None, 'deserialize': lambda x: None
|
||||
}
|
||||
default=None,
|
||||
metadata={"serialize": lambda x: None, "deserialize": lambda x: None},
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -157,7 +160,7 @@ def process_run_result(result: RunResult) -> RunResultOutput:
|
||||
thread_id=result.thread_id,
|
||||
execution_time=result.execution_time,
|
||||
message=result.message,
|
||||
adapter_response=result.adapter_response
|
||||
adapter_response=result.adapter_response,
|
||||
)
|
||||
|
||||
|
||||
@@ -180,7 +183,7 @@ class RunExecutionResult(
|
||||
|
||||
|
||||
@dataclass
|
||||
@schema_version('run-results', 1)
|
||||
@schema_version("run-results", 1)
|
||||
class RunResultsArtifact(ExecutionResult, ArtifactMixin):
|
||||
results: Sequence[RunResultOutput]
|
||||
args: Dict[str, Any] = field(default_factory=dict)
|
||||
@@ -202,7 +205,7 @@ class RunResultsArtifact(ExecutionResult, ArtifactMixin):
|
||||
metadata=meta,
|
||||
results=processed_results,
|
||||
elapsed_time=elapsed_time,
|
||||
args=args
|
||||
args=args,
|
||||
)
|
||||
|
||||
def write(self, path: str):
|
||||
@@ -216,15 +219,14 @@ class RunOperationResult(ExecutionResult):
|
||||
|
||||
@dataclass
|
||||
class RunOperationResultMetadata(BaseArtifactMetadata):
|
||||
dbt_schema_version: str = field(default_factory=lambda: str(
|
||||
RunOperationResultsArtifact.dbt_schema_version
|
||||
))
|
||||
dbt_schema_version: str = field(
|
||||
default_factory=lambda: str(RunOperationResultsArtifact.dbt_schema_version)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@schema_version('run-operation-result', 1)
|
||||
@schema_version("run-operation-result", 1)
|
||||
class RunOperationResultsArtifact(RunOperationResult, ArtifactMixin):
|
||||
|
||||
@classmethod
|
||||
def from_success(
|
||||
cls,
|
||||
@@ -243,6 +245,7 @@ class RunOperationResultsArtifact(RunOperationResult, ArtifactMixin):
|
||||
success=success,
|
||||
)
|
||||
|
||||
|
||||
# due to issues with typing.Union collapsing subclasses, this can't subclass
|
||||
# PartialResult
|
||||
|
||||
@@ -261,7 +264,7 @@ class SourceFreshnessResult(NodeResult):
|
||||
|
||||
|
||||
class FreshnessErrorEnum(StrEnum):
|
||||
runtime_error = 'runtime error'
|
||||
runtime_error = "runtime error"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -291,14 +294,11 @@ class PartialSourceFreshnessResult(NodeResult):
|
||||
return False
|
||||
|
||||
|
||||
FreshnessNodeResult = Union[PartialSourceFreshnessResult,
|
||||
SourceFreshnessResult]
|
||||
FreshnessNodeResult = Union[PartialSourceFreshnessResult, SourceFreshnessResult]
|
||||
FreshnessNodeOutput = Union[SourceFreshnessRuntimeError, SourceFreshnessOutput]
|
||||
|
||||
|
||||
def process_freshness_result(
|
||||
result: FreshnessNodeResult
|
||||
) -> FreshnessNodeOutput:
|
||||
def process_freshness_result(result: FreshnessNodeResult) -> FreshnessNodeOutput:
|
||||
unique_id = result.node.unique_id
|
||||
if result.status == FreshnessStatus.RuntimeErr:
|
||||
return SourceFreshnessRuntimeError(
|
||||
@@ -310,16 +310,15 @@ def process_freshness_result(
|
||||
# we know that this must be a SourceFreshnessResult
|
||||
if not isinstance(result, SourceFreshnessResult):
|
||||
raise InternalException(
|
||||
'Got {} instead of a SourceFreshnessResult for a '
|
||||
'non-error result in freshness execution!'
|
||||
.format(type(result))
|
||||
"Got {} instead of a SourceFreshnessResult for a "
|
||||
"non-error result in freshness execution!".format(type(result))
|
||||
)
|
||||
# if we're here, we must have a non-None freshness threshold
|
||||
criteria = result.node.freshness
|
||||
if criteria is None:
|
||||
raise InternalException(
|
||||
'Somehow evaluated a freshness result for a source '
|
||||
'that has no freshness criteria!'
|
||||
"Somehow evaluated a freshness result for a source "
|
||||
"that has no freshness criteria!"
|
||||
)
|
||||
return SourceFreshnessOutput(
|
||||
unique_id=unique_id,
|
||||
@@ -328,16 +327,14 @@ def process_freshness_result(
|
||||
max_loaded_at_time_ago_in_s=result.age,
|
||||
status=result.status,
|
||||
criteria=criteria,
|
||||
adapter_response=result.adapter_response
|
||||
adapter_response=result.adapter_response,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FreshnessMetadata(BaseArtifactMetadata):
|
||||
dbt_schema_version: str = field(
|
||||
default_factory=lambda: str(
|
||||
FreshnessExecutionResultArtifact.dbt_schema_version
|
||||
)
|
||||
default_factory=lambda: str(FreshnessExecutionResultArtifact.dbt_schema_version)
|
||||
)
|
||||
|
||||
|
||||
@@ -358,7 +355,7 @@ class FreshnessResult(ExecutionResult):
|
||||
|
||||
|
||||
@dataclass
|
||||
@schema_version('sources', 1)
|
||||
@schema_version("sources", 1)
|
||||
class FreshnessExecutionResultArtifact(
|
||||
ArtifactMixin,
|
||||
VersionedSchema,
|
||||
@@ -380,8 +377,7 @@ class FreshnessExecutionResultArtifact(
|
||||
Primitive = Union[bool, str, float, None]
|
||||
|
||||
CatalogKey = NamedTuple(
|
||||
'CatalogKey',
|
||||
[('database', Optional[str]), ('schema', str), ('name', str)]
|
||||
"CatalogKey", [("database", Optional[str]), ("schema", str), ("name", str)]
|
||||
)
|
||||
|
||||
|
||||
@@ -450,13 +446,13 @@ class CatalogResults(dbtClassMixin):
|
||||
|
||||
def __post_serialize__(self, dct):
|
||||
dct = super().__post_serialize__(dct)
|
||||
if '_compile_results' in dct:
|
||||
del dct['_compile_results']
|
||||
if "_compile_results" in dct:
|
||||
del dct["_compile_results"]
|
||||
return dct
|
||||
|
||||
|
||||
@dataclass
|
||||
@schema_version('catalog', 1)
|
||||
@schema_version("catalog", 1)
|
||||
class CatalogArtifact(CatalogResults, ArtifactMixin):
|
||||
metadata: CatalogMetadata
|
||||
|
||||
@@ -467,8 +463,8 @@ class CatalogArtifact(CatalogResults, ArtifactMixin):
|
||||
nodes: Dict[str, CatalogTable],
|
||||
sources: Dict[str, CatalogTable],
|
||||
compile_results: Optional[Any],
|
||||
errors: Optional[List[str]]
|
||||
) -> 'CatalogArtifact':
|
||||
errors: Optional[List[str]],
|
||||
) -> "CatalogArtifact":
|
||||
meta = CatalogMetadata(generated_at=generated_at)
|
||||
return cls(
|
||||
metadata=meta,
|
||||
|
||||
@@ -10,7 +10,9 @@ from dbt.dataclass_schema import dbtClassMixin, StrEnum
|
||||
from dbt.contracts.graph.compiled import CompileResultNode
|
||||
from dbt.contracts.graph.manifest import WritableManifest
|
||||
from dbt.contracts.results import (
|
||||
RunResult, RunResultsArtifact, TimingInfo,
|
||||
RunResult,
|
||||
RunResultsArtifact,
|
||||
TimingInfo,
|
||||
CatalogArtifact,
|
||||
CatalogResults,
|
||||
ExecutionResult,
|
||||
@@ -40,10 +42,10 @@ class RPCParameters(dbtClassMixin):
|
||||
@classmethod
|
||||
def __pre_deserialize__(cls, data, omit_none=True):
|
||||
data = super().__pre_deserialize__(data)
|
||||
if 'timeout' not in data:
|
||||
data['timeout'] = None
|
||||
if 'task_tags' not in data:
|
||||
data['task_tags'] = None
|
||||
if "timeout" not in data:
|
||||
data["timeout"] = None
|
||||
if "task_tags" not in data:
|
||||
data["task_tags"] = None
|
||||
return data
|
||||
|
||||
|
||||
@@ -161,6 +163,7 @@ class GCParameters(RPCParameters):
|
||||
will be applied to the task manager before GC starts. By default the
|
||||
existing gc settings remain.
|
||||
"""
|
||||
|
||||
task_ids: Optional[List[TaskID]] = None
|
||||
before: Optional[datetime] = None
|
||||
settings: Optional[GCSettings] = None
|
||||
@@ -182,6 +185,7 @@ class RPCSourceFreshnessParameters(RPCParameters):
|
||||
class GetManifestParameters(RPCParameters):
|
||||
pass
|
||||
|
||||
|
||||
# Outputs
|
||||
|
||||
|
||||
@@ -191,13 +195,13 @@ class RemoteResult(VersionedSchema):
|
||||
|
||||
|
||||
@dataclass
|
||||
@schema_version('remote-deps-result', 1)
|
||||
@schema_version("remote-deps-result", 1)
|
||||
class RemoteDepsResult(RemoteResult):
|
||||
generated_at: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
@dataclass
|
||||
@schema_version('remote-catalog-result', 1)
|
||||
@schema_version("remote-catalog-result", 1)
|
||||
class RemoteCatalogResults(CatalogResults, RemoteResult):
|
||||
generated_at: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
@@ -221,7 +225,7 @@ class RemoteCompileResultMixin(RemoteResult):
|
||||
|
||||
|
||||
@dataclass
|
||||
@schema_version('remote-compile-result', 1)
|
||||
@schema_version("remote-compile-result", 1)
|
||||
class RemoteCompileResult(RemoteCompileResultMixin):
|
||||
generated_at: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
@@ -231,7 +235,7 @@ class RemoteCompileResult(RemoteCompileResultMixin):
|
||||
|
||||
|
||||
@dataclass
|
||||
@schema_version('remote-execution-result', 1)
|
||||
@schema_version("remote-execution-result", 1)
|
||||
class RemoteExecutionResult(ExecutionResult, RemoteResult):
|
||||
results: Sequence[RunResult]
|
||||
args: Dict[str, Any] = field(default_factory=dict)
|
||||
@@ -251,7 +255,7 @@ class RemoteExecutionResult(ExecutionResult, RemoteResult):
|
||||
cls,
|
||||
base: RunExecutionResult,
|
||||
logs: List[LogMessage],
|
||||
) -> 'RemoteExecutionResult':
|
||||
) -> "RemoteExecutionResult":
|
||||
return cls(
|
||||
generated_at=base.generated_at,
|
||||
results=base.results,
|
||||
@@ -268,7 +272,7 @@ class ResultTable(dbtClassMixin):
|
||||
|
||||
|
||||
@dataclass
|
||||
@schema_version('remote-run-operation-result', 1)
|
||||
@schema_version("remote-run-operation-result", 1)
|
||||
class RemoteRunOperationResult(RunOperationResult, RemoteResult):
|
||||
generated_at: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
@@ -277,7 +281,7 @@ class RemoteRunOperationResult(RunOperationResult, RemoteResult):
|
||||
cls,
|
||||
base: RunOperationResultsArtifact,
|
||||
logs: List[LogMessage],
|
||||
) -> 'RemoteRunOperationResult':
|
||||
) -> "RemoteRunOperationResult":
|
||||
return cls(
|
||||
generated_at=base.metadata.generated_at,
|
||||
results=base.results,
|
||||
@@ -296,15 +300,14 @@ class RemoteRunOperationResult(RunOperationResult, RemoteResult):
|
||||
|
||||
|
||||
@dataclass
|
||||
@schema_version('remote-freshness-result', 1)
|
||||
@schema_version("remote-freshness-result", 1)
|
||||
class RemoteFreshnessResult(FreshnessResult, RemoteResult):
|
||||
|
||||
@classmethod
|
||||
def from_local_result(
|
||||
cls,
|
||||
base: FreshnessResult,
|
||||
logs: List[LogMessage],
|
||||
) -> 'RemoteFreshnessResult':
|
||||
) -> "RemoteFreshnessResult":
|
||||
return cls(
|
||||
metadata=base.metadata,
|
||||
results=base.results,
|
||||
@@ -318,7 +321,7 @@ class RemoteFreshnessResult(FreshnessResult, RemoteResult):
|
||||
|
||||
|
||||
@dataclass
|
||||
@schema_version('remote-run-result', 1)
|
||||
@schema_version("remote-run-result", 1)
|
||||
class RemoteRunResult(RemoteCompileResultMixin):
|
||||
table: ResultTable
|
||||
generated_at: datetime = field(default_factory=datetime.utcnow)
|
||||
@@ -336,14 +339,15 @@ RPCResult = Union[
|
||||
|
||||
# GC types
|
||||
|
||||
|
||||
class GCResultState(StrEnum):
|
||||
Deleted = 'deleted' # successful GC
|
||||
Missing = 'missing' # nothing to GC
|
||||
Running = 'running' # can't GC
|
||||
Deleted = "deleted" # successful GC
|
||||
Missing = "missing" # nothing to GC
|
||||
Running = "running" # can't GC
|
||||
|
||||
|
||||
@dataclass
|
||||
@schema_version('remote-gc-result', 1)
|
||||
@schema_version("remote-gc-result", 1)
|
||||
class GCResult(RemoteResult):
|
||||
logs: List[LogMessage] = field(default_factory=list)
|
||||
deleted: List[TaskID] = field(default_factory=list)
|
||||
@@ -358,21 +362,20 @@ class GCResult(RemoteResult):
|
||||
elif state == GCResultState.Deleted:
|
||||
self.deleted.append(task_id)
|
||||
else:
|
||||
raise InternalException(
|
||||
f'Got invalid state in add_result: {state}'
|
||||
)
|
||||
raise InternalException(f"Got invalid state in add_result: {state}")
|
||||
|
||||
|
||||
# Task management types
|
||||
|
||||
|
||||
class TaskHandlerState(StrEnum):
|
||||
NotStarted = 'not started'
|
||||
Initializing = 'initializing'
|
||||
Running = 'running'
|
||||
Success = 'success'
|
||||
Error = 'error'
|
||||
Killed = 'killed'
|
||||
Failed = 'failed'
|
||||
NotStarted = "not started"
|
||||
Initializing = "initializing"
|
||||
Running = "running"
|
||||
Success = "success"
|
||||
Error = "error"
|
||||
Killed = "killed"
|
||||
Failed = "failed"
|
||||
|
||||
def __lt__(self, other) -> bool:
|
||||
"""A logical ordering for TaskHandlerState:
|
||||
@@ -380,7 +383,7 @@ class TaskHandlerState(StrEnum):
|
||||
NotStarted < Initializing < Running < (Success, Error, Killed, Failed)
|
||||
"""
|
||||
if not isinstance(other, TaskHandlerState):
|
||||
raise TypeError('cannot compare to non-TaskHandlerState')
|
||||
raise TypeError("cannot compare to non-TaskHandlerState")
|
||||
order = (self.NotStarted, self.Initializing, self.Running)
|
||||
smaller = set()
|
||||
for value in order:
|
||||
@@ -392,13 +395,11 @@ class TaskHandlerState(StrEnum):
|
||||
|
||||
def __le__(self, other) -> bool:
|
||||
# so that ((Success <= Error) is True)
|
||||
return ((self < other) or
|
||||
(self == other) or
|
||||
(self.finished and other.finished))
|
||||
return (self < other) or (self == other) or (self.finished and other.finished)
|
||||
|
||||
def __gt__(self, other) -> bool:
|
||||
if not isinstance(other, TaskHandlerState):
|
||||
raise TypeError('cannot compare to non-TaskHandlerState')
|
||||
raise TypeError("cannot compare to non-TaskHandlerState")
|
||||
order = (self.NotStarted, self.Initializing, self.Running)
|
||||
smaller = set()
|
||||
for value in order:
|
||||
@@ -409,9 +410,7 @@ class TaskHandlerState(StrEnum):
|
||||
|
||||
def __ge__(self, other) -> bool:
|
||||
# so that ((Success <= Error) is True)
|
||||
return ((self > other) or
|
||||
(self == other) or
|
||||
(self.finished and other.finished))
|
||||
return (self > other) or (self == other) or (self.finished and other.finished)
|
||||
|
||||
@property
|
||||
def finished(self) -> bool:
|
||||
@@ -430,7 +429,7 @@ class TaskTiming(dbtClassMixin):
|
||||
@classmethod
|
||||
def __pre_deserialize__(cls, data):
|
||||
data = super().__pre_deserialize__(data)
|
||||
for field_name in ('start', 'end', 'elapsed'):
|
||||
for field_name in ("start", "end", "elapsed"):
|
||||
if field_name not in data:
|
||||
data[field_name] = None
|
||||
return data
|
||||
@@ -447,27 +446,27 @@ class TaskRow(TaskTiming):
|
||||
|
||||
|
||||
@dataclass
|
||||
@schema_version('remote-ps-result', 1)
|
||||
@schema_version("remote-ps-result", 1)
|
||||
class PSResult(RemoteResult):
|
||||
rows: List[TaskRow]
|
||||
|
||||
|
||||
class KillResultStatus(StrEnum):
|
||||
Missing = 'missing'
|
||||
NotStarted = 'not_started'
|
||||
Killed = 'killed'
|
||||
Finished = 'finished'
|
||||
Missing = "missing"
|
||||
NotStarted = "not_started"
|
||||
Killed = "killed"
|
||||
Finished = "finished"
|
||||
|
||||
|
||||
@dataclass
|
||||
@schema_version('remote-kill-result', 1)
|
||||
@schema_version("remote-kill-result", 1)
|
||||
class KillResult(RemoteResult):
|
||||
state: KillResultStatus = KillResultStatus.Missing
|
||||
logs: List[LogMessage] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
@schema_version('remote-manifest-result', 1)
|
||||
@schema_version("remote-manifest-result", 1)
|
||||
class GetManifestResult(RemoteResult):
|
||||
manifest: Optional[WritableManifest] = None
|
||||
|
||||
@@ -498,29 +497,28 @@ class PollResult(RemoteResult, TaskTiming):
|
||||
@classmethod
|
||||
def __pre_deserialize__(cls, data):
|
||||
data = super().__pre_deserialize__(data)
|
||||
for field_name in ('start', 'end', 'elapsed'):
|
||||
for field_name in ("start", "end", "elapsed"):
|
||||
if field_name not in data:
|
||||
data[field_name] = None
|
||||
return data
|
||||
|
||||
|
||||
@dataclass
|
||||
@schema_version('poll-remote-deps-result', 1)
|
||||
@schema_version("poll-remote-deps-result", 1)
|
||||
class PollRemoteEmptyCompleteResult(PollResult, RemoteResult):
|
||||
state: TaskHandlerState = field(
|
||||
metadata=restrict_to(TaskHandlerState.Success,
|
||||
TaskHandlerState.Failed),
|
||||
metadata=restrict_to(TaskHandlerState.Success, TaskHandlerState.Failed),
|
||||
)
|
||||
generated_at: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
@classmethod
|
||||
def from_result(
|
||||
cls: Type['PollRemoteEmptyCompleteResult'],
|
||||
cls: Type["PollRemoteEmptyCompleteResult"],
|
||||
base: RemoteDepsResult,
|
||||
tags: TaskTags,
|
||||
timing: TaskTiming,
|
||||
logs: List[LogMessage],
|
||||
) -> 'PollRemoteEmptyCompleteResult':
|
||||
) -> "PollRemoteEmptyCompleteResult":
|
||||
return cls(
|
||||
logs=logs,
|
||||
tags=tags,
|
||||
@@ -528,12 +526,12 @@ class PollRemoteEmptyCompleteResult(PollResult, RemoteResult):
|
||||
start=timing.start,
|
||||
end=timing.end,
|
||||
elapsed=timing.elapsed,
|
||||
generated_at=base.generated_at
|
||||
generated_at=base.generated_at,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@schema_version('poll-remote-killed-result', 1)
|
||||
@schema_version("poll-remote-killed-result", 1)
|
||||
class PollKilledResult(PollResult):
|
||||
state: TaskHandlerState = field(
|
||||
metadata=restrict_to(TaskHandlerState.Killed),
|
||||
@@ -541,24 +539,23 @@ class PollKilledResult(PollResult):
|
||||
|
||||
|
||||
@dataclass
|
||||
@schema_version('poll-remote-execution-result', 1)
|
||||
@schema_version("poll-remote-execution-result", 1)
|
||||
class PollExecuteCompleteResult(
|
||||
RemoteExecutionResult,
|
||||
PollResult,
|
||||
):
|
||||
state: TaskHandlerState = field(
|
||||
metadata=restrict_to(TaskHandlerState.Success,
|
||||
TaskHandlerState.Failed),
|
||||
metadata=restrict_to(TaskHandlerState.Success, TaskHandlerState.Failed),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_result(
|
||||
cls: Type['PollExecuteCompleteResult'],
|
||||
cls: Type["PollExecuteCompleteResult"],
|
||||
base: RemoteExecutionResult,
|
||||
tags: TaskTags,
|
||||
timing: TaskTiming,
|
||||
logs: List[LogMessage],
|
||||
) -> 'PollExecuteCompleteResult':
|
||||
) -> "PollExecuteCompleteResult":
|
||||
return cls(
|
||||
results=base.results,
|
||||
elapsed_time=base.elapsed_time,
|
||||
@@ -573,24 +570,23 @@ class PollExecuteCompleteResult(
|
||||
|
||||
|
||||
@dataclass
|
||||
@schema_version('poll-remote-compile-result', 1)
|
||||
@schema_version("poll-remote-compile-result", 1)
|
||||
class PollCompileCompleteResult(
|
||||
RemoteCompileResult,
|
||||
PollResult,
|
||||
):
|
||||
state: TaskHandlerState = field(
|
||||
metadata=restrict_to(TaskHandlerState.Success,
|
||||
TaskHandlerState.Failed),
|
||||
metadata=restrict_to(TaskHandlerState.Success, TaskHandlerState.Failed),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_result(
|
||||
cls: Type['PollCompileCompleteResult'],
|
||||
cls: Type["PollCompileCompleteResult"],
|
||||
base: RemoteCompileResult,
|
||||
tags: TaskTags,
|
||||
timing: TaskTiming,
|
||||
logs: List[LogMessage],
|
||||
) -> 'PollCompileCompleteResult':
|
||||
) -> "PollCompileCompleteResult":
|
||||
return cls(
|
||||
raw_sql=base.raw_sql,
|
||||
compiled_sql=base.compiled_sql,
|
||||
@@ -602,29 +598,28 @@ class PollCompileCompleteResult(
|
||||
start=timing.start,
|
||||
end=timing.end,
|
||||
elapsed=timing.elapsed,
|
||||
generated_at=base.generated_at
|
||||
generated_at=base.generated_at,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@schema_version('poll-remote-run-result', 1)
|
||||
@schema_version("poll-remote-run-result", 1)
|
||||
class PollRunCompleteResult(
|
||||
RemoteRunResult,
|
||||
PollResult,
|
||||
):
|
||||
state: TaskHandlerState = field(
|
||||
metadata=restrict_to(TaskHandlerState.Success,
|
||||
TaskHandlerState.Failed),
|
||||
metadata=restrict_to(TaskHandlerState.Success, TaskHandlerState.Failed),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_result(
|
||||
cls: Type['PollRunCompleteResult'],
|
||||
cls: Type["PollRunCompleteResult"],
|
||||
base: RemoteRunResult,
|
||||
tags: TaskTags,
|
||||
timing: TaskTiming,
|
||||
logs: List[LogMessage],
|
||||
) -> 'PollRunCompleteResult':
|
||||
) -> "PollRunCompleteResult":
|
||||
return cls(
|
||||
raw_sql=base.raw_sql,
|
||||
compiled_sql=base.compiled_sql,
|
||||
@@ -637,29 +632,28 @@ class PollRunCompleteResult(
|
||||
start=timing.start,
|
||||
end=timing.end,
|
||||
elapsed=timing.elapsed,
|
||||
generated_at=base.generated_at
|
||||
generated_at=base.generated_at,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@schema_version('poll-remote-run-operation-result', 1)
|
||||
@schema_version("poll-remote-run-operation-result", 1)
|
||||
class PollRunOperationCompleteResult(
|
||||
RemoteRunOperationResult,
|
||||
PollResult,
|
||||
):
|
||||
state: TaskHandlerState = field(
|
||||
metadata=restrict_to(TaskHandlerState.Success,
|
||||
TaskHandlerState.Failed),
|
||||
metadata=restrict_to(TaskHandlerState.Success, TaskHandlerState.Failed),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_result(
|
||||
cls: Type['PollRunOperationCompleteResult'],
|
||||
cls: Type["PollRunOperationCompleteResult"],
|
||||
base: RemoteRunOperationResult,
|
||||
tags: TaskTags,
|
||||
timing: TaskTiming,
|
||||
logs: List[LogMessage],
|
||||
) -> 'PollRunOperationCompleteResult':
|
||||
) -> "PollRunOperationCompleteResult":
|
||||
return cls(
|
||||
success=base.success,
|
||||
results=base.results,
|
||||
@@ -675,21 +669,20 @@ class PollRunOperationCompleteResult(
|
||||
|
||||
|
||||
@dataclass
|
||||
@schema_version('poll-remote-catalog-result', 1)
|
||||
@schema_version("poll-remote-catalog-result", 1)
|
||||
class PollCatalogCompleteResult(RemoteCatalogResults, PollResult):
|
||||
state: TaskHandlerState = field(
|
||||
metadata=restrict_to(TaskHandlerState.Success,
|
||||
TaskHandlerState.Failed),
|
||||
metadata=restrict_to(TaskHandlerState.Success, TaskHandlerState.Failed),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_result(
|
||||
cls: Type['PollCatalogCompleteResult'],
|
||||
cls: Type["PollCatalogCompleteResult"],
|
||||
base: RemoteCatalogResults,
|
||||
tags: TaskTags,
|
||||
timing: TaskTiming,
|
||||
logs: List[LogMessage],
|
||||
) -> 'PollCatalogCompleteResult':
|
||||
) -> "PollCatalogCompleteResult":
|
||||
return cls(
|
||||
nodes=base.nodes,
|
||||
sources=base.sources,
|
||||
@@ -706,27 +699,26 @@ class PollCatalogCompleteResult(RemoteCatalogResults, PollResult):
|
||||
|
||||
|
||||
@dataclass
|
||||
@schema_version('poll-remote-in-progress-result', 1)
|
||||
@schema_version("poll-remote-in-progress-result", 1)
|
||||
class PollInProgressResult(PollResult):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
@schema_version('poll-remote-get-manifest-result', 1)
|
||||
@schema_version("poll-remote-get-manifest-result", 1)
|
||||
class PollGetManifestResult(GetManifestResult, PollResult):
|
||||
state: TaskHandlerState = field(
|
||||
metadata=restrict_to(TaskHandlerState.Success,
|
||||
TaskHandlerState.Failed),
|
||||
metadata=restrict_to(TaskHandlerState.Success, TaskHandlerState.Failed),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_result(
|
||||
cls: Type['PollGetManifestResult'],
|
||||
cls: Type["PollGetManifestResult"],
|
||||
base: GetManifestResult,
|
||||
tags: TaskTags,
|
||||
timing: TaskTiming,
|
||||
logs: List[LogMessage],
|
||||
) -> 'PollGetManifestResult':
|
||||
) -> "PollGetManifestResult":
|
||||
return cls(
|
||||
manifest=base.manifest,
|
||||
logs=logs,
|
||||
@@ -739,21 +731,20 @@ class PollGetManifestResult(GetManifestResult, PollResult):
|
||||
|
||||
|
||||
@dataclass
|
||||
@schema_version('poll-remote-freshness-result', 1)
|
||||
@schema_version("poll-remote-freshness-result", 1)
|
||||
class PollFreshnessResult(RemoteFreshnessResult, PollResult):
|
||||
state: TaskHandlerState = field(
|
||||
metadata=restrict_to(TaskHandlerState.Success,
|
||||
TaskHandlerState.Failed),
|
||||
metadata=restrict_to(TaskHandlerState.Success, TaskHandlerState.Failed),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_result(
|
||||
cls: Type['PollFreshnessResult'],
|
||||
cls: Type["PollFreshnessResult"],
|
||||
base: RemoteFreshnessResult,
|
||||
tags: TaskTags,
|
||||
timing: TaskTiming,
|
||||
logs: List[LogMessage],
|
||||
) -> 'PollFreshnessResult':
|
||||
) -> "PollFreshnessResult":
|
||||
return cls(
|
||||
logs=logs,
|
||||
tags=tags,
|
||||
@@ -766,18 +757,19 @@ class PollFreshnessResult(RemoteFreshnessResult, PollResult):
|
||||
elapsed_time=base.elapsed_time,
|
||||
)
|
||||
|
||||
|
||||
# Manifest parsing types
|
||||
|
||||
|
||||
class ManifestStatus(StrEnum):
|
||||
Init = 'init'
|
||||
Compiling = 'compiling'
|
||||
Ready = 'ready'
|
||||
Error = 'error'
|
||||
Init = "init"
|
||||
Compiling = "compiling"
|
||||
Ready = "ready"
|
||||
Error = "error"
|
||||
|
||||
|
||||
@dataclass
|
||||
@schema_version('remote-status-result', 1)
|
||||
@schema_version("remote-status-result", 1)
|
||||
class LastParse(RemoteResult):
|
||||
state: ManifestStatus = ManifestStatus.Init
|
||||
logs: List[LogMessage] = field(default_factory=list)
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import List, Dict, Any, Union
|
||||
class SelectorDefinition(dbtClassMixin):
|
||||
name: str
|
||||
definition: Union[str, Dict[str, Any]]
|
||||
description: str = ''
|
||||
description: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -9,7 +9,7 @@ class PreviousState:
|
||||
self.path: Path = path
|
||||
self.manifest: Optional[WritableManifest] = None
|
||||
|
||||
manifest_path = self.path / 'manifest.json'
|
||||
manifest_path = self.path / "manifest.json"
|
||||
if manifest_path.exists() and manifest_path.is_file():
|
||||
try:
|
||||
self.manifest = WritableManifest.read(str(manifest_path))
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import dataclasses
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
List, Tuple, ClassVar, Type, TypeVar, Dict, Any, Optional
|
||||
)
|
||||
from typing import List, Tuple, ClassVar, Type, TypeVar, Dict, Any, Optional
|
||||
|
||||
from dbt.clients.system import write_json, read_json
|
||||
from dbt.exceptions import (
|
||||
@@ -57,9 +55,7 @@ class Mergeable(Replaceable):
|
||||
|
||||
class Writable:
|
||||
def write(self, path: str):
|
||||
write_json(
|
||||
path, self.to_dict(omit_none=False) # type: ignore
|
||||
)
|
||||
write_json(path, self.to_dict(omit_none=False)) # type: ignore
|
||||
|
||||
|
||||
class AdditionalPropertiesMixin:
|
||||
@@ -68,6 +64,7 @@ class AdditionalPropertiesMixin:
|
||||
The underlying class definition must include a type definition for a field
|
||||
named '_extra' that is of type `Dict[str, Any]`.
|
||||
"""
|
||||
|
||||
ADDITIONAL_PROPERTIES = True
|
||||
|
||||
# This takes attributes in the dictionary that are
|
||||
@@ -86,10 +83,10 @@ class AdditionalPropertiesMixin:
|
||||
cls_keys = cls._get_field_names()
|
||||
new_dict = {}
|
||||
for key, value in data.items():
|
||||
if key not in cls_keys and key != '_extra':
|
||||
if '_extra' not in new_dict:
|
||||
new_dict['_extra'] = {}
|
||||
new_dict['_extra'][key] = value
|
||||
if key not in cls_keys and key != "_extra":
|
||||
if "_extra" not in new_dict:
|
||||
new_dict["_extra"] = {}
|
||||
new_dict["_extra"][key] = value
|
||||
else:
|
||||
new_dict[key] = value
|
||||
data = new_dict
|
||||
@@ -99,8 +96,8 @@ class AdditionalPropertiesMixin:
|
||||
def __post_serialize__(self, dct):
|
||||
data = super().__post_serialize__(dct)
|
||||
data.update(self.extra)
|
||||
if '_extra' in data:
|
||||
del data['_extra']
|
||||
if "_extra" in data:
|
||||
del data["_extra"]
|
||||
return data
|
||||
|
||||
def replace(self, **kwargs):
|
||||
@@ -126,8 +123,8 @@ class Readable:
|
||||
return cls.from_dict(data) # type: ignore
|
||||
|
||||
|
||||
BASE_SCHEMAS_URL = 'https://schemas.getdbt.com/'
|
||||
SCHEMA_PATH = 'dbt/{name}/v{version}.json'
|
||||
BASE_SCHEMAS_URL = "https://schemas.getdbt.com/"
|
||||
SCHEMA_PATH = "dbt/{name}/v{version}.json"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -137,24 +134,22 @@ class SchemaVersion:
|
||||
|
||||
@property
|
||||
def path(self) -> str:
|
||||
return SCHEMA_PATH.format(
|
||||
name=self.name,
|
||||
version=self.version
|
||||
)
|
||||
return SCHEMA_PATH.format(name=self.name, version=self.version)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return BASE_SCHEMAS_URL + self.path
|
||||
|
||||
|
||||
SCHEMA_VERSION_KEY = 'dbt_schema_version'
|
||||
SCHEMA_VERSION_KEY = "dbt_schema_version"
|
||||
|
||||
|
||||
METADATA_ENV_PREFIX = 'DBT_ENV_CUSTOM_ENV_'
|
||||
METADATA_ENV_PREFIX = "DBT_ENV_CUSTOM_ENV_"
|
||||
|
||||
|
||||
def get_metadata_env() -> Dict[str, str]:
|
||||
return {
|
||||
k[len(METADATA_ENV_PREFIX):]: v for k, v in os.environ.items()
|
||||
k[len(METADATA_ENV_PREFIX) :]: v
|
||||
for k, v in os.environ.items()
|
||||
if k.startswith(METADATA_ENV_PREFIX)
|
||||
}
|
||||
|
||||
@@ -163,12 +158,8 @@ def get_metadata_env() -> Dict[str, str]:
|
||||
class BaseArtifactMetadata(dbtClassMixin):
|
||||
dbt_schema_version: str
|
||||
dbt_version: str = __version__
|
||||
generated_at: datetime = dataclasses.field(
|
||||
default_factory=datetime.utcnow
|
||||
)
|
||||
invocation_id: Optional[str] = dataclasses.field(
|
||||
default_factory=get_invocation_id
|
||||
)
|
||||
generated_at: datetime = dataclasses.field(default_factory=datetime.utcnow)
|
||||
invocation_id: Optional[str] = dataclasses.field(default_factory=get_invocation_id)
|
||||
env: Dict[str, str] = dataclasses.field(default_factory=get_metadata_env)
|
||||
|
||||
|
||||
@@ -179,6 +170,7 @@ def schema_version(name: str, version: int):
|
||||
version=version,
|
||||
)
|
||||
return cls
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
@@ -190,11 +182,11 @@ class VersionedSchema(dbtClassMixin):
|
||||
def json_schema(cls, embeddable: bool = False) -> Dict[str, Any]:
|
||||
result = super().json_schema(embeddable=embeddable)
|
||||
if not embeddable:
|
||||
result['$id'] = str(cls.dbt_schema_version)
|
||||
result["$id"] = str(cls.dbt_schema_version)
|
||||
return result
|
||||
|
||||
|
||||
T = TypeVar('T', bound='ArtifactMixin')
|
||||
T = TypeVar("T", bound="ArtifactMixin")
|
||||
|
||||
|
||||
# metadata should really be a Generic[T_M] where T_M is a TypeVar bound to
|
||||
@@ -208,6 +200,4 @@ class ArtifactMixin(VersionedSchema, Writable, Readable):
|
||||
def validate(cls, data):
|
||||
super().validate(data)
|
||||
if cls.dbt_schema_version is None:
|
||||
raise InternalException(
|
||||
'Cannot call from_dict with no schema version!'
|
||||
)
|
||||
raise InternalException("Cannot call from_dict with no schema version!")
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from typing import (
|
||||
Type, ClassVar, cast,
|
||||
Type,
|
||||
ClassVar,
|
||||
cast,
|
||||
)
|
||||
import re
|
||||
from dataclasses import fields
|
||||
@@ -11,9 +13,7 @@ from hologram import JsonSchemaMixin, FieldEncoder, ValidationError
|
||||
|
||||
# type: ignore
|
||||
from mashumaro import DataClassDictMixin
|
||||
from mashumaro.config import (
|
||||
TO_DICT_ADD_OMIT_NONE_FLAG, BaseConfig as MashBaseConfig
|
||||
)
|
||||
from mashumaro.config import TO_DICT_ADD_OMIT_NONE_FLAG, BaseConfig as MashBaseConfig
|
||||
from mashumaro.types import SerializableType, SerializationStrategy
|
||||
|
||||
|
||||
@@ -26,9 +26,7 @@ class DateTimeSerialization(SerializationStrategy):
|
||||
return out
|
||||
|
||||
def deserialize(self, value):
|
||||
return (
|
||||
value if isinstance(value, datetime) else parse(cast(str, value))
|
||||
)
|
||||
return value if isinstance(value, datetime) else parse(cast(str, value))
|
||||
|
||||
|
||||
# This class pulls in both JsonSchemaMixin from Hologram and
|
||||
@@ -60,8 +58,8 @@ class dbtClassMixin(DataClassDictMixin, JsonSchemaMixin):
|
||||
if self._hyphenated:
|
||||
new_dict = {}
|
||||
for key in dct:
|
||||
if '_' in key:
|
||||
new_key = key.replace('_', '-')
|
||||
if "_" in key:
|
||||
new_key = key.replace("_", "-")
|
||||
new_dict[new_key] = dct[key]
|
||||
else:
|
||||
new_dict[key] = dct[key]
|
||||
@@ -76,8 +74,8 @@ class dbtClassMixin(DataClassDictMixin, JsonSchemaMixin):
|
||||
if cls._hyphenated:
|
||||
new_dict = {}
|
||||
for key in data:
|
||||
if '-' in key:
|
||||
new_key = key.replace('-', '_')
|
||||
if "-" in key:
|
||||
new_key = key.replace("-", "_")
|
||||
new_dict[new_key] = data[key]
|
||||
else:
|
||||
new_dict[key] = data[key]
|
||||
@@ -89,16 +87,16 @@ class dbtClassMixin(DataClassDictMixin, JsonSchemaMixin):
|
||||
# hologram and in mashumaro.
|
||||
def _local_to_dict(self, **kwargs):
|
||||
args = {}
|
||||
if 'omit_none' in kwargs:
|
||||
args['omit_none'] = kwargs['omit_none']
|
||||
if "omit_none" in kwargs:
|
||||
args["omit_none"] = kwargs["omit_none"]
|
||||
return self.to_dict(**args)
|
||||
|
||||
|
||||
class ValidatedStringMixin(str, SerializableType):
|
||||
ValidationRegex = ''
|
||||
ValidationRegex = ""
|
||||
|
||||
@classmethod
|
||||
def _deserialize(cls, value: str) -> 'ValidatedStringMixin':
|
||||
def _deserialize(cls, value: str) -> "ValidatedStringMixin":
|
||||
cls.validate(value)
|
||||
return ValidatedStringMixin(value)
|
||||
|
||||
|
||||
@@ -14,39 +14,31 @@ class DBTDeprecation:
|
||||
def name(self) -> str:
|
||||
if self._name is not None:
|
||||
return self._name
|
||||
raise NotImplementedError(
|
||||
'name not implemented for {}'.format(self)
|
||||
)
|
||||
raise NotImplementedError("name not implemented for {}".format(self))
|
||||
|
||||
def track_deprecation_warn(self) -> None:
|
||||
if dbt.tracking.active_user is not None:
|
||||
dbt.tracking.track_deprecation_warn({
|
||||
"deprecation_name": self.name
|
||||
})
|
||||
dbt.tracking.track_deprecation_warn({"deprecation_name": self.name})
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
if self._description is not None:
|
||||
return self._description
|
||||
raise NotImplementedError(
|
||||
'description not implemented for {}'.format(self)
|
||||
)
|
||||
raise NotImplementedError("description not implemented for {}".format(self))
|
||||
|
||||
def show(self, *args, **kwargs) -> None:
|
||||
if self.name not in active_deprecations:
|
||||
desc = self.description.format(**kwargs)
|
||||
msg = ui.line_wrap_message(
|
||||
desc, prefix='* Deprecation Warning: '
|
||||
)
|
||||
msg = ui.line_wrap_message(desc, prefix="* Deprecation Warning: ")
|
||||
dbt.exceptions.warn_or_error(msg)
|
||||
self.track_deprecation_warn()
|
||||
active_deprecations.add(self.name)
|
||||
|
||||
|
||||
class MaterializationReturnDeprecation(DBTDeprecation):
|
||||
_name = 'materialization-return'
|
||||
_name = "materialization-return"
|
||||
|
||||
_description = '''\
|
||||
_description = """\
|
||||
The materialization ("{materialization}") did not explicitly return a list
|
||||
of relations to add to the cache. By default the target relation will be
|
||||
added, but this behavior will be removed in a future version of dbt.
|
||||
@@ -56,22 +48,22 @@ class MaterializationReturnDeprecation(DBTDeprecation):
|
||||
For more information, see:
|
||||
|
||||
https://docs.getdbt.com/v0.15/docs/creating-new-materializations#section-6-returning-relations
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
class NotADictionaryDeprecation(DBTDeprecation):
|
||||
_name = 'not-a-dictionary'
|
||||
_name = "not-a-dictionary"
|
||||
|
||||
_description = '''\
|
||||
_description = """\
|
||||
The object ("{obj}") was used as a dictionary. In a future version of dbt
|
||||
this capability will be removed from objects of this type.
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
class ColumnQuotingDeprecation(DBTDeprecation):
|
||||
_name = 'column-quoting-unset'
|
||||
_name = "column-quoting-unset"
|
||||
|
||||
_description = '''\
|
||||
_description = """\
|
||||
The quote_columns parameter was not set for seeds, so the default value of
|
||||
False was chosen. The default will change to True in a future release.
|
||||
|
||||
@@ -80,13 +72,13 @@ class ColumnQuotingDeprecation(DBTDeprecation):
|
||||
For more information, see:
|
||||
|
||||
https://docs.getdbt.com/v0.15/docs/seeds#section-specify-column-quoting
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
class ModelsKeyNonModelDeprecation(DBTDeprecation):
|
||||
_name = 'models-key-mismatch'
|
||||
_name = "models-key-mismatch"
|
||||
|
||||
_description = '''\
|
||||
_description = """\
|
||||
"{node.name}" is a {node.resource_type} node, but it is specified in
|
||||
the {patch.yaml_key} section of {patch.original_file_path}.
|
||||
|
||||
@@ -96,25 +88,25 @@ class ModelsKeyNonModelDeprecation(DBTDeprecation):
|
||||
the {expected_key} key instead.
|
||||
|
||||
This warning will become an error in a future release.
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
class ExecuteMacrosReleaseDeprecation(DBTDeprecation):
|
||||
_name = 'execute-macro-release'
|
||||
_description = '''\
|
||||
_name = "execute-macro-release"
|
||||
_description = """\
|
||||
The "release" argument to execute_macro is now ignored, and will be removed
|
||||
in a future relase of dbt. At that time, providing a `release` argument
|
||||
will result in an error.
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
class AdapterMacroDeprecation(DBTDeprecation):
|
||||
_name = 'adapter-macro'
|
||||
_description = '''\
|
||||
_name = "adapter-macro"
|
||||
_description = """\
|
||||
The "adapter_macro" macro has been deprecated. Instead, use the
|
||||
`adapter.dispatch` method to find a macro and call the result.
|
||||
adapter_macro was called for: {macro_name}
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
_adapter_renamed_description = """\
|
||||
@@ -128,11 +120,11 @@ Documentation for {new_name} can be found here:
|
||||
|
||||
|
||||
def renamed_method(old_name: str, new_name: str):
|
||||
|
||||
class AdapterDeprecationWarning(DBTDeprecation):
|
||||
_name = 'adapter:{}'.format(old_name)
|
||||
_description = _adapter_renamed_description.format(old_name=old_name,
|
||||
new_name=new_name)
|
||||
_name = "adapter:{}".format(old_name)
|
||||
_description = _adapter_renamed_description.format(
|
||||
old_name=old_name, new_name=new_name
|
||||
)
|
||||
|
||||
dep = AdapterDeprecationWarning()
|
||||
deprecations_list.append(dep)
|
||||
@@ -142,9 +134,7 @@ def renamed_method(old_name: str, new_name: str):
|
||||
def warn(name, *args, **kwargs):
|
||||
if name not in deprecations:
|
||||
# this should (hopefully) never happen
|
||||
raise RuntimeError(
|
||||
"Error showing deprecation warning: {}".format(name)
|
||||
)
|
||||
raise RuntimeError("Error showing deprecation warning: {}".format(name))
|
||||
|
||||
deprecations[name].show(*args, **kwargs)
|
||||
|
||||
@@ -163,9 +153,7 @@ deprecations_list: List[DBTDeprecation] = [
|
||||
AdapterMacroDeprecation(),
|
||||
]
|
||||
|
||||
deprecations: Dict[str, DBTDeprecation] = {
|
||||
d.name: d for d in deprecations_list
|
||||
}
|
||||
deprecations: Dict[str, DBTDeprecation] = {d.name: d for d in deprecations_list}
|
||||
|
||||
|
||||
def reset_deprecations():
|
||||
|
||||
@@ -22,12 +22,12 @@ def downloads_directory():
|
||||
# the user might have set an environment variable. Set it to that, and do
|
||||
# not remove it when finished.
|
||||
if DOWNLOADS_PATH is None:
|
||||
DOWNLOADS_PATH = os.getenv('DBT_DOWNLOADS_DIR')
|
||||
DOWNLOADS_PATH = os.getenv("DBT_DOWNLOADS_DIR")
|
||||
remove_downloads = False
|
||||
# if we are making a per-run temp directory, remove it at the end of
|
||||
# successful runs
|
||||
if DOWNLOADS_PATH is None:
|
||||
DOWNLOADS_PATH = tempfile.mkdtemp(prefix='dbt-downloads-')
|
||||
DOWNLOADS_PATH = tempfile.mkdtemp(prefix="dbt-downloads-")
|
||||
remove_downloads = True
|
||||
|
||||
system.make_directory(DOWNLOADS_PATH)
|
||||
@@ -62,7 +62,7 @@ class PinnedPackage(BasePackage):
|
||||
if not version:
|
||||
return self.name
|
||||
|
||||
return '{}@{}'.format(self.name, version)
|
||||
return "{}@{}".format(self.name, version)
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_version(self) -> Optional[str]:
|
||||
@@ -94,8 +94,8 @@ class PinnedPackage(BasePackage):
|
||||
return os.path.join(project.modules_path, dest_dirname)
|
||||
|
||||
|
||||
SomePinned = TypeVar('SomePinned', bound=PinnedPackage)
|
||||
SomeUnpinned = TypeVar('SomeUnpinned', bound='UnpinnedPackage')
|
||||
SomePinned = TypeVar("SomePinned", bound=PinnedPackage)
|
||||
SomeUnpinned = TypeVar("SomeUnpinned", bound="UnpinnedPackage")
|
||||
|
||||
|
||||
class UnpinnedPackage(Generic[SomePinned], BasePackage):
|
||||
@@ -8,18 +8,16 @@ from dbt.contracts.project import (
|
||||
ProjectPackageMetadata,
|
||||
GitPackage,
|
||||
)
|
||||
from dbt.deps.base import PinnedPackage, UnpinnedPackage, get_downloads_path
|
||||
from dbt.exceptions import (
|
||||
ExecutableError, warn_or_error, raise_dependency_error
|
||||
)
|
||||
from dbt.deps import PinnedPackage, UnpinnedPackage, get_downloads_path
|
||||
from dbt.exceptions import ExecutableError, warn_or_error, raise_dependency_error
|
||||
from dbt.logger import GLOBAL_LOGGER as logger
|
||||
from dbt import ui
|
||||
|
||||
PIN_PACKAGE_URL = 'https://docs.getdbt.com/docs/package-management#section-specifying-package-versions' # noqa
|
||||
PIN_PACKAGE_URL = "https://docs.getdbt.com/docs/package-management#section-specifying-package-versions" # noqa
|
||||
|
||||
|
||||
def md5sum(s: str):
|
||||
return hashlib.md5(s.encode('latin-1')).hexdigest()
|
||||
return hashlib.md5(s.encode("latin-1")).hexdigest()
|
||||
|
||||
|
||||
class GitPackageMixin:
|
||||
@@ -32,13 +30,11 @@ class GitPackageMixin:
|
||||
return self.git
|
||||
|
||||
def source_type(self) -> str:
|
||||
return 'git'
|
||||
return "git"
|
||||
|
||||
|
||||
class GitPinnedPackage(GitPackageMixin, PinnedPackage):
|
||||
def __init__(
|
||||
self, git: str, revision: str, warn_unpinned: bool = True
|
||||
) -> None:
|
||||
def __init__(self, git: str, revision: str, warn_unpinned: bool = True) -> None:
|
||||
super().__init__(git)
|
||||
self.revision = revision
|
||||
self.warn_unpinned = warn_unpinned
|
||||
@@ -48,15 +44,15 @@ class GitPinnedPackage(GitPackageMixin, PinnedPackage):
|
||||
return self.revision
|
||||
|
||||
def nice_version_name(self):
|
||||
if self.revision == 'HEAD':
|
||||
return 'HEAD (default branch)'
|
||||
if self.revision == "HEAD":
|
||||
return "HEAD (default branch)"
|
||||
else:
|
||||
return 'revision {}'.format(self.revision)
|
||||
return "revision {}".format(self.revision)
|
||||
|
||||
def unpinned_msg(self):
|
||||
if self.revision == 'HEAD':
|
||||
return 'not pinned, using HEAD (default branch)'
|
||||
elif self.revision in ('main', 'master'):
|
||||
if self.revision == "HEAD":
|
||||
return "not pinned, using HEAD (default branch)"
|
||||
elif self.revision in ("main", "master"):
|
||||
return f'pinned to the "{self.revision}" branch'
|
||||
else:
|
||||
return None
|
||||
@@ -68,15 +64,17 @@ class GitPinnedPackage(GitPackageMixin, PinnedPackage):
|
||||
the path to the checked out directory."""
|
||||
try:
|
||||
dir_ = git.clone_and_checkout(
|
||||
self.git, get_downloads_path(), branch=self.revision,
|
||||
dirname=self._checkout_name
|
||||
self.git,
|
||||
get_downloads_path(),
|
||||
branch=self.revision,
|
||||
dirname=self._checkout_name,
|
||||
)
|
||||
except ExecutableError as exc:
|
||||
if exc.cmd and exc.cmd[0] == 'git':
|
||||
if exc.cmd and exc.cmd[0] == "git":
|
||||
logger.error(
|
||||
'Make sure git is installed on your machine. More '
|
||||
'information: '
|
||||
'https://docs.getdbt.com/docs/package-management'
|
||||
"Make sure git is installed on your machine. More "
|
||||
"information: "
|
||||
"https://docs.getdbt.com/docs/package-management"
|
||||
)
|
||||
raise
|
||||
return os.path.join(get_downloads_path(), dir_)
|
||||
@@ -87,9 +85,10 @@ class GitPinnedPackage(GitPackageMixin, PinnedPackage):
|
||||
if self.unpinned_msg() and self.warn_unpinned:
|
||||
warn_or_error(
|
||||
'The git package "{}" \n\tis {}.\n\tThis can introduce '
|
||||
'breaking changes into your project without warning!\n\nSee {}'
|
||||
.format(self.git, self.unpinned_msg(), PIN_PACKAGE_URL),
|
||||
log_fmt=ui.yellow('WARNING: {}')
|
||||
"breaking changes into your project without warning!\n\nSee {}".format(
|
||||
self.git, self.unpinned_msg(), PIN_PACKAGE_URL
|
||||
),
|
||||
log_fmt=ui.yellow("WARNING: {}"),
|
||||
)
|
||||
loaded = Project.from_project_root(path, renderer)
|
||||
return ProjectPackageMetadata.from_project(loaded)
|
||||
@@ -114,26 +113,21 @@ class GitUnpinnedPackage(GitPackageMixin, UnpinnedPackage[GitPinnedPackage]):
|
||||
self.warn_unpinned = warn_unpinned
|
||||
|
||||
@classmethod
|
||||
def from_contract(
|
||||
cls, contract: GitPackage
|
||||
) -> 'GitUnpinnedPackage':
|
||||
def from_contract(cls, contract: GitPackage) -> "GitUnpinnedPackage":
|
||||
revisions = contract.get_revisions()
|
||||
|
||||
# we want to map None -> True
|
||||
warn_unpinned = contract.warn_unpinned is not False
|
||||
return cls(git=contract.git, revisions=revisions,
|
||||
warn_unpinned=warn_unpinned)
|
||||
return cls(git=contract.git, revisions=revisions, warn_unpinned=warn_unpinned)
|
||||
|
||||
def all_names(self) -> List[str]:
|
||||
if self.git.endswith('.git'):
|
||||
if self.git.endswith(".git"):
|
||||
other = self.git[:-4]
|
||||
else:
|
||||
other = self.git + '.git'
|
||||
other = self.git + ".git"
|
||||
return [self.git, other]
|
||||
|
||||
def incorporate(
|
||||
self, other: 'GitUnpinnedPackage'
|
||||
) -> 'GitUnpinnedPackage':
|
||||
def incorporate(self, other: "GitUnpinnedPackage") -> "GitUnpinnedPackage":
|
||||
warn_unpinned = self.warn_unpinned and other.warn_unpinned
|
||||
|
||||
return GitUnpinnedPackage(
|
||||
@@ -145,13 +139,13 @@ class GitUnpinnedPackage(GitPackageMixin, UnpinnedPackage[GitPinnedPackage]):
|
||||
def resolved(self) -> GitPinnedPackage:
|
||||
requested = set(self.revisions)
|
||||
if len(requested) == 0:
|
||||
requested = {'HEAD'}
|
||||
requested = {"HEAD"}
|
||||
elif len(requested) > 1:
|
||||
raise_dependency_error(
|
||||
'git dependencies should contain exactly one version. '
|
||||
'{} contains: {}'.format(self.git, requested))
|
||||
"git dependencies should contain exactly one version. "
|
||||
"{} contains: {}".format(self.git, requested)
|
||||
)
|
||||
|
||||
return GitPinnedPackage(
|
||||
git=self.git, revision=requested.pop(),
|
||||
warn_unpinned=self.warn_unpinned
|
||||
git=self.git, revision=requested.pop(), warn_unpinned=self.warn_unpinned
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import shutil
|
||||
|
||||
from dbt.clients import system
|
||||
from dbt.deps.base import PinnedPackage, UnpinnedPackage
|
||||
from dbt.deps import PinnedPackage, UnpinnedPackage
|
||||
from dbt.contracts.project import (
|
||||
ProjectPackageMetadata,
|
||||
LocalPackage,
|
||||
@@ -19,7 +19,7 @@ class LocalPackageMixin:
|
||||
return self.local
|
||||
|
||||
def source_type(self):
|
||||
return 'local'
|
||||
return "local"
|
||||
|
||||
|
||||
class LocalPinnedPackage(LocalPackageMixin, PinnedPackage):
|
||||
@@ -30,7 +30,7 @@ class LocalPinnedPackage(LocalPackageMixin, PinnedPackage):
|
||||
return None
|
||||
|
||||
def nice_version_name(self):
|
||||
return '<local @ {}>'.format(self.local)
|
||||
return "<local @ {}>".format(self.local)
|
||||
|
||||
def resolve_path(self, project):
|
||||
return system.resolve_path_from_base(
|
||||
@@ -39,9 +39,7 @@ class LocalPinnedPackage(LocalPackageMixin, PinnedPackage):
|
||||
)
|
||||
|
||||
def _fetch_metadata(self, project, renderer):
|
||||
loaded = project.from_project_root(
|
||||
self.resolve_path(project), renderer
|
||||
)
|
||||
loaded = project.from_project_root(self.resolve_path(project), renderer)
|
||||
return ProjectPackageMetadata.from_project(loaded)
|
||||
|
||||
def install(self, project, renderer):
|
||||
@@ -57,27 +55,22 @@ class LocalPinnedPackage(LocalPackageMixin, PinnedPackage):
|
||||
system.remove_file(dest_path)
|
||||
|
||||
if can_create_symlink:
|
||||
logger.debug(' Creating symlink to local dependency.')
|
||||
logger.debug(" Creating symlink to local dependency.")
|
||||
system.make_symlink(src_path, dest_path)
|
||||
|
||||
else:
|
||||
logger.debug(' Symlinks are not available on this '
|
||||
'OS, copying dependency.')
|
||||
logger.debug(
|
||||
" Symlinks are not available on this " "OS, copying dependency."
|
||||
)
|
||||
shutil.copytree(src_path, dest_path)
|
||||
|
||||
|
||||
class LocalUnpinnedPackage(
|
||||
LocalPackageMixin, UnpinnedPackage[LocalPinnedPackage]
|
||||
):
|
||||
class LocalUnpinnedPackage(LocalPackageMixin, UnpinnedPackage[LocalPinnedPackage]):
|
||||
@classmethod
|
||||
def from_contract(
|
||||
cls, contract: LocalPackage
|
||||
) -> 'LocalUnpinnedPackage':
|
||||
def from_contract(cls, contract: LocalPackage) -> "LocalUnpinnedPackage":
|
||||
return cls(local=contract.local)
|
||||
|
||||
def incorporate(
|
||||
self, other: 'LocalUnpinnedPackage'
|
||||
) -> 'LocalUnpinnedPackage':
|
||||
def incorporate(self, other: "LocalUnpinnedPackage") -> "LocalUnpinnedPackage":
|
||||
return LocalUnpinnedPackage(local=self.local)
|
||||
|
||||
def resolved(self) -> LocalPinnedPackage:
|
||||
|
||||
@@ -7,7 +7,7 @@ from dbt.contracts.project import (
|
||||
RegistryPackageMetadata,
|
||||
RegistryPackage,
|
||||
)
|
||||
from dbt.deps.base import PinnedPackage, UnpinnedPackage, get_downloads_path
|
||||
from dbt.deps import PinnedPackage, UnpinnedPackage, get_downloads_path
|
||||
from dbt.exceptions import (
|
||||
package_version_not_found,
|
||||
VersionsNotCompatibleException,
|
||||
@@ -26,7 +26,7 @@ class RegistryPackageMixin:
|
||||
return self.package
|
||||
|
||||
def source_type(self) -> str:
|
||||
return 'hub'
|
||||
return "hub"
|
||||
|
||||
|
||||
class RegistryPinnedPackage(RegistryPackageMixin, PinnedPackage):
|
||||
@@ -39,13 +39,13 @@ class RegistryPinnedPackage(RegistryPackageMixin, PinnedPackage):
|
||||
return self.package
|
||||
|
||||
def source_type(self):
|
||||
return 'hub'
|
||||
return "hub"
|
||||
|
||||
def get_version(self):
|
||||
return self.version
|
||||
|
||||
def nice_version_name(self):
|
||||
return 'version {}'.format(self.version)
|
||||
return "version {}".format(self.version)
|
||||
|
||||
def _fetch_metadata(self, project, renderer) -> RegistryPackageMetadata:
|
||||
dct = registry.package_version(self.package, self.version)
|
||||
@@ -54,10 +54,8 @@ class RegistryPinnedPackage(RegistryPackageMixin, PinnedPackage):
|
||||
def install(self, project, renderer):
|
||||
metadata = self.fetch_metadata(project, renderer)
|
||||
|
||||
tar_name = '{}.{}.tar.gz'.format(self.package, self.version)
|
||||
tar_path = os.path.realpath(
|
||||
os.path.join(get_downloads_path(), tar_name)
|
||||
)
|
||||
tar_name = "{}.{}.tar.gz".format(self.package, self.version)
|
||||
tar_path = os.path.realpath(os.path.join(get_downloads_path(), tar_name))
|
||||
system.make_directory(os.path.dirname(tar_path))
|
||||
|
||||
download_url = metadata.downloads.tarball
|
||||
@@ -70,9 +68,7 @@ class RegistryPinnedPackage(RegistryPackageMixin, PinnedPackage):
|
||||
class RegistryUnpinnedPackage(
|
||||
RegistryPackageMixin, UnpinnedPackage[RegistryPinnedPackage]
|
||||
):
|
||||
def __init__(
|
||||
self, package: str, versions: List[semver.VersionSpecifier]
|
||||
) -> None:
|
||||
def __init__(self, package: str, versions: List[semver.VersionSpecifier]) -> None:
|
||||
super().__init__(package)
|
||||
self.versions = versions
|
||||
|
||||
@@ -82,20 +78,15 @@ class RegistryUnpinnedPackage(
|
||||
package_not_found(self.package)
|
||||
|
||||
@classmethod
|
||||
def from_contract(
|
||||
cls, contract: RegistryPackage
|
||||
) -> 'RegistryUnpinnedPackage':
|
||||
def from_contract(cls, contract: RegistryPackage) -> "RegistryUnpinnedPackage":
|
||||
raw_version = contract.get_versions()
|
||||
|
||||
versions = [
|
||||
semver.VersionSpecifier.from_version_string(v)
|
||||
for v in raw_version
|
||||
]
|
||||
versions = [semver.VersionSpecifier.from_version_string(v) for v in raw_version]
|
||||
return cls(package=contract.package, versions=versions)
|
||||
|
||||
def incorporate(
|
||||
self, other: 'RegistryUnpinnedPackage'
|
||||
) -> 'RegistryUnpinnedPackage':
|
||||
self, other: "RegistryUnpinnedPackage"
|
||||
) -> "RegistryUnpinnedPackage":
|
||||
return RegistryUnpinnedPackage(
|
||||
package=self.package,
|
||||
versions=self.versions + other.versions,
|
||||
@@ -106,8 +97,7 @@ class RegistryUnpinnedPackage(
|
||||
try:
|
||||
range_ = semver.reduce_versions(*self.versions)
|
||||
except VersionsNotCompatibleException as e:
|
||||
new_msg = ('Version error for package {}: {}'
|
||||
.format(self.name, e))
|
||||
new_msg = "Version error for package {}: {}".format(self.name, e)
|
||||
raise DependencyException(new_msg) from e
|
||||
|
||||
available = registry.get_available_versions(self.package)
|
||||
|
||||
@@ -6,7 +6,7 @@ from dbt.exceptions import raise_dependency_error, InternalException
|
||||
from dbt.context.target import generate_target_context
|
||||
from dbt.config import Project, RuntimeConfig
|
||||
from dbt.config.renderer import DbtProjectYamlRenderer
|
||||
from dbt.deps.base import BasePackage, PinnedPackage, UnpinnedPackage
|
||||
from dbt.deps import BasePackage, PinnedPackage, UnpinnedPackage
|
||||
from dbt.deps.local import LocalUnpinnedPackage
|
||||
from dbt.deps.git import GitUnpinnedPackage
|
||||
from dbt.deps.registry import RegistryUnpinnedPackage
|
||||
@@ -49,12 +49,10 @@ class PackageListing:
|
||||
key_str: str = self._pick_key(key)
|
||||
self.packages[key_str] = value
|
||||
|
||||
def _mismatched_types(
|
||||
self, old: UnpinnedPackage, new: UnpinnedPackage
|
||||
) -> NoReturn:
|
||||
def _mismatched_types(self, old: UnpinnedPackage, new: UnpinnedPackage) -> NoReturn:
|
||||
raise_dependency_error(
|
||||
f'Cannot incorporate {new} ({new.__class__.__name__}) in {old} '
|
||||
f'({old.__class__.__name__}): mismatched types'
|
||||
f"Cannot incorporate {new} ({new.__class__.__name__}) in {old} "
|
||||
f"({old.__class__.__name__}): mismatched types"
|
||||
)
|
||||
|
||||
def incorporate(self, package: UnpinnedPackage):
|
||||
@@ -78,14 +76,14 @@ class PackageListing:
|
||||
pkg = RegistryUnpinnedPackage.from_contract(contract)
|
||||
else:
|
||||
raise InternalException(
|
||||
'Invalid package type {}'.format(type(contract))
|
||||
"Invalid package type {}".format(type(contract))
|
||||
)
|
||||
self.incorporate(pkg)
|
||||
|
||||
@classmethod
|
||||
def from_contracts(
|
||||
cls: Type['PackageListing'], src: List[PackageContract]
|
||||
) -> 'PackageListing':
|
||||
cls: Type["PackageListing"], src: List[PackageContract]
|
||||
) -> "PackageListing":
|
||||
self = cls({})
|
||||
self.update_from(src)
|
||||
return self
|
||||
@@ -108,14 +106,14 @@ def _check_for_duplicate_project_names(
|
||||
if project_name in seen:
|
||||
raise_dependency_error(
|
||||
f'Found duplicate project "{project_name}". This occurs when '
|
||||
'a dependency has the same project name as some other '
|
||||
'dependency.'
|
||||
"a dependency has the same project name as some other "
|
||||
"dependency."
|
||||
)
|
||||
elif project_name == config.project_name:
|
||||
raise_dependency_error(
|
||||
'Found a dependency with the same name as the root project '
|
||||
"Found a dependency with the same name as the root project "
|
||||
f'"{project_name}". Package names must be unique in a project.'
|
||||
' Please rename one of these packages.'
|
||||
" Please rename one of these packages."
|
||||
)
|
||||
seen.add(project_name)
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import multiprocessing
|
||||
if os.name != 'nt':
|
||||
|
||||
if os.name != "nt":
|
||||
# https://bugs.python.org/issue41567
|
||||
import multiprocessing.popen_spawn_posix # type: ignore
|
||||
from pathlib import Path
|
||||
@@ -23,7 +24,7 @@ def env_set_truthy(key: str) -> Optional[str]:
|
||||
otherwise.
|
||||
"""
|
||||
value = os.getenv(key)
|
||||
if not value or value.lower() in ('0', 'false', 'f'):
|
||||
if not value or value.lower() in ("0", "false", "f"):
|
||||
return None
|
||||
return value
|
||||
|
||||
@@ -36,24 +37,23 @@ def env_set_path(key: str) -> Optional[Path]:
|
||||
return Path(value)
|
||||
|
||||
|
||||
SINGLE_THREADED_WEBSERVER = env_set_truthy('DBT_SINGLE_THREADED_WEBSERVER')
|
||||
SINGLE_THREADED_HANDLER = env_set_truthy('DBT_SINGLE_THREADED_HANDLER')
|
||||
MACRO_DEBUGGING = env_set_truthy('DBT_MACRO_DEBUGGING')
|
||||
DEFER_MODE = env_set_truthy('DBT_DEFER_TO_STATE')
|
||||
ARTIFACT_STATE_PATH = env_set_path('DBT_ARTIFACT_STATE_PATH')
|
||||
SINGLE_THREADED_WEBSERVER = env_set_truthy("DBT_SINGLE_THREADED_WEBSERVER")
|
||||
SINGLE_THREADED_HANDLER = env_set_truthy("DBT_SINGLE_THREADED_HANDLER")
|
||||
MACRO_DEBUGGING = env_set_truthy("DBT_MACRO_DEBUGGING")
|
||||
DEFER_MODE = env_set_truthy("DBT_DEFER_TO_STATE")
|
||||
ARTIFACT_STATE_PATH = env_set_path("DBT_ARTIFACT_STATE_PATH")
|
||||
|
||||
|
||||
def _get_context():
|
||||
# TODO: change this back to use fork() on linux when we have made that safe
|
||||
return multiprocessing.get_context('spawn')
|
||||
return multiprocessing.get_context("spawn")
|
||||
|
||||
|
||||
MP_CONTEXT = _get_context()
|
||||
|
||||
|
||||
def reset():
|
||||
global STRICT_MODE, FULL_REFRESH, USE_CACHE, WARN_ERROR, TEST_NEW_PARSER, \
|
||||
WRITE_JSON, PARTIAL_PARSE, MP_CONTEXT, USE_COLORS
|
||||
global STRICT_MODE, FULL_REFRESH, USE_CACHE, WARN_ERROR, TEST_NEW_PARSER, WRITE_JSON, PARTIAL_PARSE, MP_CONTEXT, USE_COLORS
|
||||
|
||||
STRICT_MODE = False
|
||||
FULL_REFRESH = False
|
||||
@@ -67,26 +67,22 @@ def reset():
|
||||
|
||||
|
||||
def set_from_args(args):
|
||||
global STRICT_MODE, FULL_REFRESH, USE_CACHE, WARN_ERROR, TEST_NEW_PARSER, \
|
||||
WRITE_JSON, PARTIAL_PARSE, MP_CONTEXT, USE_COLORS
|
||||
global STRICT_MODE, FULL_REFRESH, USE_CACHE, WARN_ERROR, TEST_NEW_PARSER, WRITE_JSON, PARTIAL_PARSE, MP_CONTEXT, USE_COLORS
|
||||
|
||||
USE_CACHE = getattr(args, 'use_cache', USE_CACHE)
|
||||
USE_CACHE = getattr(args, "use_cache", USE_CACHE)
|
||||
|
||||
FULL_REFRESH = getattr(args, 'full_refresh', FULL_REFRESH)
|
||||
STRICT_MODE = getattr(args, 'strict', STRICT_MODE)
|
||||
WARN_ERROR = (
|
||||
STRICT_MODE or
|
||||
getattr(args, 'warn_error', STRICT_MODE or WARN_ERROR)
|
||||
)
|
||||
FULL_REFRESH = getattr(args, "full_refresh", FULL_REFRESH)
|
||||
STRICT_MODE = getattr(args, "strict", STRICT_MODE)
|
||||
WARN_ERROR = STRICT_MODE or getattr(args, "warn_error", STRICT_MODE or WARN_ERROR)
|
||||
|
||||
TEST_NEW_PARSER = getattr(args, 'test_new_parser', TEST_NEW_PARSER)
|
||||
WRITE_JSON = getattr(args, 'write_json', WRITE_JSON)
|
||||
PARTIAL_PARSE = getattr(args, 'partial_parse', None)
|
||||
TEST_NEW_PARSER = getattr(args, "test_new_parser", TEST_NEW_PARSER)
|
||||
WRITE_JSON = getattr(args, "write_json", WRITE_JSON)
|
||||
PARTIAL_PARSE = getattr(args, "partial_parse", None)
|
||||
MP_CONTEXT = _get_context()
|
||||
|
||||
# The use_colors attribute will always have a value because it is assigned
|
||||
# None by default from the add_mutually_exclusive_group function
|
||||
use_colors_override = getattr(args, 'use_colors')
|
||||
use_colors_override = getattr(args, "use_colors")
|
||||
|
||||
if use_colors_override is not None:
|
||||
USE_COLORS = use_colors_override
|
||||
|
||||
@@ -2,9 +2,7 @@
|
||||
import itertools
|
||||
from dbt.clients.yaml_helper import yaml, Loader, Dumper # noqa: F401
|
||||
|
||||
from typing import (
|
||||
Dict, List, Optional, Tuple, Any, Union
|
||||
)
|
||||
from typing import Dict, List, Optional, Tuple, Any, Union
|
||||
|
||||
from dbt.contracts.selection import SelectorDefinition, SelectorFile
|
||||
from dbt.exceptions import InternalException, ValidationException
|
||||
@@ -17,21 +15,17 @@ from .selector_spec import (
|
||||
SelectionCriteria,
|
||||
)
|
||||
|
||||
INTERSECTION_DELIMITER = ','
|
||||
INTERSECTION_DELIMITER = ","
|
||||
|
||||
DEFAULT_INCLUDES: List[str] = ['fqn:*', 'source:*', 'exposure:*']
|
||||
DEFAULT_INCLUDES: List[str] = ["fqn:*", "source:*", "exposure:*"]
|
||||
DEFAULT_EXCLUDES: List[str] = []
|
||||
DATA_TEST_SELECTOR: str = 'test_type:data'
|
||||
SCHEMA_TEST_SELECTOR: str = 'test_type:schema'
|
||||
DATA_TEST_SELECTOR: str = "test_type:data"
|
||||
SCHEMA_TEST_SELECTOR: str = "test_type:schema"
|
||||
|
||||
|
||||
def parse_union(
|
||||
components: List[str], expect_exists: bool
|
||||
) -> SelectionUnion:
|
||||
def parse_union(components: List[str], expect_exists: bool) -> SelectionUnion:
|
||||
# turn ['a b', 'c'] -> ['a', 'b', 'c']
|
||||
raw_specs = itertools.chain.from_iterable(
|
||||
r.split(' ') for r in components
|
||||
)
|
||||
raw_specs = itertools.chain.from_iterable(r.split(" ") for r in components)
|
||||
union_components: List[SelectionSpec] = []
|
||||
|
||||
# ['a', 'b', 'c,d'] -> union('a', 'b', intersection('c', 'd'))
|
||||
@@ -40,11 +34,13 @@ def parse_union(
|
||||
SelectionCriteria.from_single_spec(part)
|
||||
for part in raw_spec.split(INTERSECTION_DELIMITER)
|
||||
]
|
||||
union_components.append(SelectionIntersection(
|
||||
union_components.append(
|
||||
SelectionIntersection(
|
||||
components=intersection_components,
|
||||
expect_exists=expect_exists,
|
||||
raw=raw_spec,
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
return SelectionUnion(
|
||||
components=union_components,
|
||||
@@ -78,9 +74,7 @@ def parse_test_selectors(
|
||||
union_components = []
|
||||
|
||||
if data:
|
||||
union_components.append(
|
||||
SelectionCriteria.from_single_spec(DATA_TEST_SELECTOR)
|
||||
)
|
||||
union_components.append(SelectionCriteria.from_single_spec(DATA_TEST_SELECTOR))
|
||||
if schema:
|
||||
union_components.append(
|
||||
SelectionCriteria.from_single_spec(SCHEMA_TEST_SELECTOR)
|
||||
@@ -98,27 +92,21 @@ def parse_test_selectors(
|
||||
raw=[DATA_TEST_SELECTOR, SCHEMA_TEST_SELECTOR],
|
||||
)
|
||||
|
||||
return SelectionIntersection(
|
||||
components=[base, intersect_with], expect_exists=True
|
||||
)
|
||||
return SelectionIntersection(components=[base, intersect_with], expect_exists=True)
|
||||
|
||||
|
||||
RawDefinition = Union[str, Dict[str, Any]]
|
||||
|
||||
|
||||
def _get_list_dicts(
|
||||
dct: Dict[str, Any], key: str
|
||||
) -> List[RawDefinition]:
|
||||
def _get_list_dicts(dct: Dict[str, Any], key: str) -> List[RawDefinition]:
|
||||
result: List[RawDefinition] = []
|
||||
if key not in dct:
|
||||
raise InternalException(
|
||||
f'Expected to find key {key} in dict, only found {list(dct)}'
|
||||
f"Expected to find key {key} in dict, only found {list(dct)}"
|
||||
)
|
||||
values = dct[key]
|
||||
if not isinstance(values, list):
|
||||
raise ValidationException(
|
||||
f'Invalid value for key "{key}". Expected a list.'
|
||||
)
|
||||
raise ValidationException(f'Invalid value for key "{key}". Expected a list.')
|
||||
for value in values:
|
||||
if isinstance(value, dict):
|
||||
for value_key in value:
|
||||
@@ -133,36 +121,31 @@ def _get_list_dicts(
|
||||
else:
|
||||
raise ValidationException(
|
||||
f'Invalid value type {type(value)} in key "{key}", expected '
|
||||
f'dict or str (value: {value}).'
|
||||
f"dict or str (value: {value})."
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _parse_exclusions(definition) -> Optional[SelectionSpec]:
|
||||
exclusions = _get_list_dicts(definition, 'exclude')
|
||||
parsed_exclusions = [
|
||||
parse_from_definition(excl) for excl in exclusions
|
||||
]
|
||||
exclusions = _get_list_dicts(definition, "exclude")
|
||||
parsed_exclusions = [parse_from_definition(excl) for excl in exclusions]
|
||||
if len(parsed_exclusions) == 1:
|
||||
return parsed_exclusions[0]
|
||||
elif len(parsed_exclusions) > 1:
|
||||
return SelectionUnion(
|
||||
components=parsed_exclusions,
|
||||
raw=exclusions
|
||||
)
|
||||
return SelectionUnion(components=parsed_exclusions, raw=exclusions)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def _parse_include_exclude_subdefs(
|
||||
definitions: List[RawDefinition]
|
||||
definitions: List[RawDefinition],
|
||||
) -> Tuple[List[SelectionSpec], Optional[SelectionSpec]]:
|
||||
include_parts: List[SelectionSpec] = []
|
||||
diff_arg: Optional[SelectionSpec] = None
|
||||
|
||||
for definition in definitions:
|
||||
if isinstance(definition, dict) and 'exclude' in definition:
|
||||
if isinstance(definition, dict) and "exclude" in definition:
|
||||
# do not allow multiple exclude: defs at the same level
|
||||
if diff_arg is not None:
|
||||
yaml_sel_cfg = yaml.dump(definition)
|
||||
@@ -178,7 +161,7 @@ def _parse_include_exclude_subdefs(
|
||||
|
||||
|
||||
def parse_union_definition(definition: Dict[str, Any]) -> SelectionSpec:
|
||||
union_def_parts = _get_list_dicts(definition, 'union')
|
||||
union_def_parts = _get_list_dicts(definition, "union")
|
||||
include, exclude = _parse_include_exclude_subdefs(union_def_parts)
|
||||
|
||||
union = SelectionUnion(components=include)
|
||||
@@ -187,16 +170,11 @@ def parse_union_definition(definition: Dict[str, Any]) -> SelectionSpec:
|
||||
union.raw = definition
|
||||
return union
|
||||
else:
|
||||
return SelectionDifference(
|
||||
components=[union, exclude],
|
||||
raw=definition
|
||||
)
|
||||
return SelectionDifference(components=[union, exclude], raw=definition)
|
||||
|
||||
|
||||
def parse_intersection_definition(
|
||||
definition: Dict[str, Any]
|
||||
) -> SelectionSpec:
|
||||
intersection_def_parts = _get_list_dicts(definition, 'intersection')
|
||||
def parse_intersection_definition(definition: Dict[str, Any]) -> SelectionSpec:
|
||||
intersection_def_parts = _get_list_dicts(definition, "intersection")
|
||||
include, exclude = _parse_include_exclude_subdefs(intersection_def_parts)
|
||||
intersection = SelectionIntersection(components=include)
|
||||
|
||||
@@ -204,10 +182,7 @@ def parse_intersection_definition(
|
||||
intersection.raw = definition
|
||||
return intersection
|
||||
else:
|
||||
return SelectionDifference(
|
||||
components=[intersection, exclude],
|
||||
raw=definition
|
||||
)
|
||||
return SelectionDifference(components=[intersection, exclude], raw=definition)
|
||||
|
||||
|
||||
def parse_dict_definition(definition: Dict[str, Any]) -> SelectionSpec:
|
||||
@@ -221,14 +196,14 @@ def parse_dict_definition(definition: Dict[str, Any]) -> SelectionSpec:
|
||||
f'"{type(key)}" ({key})'
|
||||
)
|
||||
dct = {
|
||||
'method': key,
|
||||
'value': value,
|
||||
"method": key,
|
||||
"value": value,
|
||||
}
|
||||
elif 'method' in definition and 'value' in definition:
|
||||
elif "method" in definition and "value" in definition:
|
||||
dct = definition
|
||||
if 'exclude' in definition:
|
||||
if "exclude" in definition:
|
||||
diff_arg = _parse_exclusions(definition)
|
||||
dct = {k: v for k, v in dct.items() if k != 'exclude'}
|
||||
dct = {k: v for k, v in dct.items() if k != "exclude"}
|
||||
else:
|
||||
raise ValidationException(
|
||||
f'Expected either 1 key or else "method" '
|
||||
@@ -243,13 +218,14 @@ def parse_dict_definition(definition: Dict[str, Any]) -> SelectionSpec:
|
||||
return SelectionDifference(components=[base, diff_arg])
|
||||
|
||||
|
||||
def parse_from_definition(
|
||||
definition: RawDefinition, rootlevel=False
|
||||
) -> SelectionSpec:
|
||||
def parse_from_definition(definition: RawDefinition, rootlevel=False) -> SelectionSpec:
|
||||
|
||||
if (isinstance(definition, dict) and
|
||||
('union' in definition or 'intersection' in definition) and
|
||||
rootlevel and len(definition) > 1):
|
||||
if (
|
||||
isinstance(definition, dict)
|
||||
and ("union" in definition or "intersection" in definition)
|
||||
and rootlevel
|
||||
and len(definition) > 1
|
||||
):
|
||||
keys = ",".join(definition.keys())
|
||||
raise ValidationException(
|
||||
f"Only a single 'union' or 'intersection' key is allowed "
|
||||
@@ -257,25 +233,24 @@ def parse_from_definition(
|
||||
)
|
||||
if isinstance(definition, str):
|
||||
return SelectionCriteria.from_single_spec(definition)
|
||||
elif 'union' in definition:
|
||||
elif "union" in definition:
|
||||
return parse_union_definition(definition)
|
||||
elif 'intersection' in definition:
|
||||
elif "intersection" in definition:
|
||||
return parse_intersection_definition(definition)
|
||||
elif isinstance(definition, dict):
|
||||
return parse_dict_definition(definition)
|
||||
else:
|
||||
raise ValidationException(
|
||||
f'Expected to find union, intersection, str or dict, instead '
|
||||
f'found {type(definition)}: {definition}'
|
||||
f"Expected to find union, intersection, str or dict, instead "
|
||||
f"found {type(definition)}: {definition}"
|
||||
)
|
||||
|
||||
|
||||
def parse_from_selectors_definition(
|
||||
source: SelectorFile
|
||||
) -> Dict[str, SelectionSpec]:
|
||||
def parse_from_selectors_definition(source: SelectorFile) -> Dict[str, SelectionSpec]:
|
||||
result: Dict[str, SelectionSpec] = {}
|
||||
selector: SelectorDefinition
|
||||
for selector in source.selectors:
|
||||
result[selector.name] = parse_from_definition(selector.definition,
|
||||
rootlevel=True)
|
||||
result[selector.name] = parse_from_definition(
|
||||
selector.definition, rootlevel=True
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
from typing import (
|
||||
Set, Iterable, Iterator, Optional, NewType
|
||||
)
|
||||
from typing import Set, Iterable, Iterator, Optional, NewType
|
||||
import networkx as nx # type: ignore
|
||||
|
||||
from dbt.exceptions import InternalException
|
||||
|
||||
UniqueId = NewType('UniqueId', str)
|
||||
UniqueId = NewType("UniqueId", str)
|
||||
|
||||
|
||||
class Graph:
|
||||
"""A wrapper around the networkx graph that understands SelectionCriteria
|
||||
and how they interact with the graph.
|
||||
"""
|
||||
|
||||
def __init__(self, graph):
|
||||
self.graph = graph
|
||||
|
||||
@@ -29,12 +28,11 @@ class Graph:
|
||||
) -> Set[UniqueId]:
|
||||
"""Returns all nodes having a path to `node` in `graph`"""
|
||||
if not self.graph.has_node(node):
|
||||
raise InternalException(f'Node {node} not found in the graph!')
|
||||
raise InternalException(f"Node {node} not found in the graph!")
|
||||
with nx.utils.reversed(self.graph):
|
||||
anc = nx.single_source_shortest_path_length(G=self.graph,
|
||||
source=node,
|
||||
cutoff=max_depth)\
|
||||
.keys()
|
||||
anc = nx.single_source_shortest_path_length(
|
||||
G=self.graph, source=node, cutoff=max_depth
|
||||
).keys()
|
||||
return anc - {node}
|
||||
|
||||
def descendants(
|
||||
@@ -42,16 +40,13 @@ class Graph:
|
||||
) -> Set[UniqueId]:
|
||||
"""Returns all nodes reachable from `node` in `graph`"""
|
||||
if not self.graph.has_node(node):
|
||||
raise InternalException(f'Node {node} not found in the graph!')
|
||||
des = nx.single_source_shortest_path_length(G=self.graph,
|
||||
source=node,
|
||||
cutoff=max_depth)\
|
||||
.keys()
|
||||
raise InternalException(f"Node {node} not found in the graph!")
|
||||
des = nx.single_source_shortest_path_length(
|
||||
G=self.graph, source=node, cutoff=max_depth
|
||||
).keys()
|
||||
return des - {node}
|
||||
|
||||
def select_childrens_parents(
|
||||
self, selected: Set[UniqueId]
|
||||
) -> Set[UniqueId]:
|
||||
def select_childrens_parents(self, selected: Set[UniqueId]) -> Set[UniqueId]:
|
||||
ancestors_for = self.select_children(selected) | selected
|
||||
return self.select_parents(ancestors_for) | ancestors_for
|
||||
|
||||
@@ -77,7 +72,7 @@ class Graph:
|
||||
successors.update(self.graph.successors(node))
|
||||
return successors
|
||||
|
||||
def get_subset_graph(self, selected: Iterable[UniqueId]) -> 'Graph':
|
||||
def get_subset_graph(self, selected: Iterable[UniqueId]) -> "Graph":
|
||||
"""Create and return a new graph that is a shallow copy of the graph,
|
||||
but with only the nodes in include_nodes. Transitive edges across
|
||||
removed nodes are preserved as explicit new edges.
|
||||
@@ -98,7 +93,7 @@ class Graph:
|
||||
)
|
||||
return Graph(new_graph)
|
||||
|
||||
def subgraph(self, nodes: Iterable[UniqueId]) -> 'Graph':
|
||||
def subgraph(self, nodes: Iterable[UniqueId]) -> "Graph":
|
||||
return Graph(self.graph.subgraph(nodes))
|
||||
|
||||
def get_dependent_nodes(self, node: UniqueId):
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import threading
|
||||
from queue import PriorityQueue
|
||||
from typing import (
|
||||
Dict, Set, Optional
|
||||
)
|
||||
from typing import Dict, Set, Optional
|
||||
|
||||
import networkx as nx # type: ignore
|
||||
|
||||
@@ -21,9 +19,8 @@ class GraphQueue:
|
||||
that separate threads do not call `.empty()` or `__len__()` and `.get()` at
|
||||
the same time, as there is an unlocked race!
|
||||
"""
|
||||
def __init__(
|
||||
self, graph: nx.DiGraph, manifest: Manifest, selected: Set[UniqueId]
|
||||
):
|
||||
|
||||
def __init__(self, graph: nx.DiGraph, manifest: Manifest, selected: Set[UniqueId]):
|
||||
self.graph = graph
|
||||
self.manifest = manifest
|
||||
self._selected = selected
|
||||
@@ -75,10 +72,13 @@ class GraphQueue:
|
||||
"""
|
||||
scores = {}
|
||||
for node in self.graph.nodes():
|
||||
score = -1 * len([
|
||||
d for d in nx.descendants(self.graph, node)
|
||||
score = -1 * len(
|
||||
[
|
||||
d
|
||||
for d in nx.descendants(self.graph, node)
|
||||
if self._include_in_cost(d)
|
||||
])
|
||||
]
|
||||
)
|
||||
scores[node] = score
|
||||
return scores
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
from typing import Set, List, Optional
|
||||
|
||||
from .graph import Graph, UniqueId
|
||||
@@ -25,14 +24,13 @@ def get_package_names(nodes):
|
||||
def alert_non_existence(raw_spec, nodes):
|
||||
if len(nodes) == 0:
|
||||
warn_or_error(
|
||||
f"The selection criterion '{str(raw_spec)}' does not match"
|
||||
f" any nodes"
|
||||
f"The selection criterion '{str(raw_spec)}' does not match" f" any nodes"
|
||||
)
|
||||
|
||||
|
||||
class NodeSelector(MethodManager):
|
||||
"""The node selector is aware of the graph and manifest,
|
||||
"""
|
||||
"""The node selector is aware of the graph and manifest,"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: Graph,
|
||||
@@ -45,13 +43,16 @@ class NodeSelector(MethodManager):
|
||||
# build a subgraph containing only non-empty, enabled nodes and enabled
|
||||
# sources.
|
||||
graph_members = {
|
||||
unique_id for unique_id in self.full_graph.nodes()
|
||||
unique_id
|
||||
for unique_id in self.full_graph.nodes()
|
||||
if self._is_graph_member(unique_id)
|
||||
}
|
||||
self.graph = self.full_graph.subgraph(graph_members)
|
||||
|
||||
def select_included(
|
||||
self, included_nodes: Set[UniqueId], spec: SelectionCriteria,
|
||||
self,
|
||||
included_nodes: Set[UniqueId],
|
||||
spec: SelectionCriteria,
|
||||
) -> Set[UniqueId]:
|
||||
"""Select the explicitly included nodes, using the given spec. Return
|
||||
the selected set of unique IDs.
|
||||
@@ -116,10 +117,7 @@ class NodeSelector(MethodManager):
|
||||
if isinstance(spec, SelectionCriteria):
|
||||
result = self.get_nodes_from_criteria(spec)
|
||||
else:
|
||||
node_selections = [
|
||||
self.select_nodes(component)
|
||||
for component in spec
|
||||
]
|
||||
node_selections = [self.select_nodes(component) for component in spec]
|
||||
result = spec.combined(node_selections)
|
||||
if spec.expect_exists:
|
||||
alert_non_existence(spec.raw, result)
|
||||
@@ -149,18 +147,14 @@ class NodeSelector(MethodManager):
|
||||
elif unique_id in self.manifest.exposures:
|
||||
node = self.manifest.exposures[unique_id]
|
||||
else:
|
||||
raise InternalException(
|
||||
f'Node {unique_id} not found in the manifest!'
|
||||
)
|
||||
raise InternalException(f"Node {unique_id} not found in the manifest!")
|
||||
return self.node_is_match(node)
|
||||
|
||||
def filter_selection(self, selected: Set[UniqueId]) -> Set[UniqueId]:
|
||||
"""Return the subset of selected nodes that is a match for this
|
||||
selector.
|
||||
"""
|
||||
return {
|
||||
unique_id for unique_id in selected if self._is_match(unique_id)
|
||||
}
|
||||
return {unique_id for unique_id in selected if self._is_match(unique_id)}
|
||||
|
||||
def expand_selection(self, selected: Set[UniqueId]) -> Set[UniqueId]:
|
||||
"""Perform selector-specific expansion."""
|
||||
|
||||
@@ -31,28 +31,28 @@ from dbt.node_types import NodeType
|
||||
from dbt.ui import warning_tag
|
||||
|
||||
|
||||
SELECTOR_GLOB = '*'
|
||||
SELECTOR_DELIMITER = ':'
|
||||
SELECTOR_GLOB = "*"
|
||||
SELECTOR_DELIMITER = ":"
|
||||
|
||||
|
||||
class MethodName(StrEnum):
|
||||
FQN = 'fqn'
|
||||
Tag = 'tag'
|
||||
Source = 'source'
|
||||
Path = 'path'
|
||||
Package = 'package'
|
||||
Config = 'config'
|
||||
TestName = 'test_name'
|
||||
TestType = 'test_type'
|
||||
ResourceType = 'resource_type'
|
||||
State = 'state'
|
||||
Exposure = 'exposure'
|
||||
FQN = "fqn"
|
||||
Tag = "tag"
|
||||
Source = "source"
|
||||
Path = "path"
|
||||
Package = "package"
|
||||
Config = "config"
|
||||
TestName = "test_name"
|
||||
TestType = "test_type"
|
||||
ResourceType = "resource_type"
|
||||
State = "state"
|
||||
Exposure = "exposure"
|
||||
|
||||
|
||||
def is_selected_node(real_node, node_selector):
|
||||
for i, selector_part in enumerate(node_selector):
|
||||
|
||||
is_last = (i == len(node_selector) - 1)
|
||||
is_last = i == len(node_selector) - 1
|
||||
|
||||
# if we hit a GLOB, then this node is selected
|
||||
if selector_part == SELECTOR_GLOB:
|
||||
@@ -83,15 +83,14 @@ class SelectorMethod(metaclass=abc.ABCMeta):
|
||||
self,
|
||||
manifest: Manifest,
|
||||
previous_state: Optional[PreviousState],
|
||||
arguments: List[str]
|
||||
arguments: List[str],
|
||||
):
|
||||
self.manifest: Manifest = manifest
|
||||
self.previous_state = previous_state
|
||||
self.arguments: List[str] = arguments
|
||||
|
||||
def parsed_nodes(
|
||||
self,
|
||||
included_nodes: Set[UniqueId]
|
||||
self, included_nodes: Set[UniqueId]
|
||||
) -> Iterator[Tuple[UniqueId, ManifestNode]]:
|
||||
|
||||
for key, node in self.manifest.nodes.items():
|
||||
@@ -101,8 +100,7 @@ class SelectorMethod(metaclass=abc.ABCMeta):
|
||||
yield unique_id, node
|
||||
|
||||
def source_nodes(
|
||||
self,
|
||||
included_nodes: Set[UniqueId]
|
||||
self, included_nodes: Set[UniqueId]
|
||||
) -> Iterator[Tuple[UniqueId, ParsedSourceDefinition]]:
|
||||
|
||||
for key, source in self.manifest.sources.items():
|
||||
@@ -112,8 +110,7 @@ class SelectorMethod(metaclass=abc.ABCMeta):
|
||||
yield unique_id, source
|
||||
|
||||
def exposure_nodes(
|
||||
self,
|
||||
included_nodes: Set[UniqueId]
|
||||
self, included_nodes: Set[UniqueId]
|
||||
) -> Iterator[Tuple[UniqueId, ParsedExposure]]:
|
||||
|
||||
for key, exposure in self.manifest.exposures.items():
|
||||
@@ -123,26 +120,28 @@ class SelectorMethod(metaclass=abc.ABCMeta):
|
||||
yield unique_id, exposure
|
||||
|
||||
def all_nodes(
|
||||
self,
|
||||
included_nodes: Set[UniqueId]
|
||||
self, included_nodes: Set[UniqueId]
|
||||
) -> Iterator[Tuple[UniqueId, SelectorTarget]]:
|
||||
yield from chain(self.parsed_nodes(included_nodes),
|
||||
yield from chain(
|
||||
self.parsed_nodes(included_nodes),
|
||||
self.source_nodes(included_nodes),
|
||||
self.exposure_nodes(included_nodes))
|
||||
self.exposure_nodes(included_nodes),
|
||||
)
|
||||
|
||||
def configurable_nodes(
|
||||
self,
|
||||
included_nodes: Set[UniqueId]
|
||||
self, included_nodes: Set[UniqueId]
|
||||
) -> Iterator[Tuple[UniqueId, CompileResultNode]]:
|
||||
yield from chain(self.parsed_nodes(included_nodes),
|
||||
self.source_nodes(included_nodes))
|
||||
yield from chain(
|
||||
self.parsed_nodes(included_nodes), self.source_nodes(included_nodes)
|
||||
)
|
||||
|
||||
def non_source_nodes(
|
||||
self,
|
||||
included_nodes: Set[UniqueId],
|
||||
) -> Iterator[Tuple[UniqueId, Union[ParsedExposure, ManifestNode]]]:
|
||||
yield from chain(self.parsed_nodes(included_nodes),
|
||||
self.exposure_nodes(included_nodes))
|
||||
yield from chain(
|
||||
self.parsed_nodes(included_nodes), self.exposure_nodes(included_nodes)
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
def search(
|
||||
@@ -150,7 +149,7 @@ class SelectorMethod(metaclass=abc.ABCMeta):
|
||||
included_nodes: Set[UniqueId],
|
||||
selector: str,
|
||||
) -> Iterator[UniqueId]:
|
||||
raise NotImplementedError('subclasses should implement this')
|
||||
raise NotImplementedError("subclasses should implement this")
|
||||
|
||||
|
||||
class QualifiedNameSelectorMethod(SelectorMethod):
|
||||
@@ -216,7 +215,7 @@ class SourceSelectorMethod(SelectorMethod):
|
||||
self, included_nodes: Set[UniqueId], selector: str
|
||||
) -> Iterator[UniqueId]:
|
||||
"""yields nodes from included are the specified source."""
|
||||
parts = selector.split('.')
|
||||
parts = selector.split(".")
|
||||
target_package = SELECTOR_GLOB
|
||||
if len(parts) == 1:
|
||||
target_source, target_table = parts[0], None
|
||||
@@ -227,9 +226,9 @@ class SourceSelectorMethod(SelectorMethod):
|
||||
else: # len(parts) > 3 or len(parts) == 0
|
||||
msg = (
|
||||
'Invalid source selector value "{}". Sources must be of the '
|
||||
'form `${{source_name}}`, '
|
||||
'`${{source_name}}.${{target_name}}`, or '
|
||||
'`${{package_name}}.${{source_name}}.${{target_name}}'
|
||||
"form `${{source_name}}`, "
|
||||
"`${{source_name}}.${{target_name}}`, or "
|
||||
"`${{package_name}}.${{source_name}}.${{target_name}}"
|
||||
).format(selector)
|
||||
raise RuntimeException(msg)
|
||||
|
||||
@@ -248,7 +247,7 @@ class ExposureSelectorMethod(SelectorMethod):
|
||||
def search(
|
||||
self, included_nodes: Set[UniqueId], selector: str
|
||||
) -> Iterator[UniqueId]:
|
||||
parts = selector.split('.')
|
||||
parts = selector.split(".")
|
||||
target_package = SELECTOR_GLOB
|
||||
if len(parts) == 1:
|
||||
target_name = parts[0]
|
||||
@@ -257,8 +256,8 @@ class ExposureSelectorMethod(SelectorMethod):
|
||||
else:
|
||||
msg = (
|
||||
'Invalid exposure selector value "{}". Exposures must be of '
|
||||
'the form ${{exposure_name}} or '
|
||||
'${{exposure_package.exposure_name}}'
|
||||
"the form ${{exposure_name}} or "
|
||||
"${{exposure_package.exposure_name}}"
|
||||
).format(selector)
|
||||
raise RuntimeException(msg)
|
||||
|
||||
@@ -275,9 +274,7 @@ class PathSelectorMethod(SelectorMethod):
|
||||
def search(
|
||||
self, included_nodes: Set[UniqueId], selector: str
|
||||
) -> Iterator[UniqueId]:
|
||||
"""Yields nodes from inclucded that match the given path.
|
||||
|
||||
"""
|
||||
"""Yields nodes from inclucded that match the given path."""
|
||||
# use '.' and not 'root' for easy comparison
|
||||
root = Path.cwd()
|
||||
paths = set(p.relative_to(root) for p in root.glob(selector))
|
||||
@@ -336,7 +333,7 @@ class ConfigSelectorMethod(SelectorMethod):
|
||||
parts = self.arguments
|
||||
# special case: if the user wanted to compare test severity,
|
||||
# make the comparison case-insensitive
|
||||
if parts == ['severity']:
|
||||
if parts == ["severity"]:
|
||||
selector = CaseInsensitive(selector)
|
||||
|
||||
# search sources is kind of useless now source configs only have
|
||||
@@ -382,14 +379,13 @@ class TestTypeSelectorMethod(SelectorMethod):
|
||||
self, included_nodes: Set[UniqueId], selector: str
|
||||
) -> Iterator[UniqueId]:
|
||||
search_types: Tuple[Type, ...]
|
||||
if selector == 'schema':
|
||||
if selector == "schema":
|
||||
search_types = (ParsedSchemaTestNode, CompiledSchemaTestNode)
|
||||
elif selector == 'data':
|
||||
elif selector == "data":
|
||||
search_types = (ParsedDataTestNode, CompiledDataTestNode)
|
||||
else:
|
||||
raise RuntimeException(
|
||||
f'Invalid test type selector {selector}: expected "data" or '
|
||||
'"schema"'
|
||||
f'Invalid test type selector {selector}: expected "data" or ' '"schema"'
|
||||
)
|
||||
|
||||
for node, real_node in self.parsed_nodes(included_nodes):
|
||||
@@ -405,25 +401,23 @@ class StateSelectorMethod(SelectorMethod):
|
||||
def _macros_modified(self) -> List[str]:
|
||||
# we checked in the caller!
|
||||
if self.previous_state is None or self.previous_state.manifest is None:
|
||||
raise InternalException(
|
||||
'No comparison manifest in _macros_modified'
|
||||
)
|
||||
raise InternalException("No comparison manifest in _macros_modified")
|
||||
old_macros = self.previous_state.manifest.macros
|
||||
new_macros = self.manifest.macros
|
||||
|
||||
modified = []
|
||||
for uid, macro in new_macros.items():
|
||||
name = f'{macro.package_name}.{macro.name}'
|
||||
name = f"{macro.package_name}.{macro.name}"
|
||||
if uid in old_macros:
|
||||
old_macro = old_macros[uid]
|
||||
if macro.macro_sql != old_macro.macro_sql:
|
||||
modified.append(f'{name} changed')
|
||||
modified.append(f"{name} changed")
|
||||
else:
|
||||
modified.append(f'{name} added')
|
||||
modified.append(f"{name} added")
|
||||
|
||||
for uid, macro in old_macros.items():
|
||||
if uid not in new_macros:
|
||||
modified.append(f'{macro.package_name}.{macro.name} removed')
|
||||
modified.append(f"{macro.package_name}.{macro.name} removed")
|
||||
|
||||
return modified[:3]
|
||||
|
||||
@@ -437,12 +431,14 @@ class StateSelectorMethod(SelectorMethod):
|
||||
if self.macros_were_modified is None:
|
||||
self.macros_were_modified = self._macros_modified()
|
||||
if self.macros_were_modified:
|
||||
log_str = ', '.join(self.macros_were_modified)
|
||||
logger.warning(warning_tag(
|
||||
f'During a state comparison, dbt detected a change in '
|
||||
f'macros. This will not be marked as a modification. Some '
|
||||
f'macros: {log_str}'
|
||||
))
|
||||
log_str = ", ".join(self.macros_were_modified)
|
||||
logger.warning(
|
||||
warning_tag(
|
||||
f"During a state comparison, dbt detected a change in "
|
||||
f"macros. This will not be marked as a modification. Some "
|
||||
f"macros: {log_str}"
|
||||
)
|
||||
)
|
||||
|
||||
return not new.same_contents(old) # type: ignore
|
||||
|
||||
@@ -458,12 +454,12 @@ class StateSelectorMethod(SelectorMethod):
|
||||
) -> Iterator[UniqueId]:
|
||||
if self.previous_state is None or self.previous_state.manifest is None:
|
||||
raise RuntimeException(
|
||||
'Got a state selector method, but no comparison manifest'
|
||||
"Got a state selector method, but no comparison manifest"
|
||||
)
|
||||
|
||||
state_checks = {
|
||||
'modified': self.check_modified,
|
||||
'new': self.check_new,
|
||||
"modified": self.check_modified,
|
||||
"new": self.check_new,
|
||||
}
|
||||
if selector in state_checks:
|
||||
checker = state_checks[selector]
|
||||
@@ -517,7 +513,7 @@ class MethodManager:
|
||||
if method not in self.SELECTOR_METHODS:
|
||||
raise InternalException(
|
||||
f'Method name "{method}" is a valid node selection '
|
||||
f'method name, but it is not handled'
|
||||
f"method name, but it is not handled"
|
||||
)
|
||||
cls: Type[SelectorMethod] = self.SELECTOR_METHODS[method]
|
||||
return cls(self.manifest, self.previous_state, method_arguments)
|
||||
|
||||
@@ -3,23 +3,21 @@ import re
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
from typing import (
|
||||
Set, Iterator, List, Optional, Dict, Union, Any, Iterable, Tuple
|
||||
)
|
||||
from typing import Set, Iterator, List, Optional, Dict, Union, Any, Iterable, Tuple
|
||||
from .graph import UniqueId
|
||||
from .selector_methods import MethodName
|
||||
from dbt.exceptions import RuntimeException, InvalidSelectorException
|
||||
|
||||
|
||||
RAW_SELECTOR_PATTERN = re.compile(
|
||||
r'\A'
|
||||
r'(?P<childrens_parents>(\@))?'
|
||||
r'(?P<parents>((?P<parents_depth>(\d*))\+))?'
|
||||
r'((?P<method>([\w.]+)):)?(?P<value>(.*?))'
|
||||
r'(?P<children>(\+(?P<children_depth>(\d*))))?'
|
||||
r'\Z'
|
||||
r"\A"
|
||||
r"(?P<childrens_parents>(\@))?"
|
||||
r"(?P<parents>((?P<parents_depth>(\d*))\+))?"
|
||||
r"((?P<method>([\w.]+)):)?(?P<value>(.*?))"
|
||||
r"(?P<children>(\+(?P<children_depth>(\d*))))?"
|
||||
r"\Z"
|
||||
)
|
||||
SELECTOR_METHOD_SEPARATOR = '.'
|
||||
SELECTOR_METHOD_SEPARATOR = "."
|
||||
|
||||
|
||||
def _probably_path(value: str):
|
||||
@@ -43,15 +41,15 @@ def _match_to_int(match: Dict[str, str], key: str) -> Optional[int]:
|
||||
return int(raw)
|
||||
except ValueError as exc:
|
||||
raise RuntimeException(
|
||||
f'Invalid node spec - could not handle parent depth {raw}'
|
||||
f"Invalid node spec - could not handle parent depth {raw}"
|
||||
) from exc
|
||||
|
||||
|
||||
SelectionSpec = Union[
|
||||
'SelectionCriteria',
|
||||
'SelectionIntersection',
|
||||
'SelectionDifference',
|
||||
'SelectionUnion',
|
||||
"SelectionCriteria",
|
||||
"SelectionIntersection",
|
||||
"SelectionDifference",
|
||||
"SelectionUnion",
|
||||
]
|
||||
|
||||
|
||||
@@ -71,7 +69,7 @@ class SelectionCriteria:
|
||||
if self.children and self.childrens_parents:
|
||||
raise RuntimeException(
|
||||
f'Invalid node spec {self.raw} - "@" prefix and "+" suffix '
|
||||
'are incompatible'
|
||||
"are incompatible"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -82,12 +80,10 @@ class SelectionCriteria:
|
||||
return MethodName.FQN
|
||||
|
||||
@classmethod
|
||||
def parse_method(
|
||||
cls, groupdict: Dict[str, Any]
|
||||
) -> Tuple[MethodName, List[str]]:
|
||||
raw_method = groupdict.get('method')
|
||||
def parse_method(cls, groupdict: Dict[str, Any]) -> Tuple[MethodName, List[str]]:
|
||||
raw_method = groupdict.get("method")
|
||||
if raw_method is None:
|
||||
return cls.default_method(groupdict['value']), []
|
||||
return cls.default_method(groupdict["value"]), []
|
||||
|
||||
method_parts: List[str] = raw_method.split(SELECTOR_METHOD_SEPARATOR)
|
||||
try:
|
||||
@@ -104,24 +100,22 @@ class SelectionCriteria:
|
||||
@classmethod
|
||||
def selection_criteria_from_dict(
|
||||
cls, raw: Any, dct: Dict[str, Any]
|
||||
) -> 'SelectionCriteria':
|
||||
if 'value' not in dct:
|
||||
raise RuntimeException(
|
||||
f'Invalid node spec "{raw}" - no search value!'
|
||||
)
|
||||
) -> "SelectionCriteria":
|
||||
if "value" not in dct:
|
||||
raise RuntimeException(f'Invalid node spec "{raw}" - no search value!')
|
||||
method_name, method_arguments = cls.parse_method(dct)
|
||||
|
||||
parents_depth = _match_to_int(dct, 'parents_depth')
|
||||
children_depth = _match_to_int(dct, 'children_depth')
|
||||
parents_depth = _match_to_int(dct, "parents_depth")
|
||||
children_depth = _match_to_int(dct, "children_depth")
|
||||
return cls(
|
||||
raw=raw,
|
||||
method=method_name,
|
||||
method_arguments=method_arguments,
|
||||
value=dct['value'],
|
||||
childrens_parents=bool(dct.get('childrens_parents')),
|
||||
parents=bool(dct.get('parents')),
|
||||
value=dct["value"],
|
||||
childrens_parents=bool(dct.get("childrens_parents")),
|
||||
parents=bool(dct.get("parents")),
|
||||
parents_depth=parents_depth,
|
||||
children=bool(dct.get('children')),
|
||||
children=bool(dct.get("children")),
|
||||
children_depth=children_depth,
|
||||
)
|
||||
|
||||
@@ -129,24 +123,24 @@ class SelectionCriteria:
|
||||
def dict_from_single_spec(cls, raw: str):
|
||||
result = RAW_SELECTOR_PATTERN.match(raw)
|
||||
if result is None:
|
||||
return {'error': 'Invalid selector spec'}
|
||||
return {"error": "Invalid selector spec"}
|
||||
dct: Dict[str, Any] = result.groupdict()
|
||||
method_name, method_arguments = cls.parse_method(dct)
|
||||
meth_name = str(method_name)
|
||||
if method_arguments:
|
||||
meth_name = meth_name + '.' + '.'.join(method_arguments)
|
||||
dct['method'] = meth_name
|
||||
dct = {k: v for k, v in dct.items() if (v is not None and v != '')}
|
||||
if 'childrens_parents' in dct:
|
||||
dct['childrens_parents'] = bool(dct.get('childrens_parents'))
|
||||
if 'parents' in dct:
|
||||
dct['parents'] = bool(dct.get('parents'))
|
||||
if 'children' in dct:
|
||||
dct['children'] = bool(dct.get('children'))
|
||||
meth_name = meth_name + "." + ".".join(method_arguments)
|
||||
dct["method"] = meth_name
|
||||
dct = {k: v for k, v in dct.items() if (v is not None and v != "")}
|
||||
if "childrens_parents" in dct:
|
||||
dct["childrens_parents"] = bool(dct.get("childrens_parents"))
|
||||
if "parents" in dct:
|
||||
dct["parents"] = bool(dct.get("parents"))
|
||||
if "children" in dct:
|
||||
dct["children"] = bool(dct.get("children"))
|
||||
return dct
|
||||
|
||||
@classmethod
|
||||
def from_single_spec(cls, raw: str) -> 'SelectionCriteria':
|
||||
def from_single_spec(cls, raw: str) -> "SelectionCriteria":
|
||||
result = RAW_SELECTOR_PATTERN.match(raw)
|
||||
if result is None:
|
||||
# bad spec!
|
||||
@@ -175,9 +169,7 @@ class BaseSelectionGroup(Iterable[SelectionSpec], metaclass=ABCMeta):
|
||||
self,
|
||||
selections: List[Set[UniqueId]],
|
||||
) -> Set[UniqueId]:
|
||||
raise NotImplementedError(
|
||||
'_combine_selections not implemented!'
|
||||
)
|
||||
raise NotImplementedError("_combine_selections not implemented!")
|
||||
|
||||
def combined(self, selections: List[Set[UniqueId]]) -> Set[UniqueId]:
|
||||
if not selections:
|
||||
|
||||
@@ -5,7 +5,9 @@ from pathlib import Path
|
||||
from typing import Tuple, AbstractSet, Union
|
||||
|
||||
from dbt.dataclass_schema import (
|
||||
dbtClassMixin, ValidationError, StrEnum,
|
||||
dbtClassMixin,
|
||||
ValidationError,
|
||||
StrEnum,
|
||||
)
|
||||
from hologram import FieldEncoder, JsonDict
|
||||
from mashumaro.types import SerializableType
|
||||
@@ -13,11 +15,11 @@ from mashumaro.types import SerializableType
|
||||
|
||||
class Port(int, SerializableType):
|
||||
@classmethod
|
||||
def _deserialize(cls, value: Union[int, str]) -> 'Port':
|
||||
def _deserialize(cls, value: Union[int, str]) -> "Port":
|
||||
try:
|
||||
value = int(value)
|
||||
except ValueError:
|
||||
raise ValidationError(f'Cannot encode {value} into port number')
|
||||
raise ValidationError(f"Cannot encode {value} into port number")
|
||||
|
||||
return Port(value)
|
||||
|
||||
@@ -28,7 +30,7 @@ class Port(int, SerializableType):
|
||||
class PortEncoder(FieldEncoder):
|
||||
@property
|
||||
def json_schema(self):
|
||||
return {'type': 'integer', 'minimum': 0, 'maximum': 65535}
|
||||
return {"type": "integer", "minimum": 0, "maximum": 65535}
|
||||
|
||||
|
||||
class TimeDeltaFieldEncoder(FieldEncoder[timedelta]):
|
||||
@@ -44,12 +46,12 @@ class TimeDeltaFieldEncoder(FieldEncoder[timedelta]):
|
||||
return timedelta(seconds=value)
|
||||
except TypeError:
|
||||
raise ValidationError(
|
||||
'cannot encode {} into timedelta'.format(value)
|
||||
"cannot encode {} into timedelta".format(value)
|
||||
) from None
|
||||
|
||||
@property
|
||||
def json_schema(self) -> JsonDict:
|
||||
return {'type': 'number'}
|
||||
return {"type": "number"}
|
||||
|
||||
|
||||
class PathEncoder(FieldEncoder):
|
||||
@@ -63,16 +65,16 @@ class PathEncoder(FieldEncoder):
|
||||
return Path(value)
|
||||
except TypeError:
|
||||
raise ValidationError(
|
||||
'cannot encode {} into timedelta'.format(value)
|
||||
"cannot encode {} into timedelta".format(value)
|
||||
) from None
|
||||
|
||||
@property
|
||||
def json_schema(self) -> JsonDict:
|
||||
return {'type': 'string'}
|
||||
return {"type": "string"}
|
||||
|
||||
|
||||
class NVEnum(StrEnum):
|
||||
novalue = 'novalue'
|
||||
novalue = "novalue"
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, NVEnum)
|
||||
@@ -81,14 +83,17 @@ class NVEnum(StrEnum):
|
||||
@dataclass
|
||||
class NoValue(dbtClassMixin):
|
||||
"""Sometimes, you want a way to say none that isn't None"""
|
||||
|
||||
novalue: NVEnum = NVEnum.novalue
|
||||
|
||||
|
||||
dbtClassMixin.register_field_encoders({
|
||||
dbtClassMixin.register_field_encoders(
|
||||
{
|
||||
Port: PortEncoder(),
|
||||
timedelta: TimeDeltaFieldEncoder(),
|
||||
Path: PathEncoder(),
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
FQNPath = Tuple[str, ...]
|
||||
|
||||
@@ -5,8 +5,8 @@ from typing import Union, Dict, Any
|
||||
|
||||
|
||||
class ModelHookType(StrEnum):
|
||||
PreHook = 'pre-hook'
|
||||
PostHook = 'post-hook'
|
||||
PreHook = "pre-hook"
|
||||
PostHook = "post-hook"
|
||||
|
||||
|
||||
def get_hook_dict(source: Union[str, Dict[str, Any]]) -> Dict[str, Any]:
|
||||
@@ -18,4 +18,4 @@ def get_hook_dict(source: Union[str, Dict[str, Any]]) -> Dict[str, Any]:
|
||||
try:
|
||||
return json.loads(source)
|
||||
except ValueError:
|
||||
return {'sql': source}
|
||||
return {"sql": source}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import os
|
||||
|
||||
PACKAGE_PATH = os.path.dirname(__file__)
|
||||
PROJECT_NAME = 'dbt'
|
||||
PROJECT_NAME = "dbt"
|
||||
|
||||
DOCS_INDEX_FILE_PATH = os.path.normpath(
|
||||
os.path.join(PACKAGE_PATH, '..', "index.html"))
|
||||
DOCS_INDEX_FILE_PATH = os.path.normpath(os.path.join(PACKAGE_PATH, "..", "index.html"))
|
||||
|
||||
@@ -287,4 +287,3 @@
|
||||
{% macro set_sql_header(config) -%}
|
||||
{{ config.set('sql_header', caller()) }}
|
||||
{%- endmacro %}
|
||||
|
||||
|
||||
@@ -23,5 +23,3 @@
|
||||
values ({{ insert_cols_csv }})
|
||||
;
|
||||
{% endmacro %}
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
ProfileConfigDocs = 'https://docs.getdbt.com/docs/configure-your-profile'
|
||||
SnowflakeQuotingDocs = 'https://docs.getdbt.com/v0.10/docs/configuring-quoting'
|
||||
IncrementalDocs = 'https://docs.getdbt.com/docs/configuring-incremental-models'
|
||||
BigQueryNewPartitionBy = 'https://docs.getdbt.com/docs/upgrading-to-0-16-0'
|
||||
ProfileConfigDocs = "https://docs.getdbt.com/docs/configure-your-profile"
|
||||
SnowflakeQuotingDocs = "https://docs.getdbt.com/v0.10/docs/configuring-quoting"
|
||||
IncrementalDocs = "https://docs.getdbt.com/docs/configuring-incremental-models"
|
||||
BigQueryNewPartitionBy = "https://docs.getdbt.com/docs/upgrading-to-0-16-0"
|
||||
|
||||
@@ -26,21 +26,21 @@ colorama_wrap = True
|
||||
colorama.init(wrap=colorama_wrap)
|
||||
|
||||
|
||||
if sys.platform == 'win32' and not os.getenv('TERM'):
|
||||
if sys.platform == "win32" and not os.getenv("TERM"):
|
||||
colorama_wrap = False
|
||||
colorama_stdout = colorama.AnsiToWin32(sys.stdout).stream
|
||||
|
||||
elif sys.platform == 'win32':
|
||||
elif sys.platform == "win32":
|
||||
colorama_wrap = False
|
||||
|
||||
colorama.init(wrap=colorama_wrap)
|
||||
|
||||
|
||||
STDOUT_LOG_FORMAT = '{record.message}'
|
||||
STDOUT_LOG_FORMAT = "{record.message}"
|
||||
DEBUG_LOG_FORMAT = (
|
||||
'{record.time:%Y-%m-%d %H:%M:%S.%f%z} '
|
||||
'({record.thread_name}): '
|
||||
'{record.message}'
|
||||
"{record.time:%Y-%m-%d %H:%M:%S.%f%z} "
|
||||
"({record.thread_name}): "
|
||||
"{record.message}"
|
||||
)
|
||||
|
||||
|
||||
@@ -94,6 +94,7 @@ class JsonFormatter(LogMessageFormatter):
|
||||
"""Return a the record converted to LogMessage's JSON form"""
|
||||
# utils imports exceptions which imports logger...
|
||||
import dbt.utils
|
||||
|
||||
log_message = super().__call__(record, handler)
|
||||
dct = log_message.to_dict(omit_none=True)
|
||||
return json.dumps(dct, cls=dbt.utils.JSONEncoder)
|
||||
@@ -117,9 +118,7 @@ class FormatterMixin:
|
||||
self.format_string = self._text_format_string
|
||||
|
||||
def reset(self):
|
||||
raise NotImplementedError(
|
||||
'reset() not implemented in FormatterMixin subclass'
|
||||
)
|
||||
raise NotImplementedError("reset() not implemented in FormatterMixin subclass")
|
||||
|
||||
|
||||
class OutputHandler(logbook.StreamHandler, FormatterMixin):
|
||||
@@ -164,9 +163,9 @@ class OutputHandler(logbook.StreamHandler, FormatterMixin):
|
||||
if record.level < self.level:
|
||||
return False
|
||||
text_mode = self.formatter_class is logbook.StringFormatter
|
||||
if text_mode and record.extra.get('json_only', False):
|
||||
if text_mode and record.extra.get("json_only", False):
|
||||
return False
|
||||
elif not text_mode and record.extra.get('text_only', False):
|
||||
elif not text_mode and record.extra.get("text_only", False):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
@@ -177,7 +176,7 @@ def _redirect_std_logging():
|
||||
|
||||
|
||||
def _root_channel(record: logbook.LogRecord) -> str:
|
||||
return record.channel.split('.')[0]
|
||||
return record.channel.split(".")[0]
|
||||
|
||||
|
||||
class Relevel(logbook.Processor):
|
||||
@@ -195,7 +194,7 @@ class Relevel(logbook.Processor):
|
||||
def process(self, record):
|
||||
if _root_channel(record) in self.allowed:
|
||||
return
|
||||
record.extra['old_level'] = record.level
|
||||
record.extra["old_level"] = record.level
|
||||
# suppress logs at/below our min level by lowering them to NOTSET
|
||||
if record.level < self.min_level:
|
||||
record.level = logbook.NOTSET
|
||||
@@ -207,12 +206,12 @@ class Relevel(logbook.Processor):
|
||||
|
||||
class JsonOnly(logbook.Processor):
|
||||
def process(self, record):
|
||||
record.extra['json_only'] = True
|
||||
record.extra["json_only"] = True
|
||||
|
||||
|
||||
class TextOnly(logbook.Processor):
|
||||
def process(self, record):
|
||||
record.extra['text_only'] = True
|
||||
record.extra["text_only"] = True
|
||||
|
||||
|
||||
class TimingProcessor(logbook.Processor):
|
||||
@@ -222,8 +221,7 @@ class TimingProcessor(logbook.Processor):
|
||||
|
||||
def process(self, record):
|
||||
if self.timing_info is not None:
|
||||
record.extra['timing_info'] = self.timing_info.to_dict(
|
||||
omit_none=True)
|
||||
record.extra["timing_info"] = self.timing_info.to_dict(omit_none=True)
|
||||
|
||||
|
||||
class DbtProcessState(logbook.Processor):
|
||||
@@ -233,11 +231,10 @@ class DbtProcessState(logbook.Processor):
|
||||
|
||||
def process(self, record):
|
||||
overwrite = (
|
||||
'run_state' not in record.extra or
|
||||
record.extra['run_state'] == 'internal'
|
||||
"run_state" not in record.extra or record.extra["run_state"] == "internal"
|
||||
)
|
||||
if overwrite:
|
||||
record.extra['run_state'] = self.value
|
||||
record.extra["run_state"] = self.value
|
||||
|
||||
|
||||
class DbtModelState(logbook.Processor):
|
||||
@@ -251,7 +248,7 @@ class DbtModelState(logbook.Processor):
|
||||
|
||||
class DbtStatusMessage(logbook.Processor):
|
||||
def process(self, record):
|
||||
record.extra['is_status_message'] = True
|
||||
record.extra["is_status_message"] = True
|
||||
|
||||
|
||||
class UniqueID(logbook.Processor):
|
||||
@@ -260,7 +257,7 @@ class UniqueID(logbook.Processor):
|
||||
super().__init__()
|
||||
|
||||
def process(self, record):
|
||||
record.extra['unique_id'] = self.unique_id
|
||||
record.extra["unique_id"] = self.unique_id
|
||||
|
||||
|
||||
class NodeCount(logbook.Processor):
|
||||
@@ -269,7 +266,7 @@ class NodeCount(logbook.Processor):
|
||||
super().__init__()
|
||||
|
||||
def process(self, record):
|
||||
record.extra['node_count'] = self.node_count
|
||||
record.extra["node_count"] = self.node_count
|
||||
|
||||
|
||||
class NodeMetadata(logbook.Processor):
|
||||
@@ -289,26 +286,26 @@ class NodeMetadata(logbook.Processor):
|
||||
|
||||
def process(self, record):
|
||||
self.process_keys(record)
|
||||
record.extra['node_index'] = self.index
|
||||
record.extra["node_index"] = self.index
|
||||
|
||||
|
||||
class ModelMetadata(NodeMetadata):
|
||||
def mapping_keys(self):
|
||||
return [
|
||||
('alias', 'node_alias'),
|
||||
('schema', 'node_schema'),
|
||||
('database', 'node_database'),
|
||||
('original_file_path', 'node_path'),
|
||||
('name', 'node_name'),
|
||||
('resource_type', 'resource_type'),
|
||||
('depends_on_nodes', 'depends_on'),
|
||||
("alias", "node_alias"),
|
||||
("schema", "node_schema"),
|
||||
("database", "node_database"),
|
||||
("original_file_path", "node_path"),
|
||||
("name", "node_name"),
|
||||
("resource_type", "resource_type"),
|
||||
("depends_on_nodes", "depends_on"),
|
||||
]
|
||||
|
||||
def process_config(self, record):
|
||||
if hasattr(self.node, 'config'):
|
||||
materialized = getattr(self.node.config, 'materialized', None)
|
||||
if hasattr(self.node, "config"):
|
||||
materialized = getattr(self.node.config, "materialized", None)
|
||||
if materialized is not None:
|
||||
record.extra['node_materialized'] = materialized
|
||||
record.extra["node_materialized"] = materialized
|
||||
|
||||
def process(self, record):
|
||||
super().process(record)
|
||||
@@ -318,8 +315,8 @@ class ModelMetadata(NodeMetadata):
|
||||
class HookMetadata(NodeMetadata):
|
||||
def mapping_keys(self):
|
||||
return [
|
||||
('name', 'node_name'),
|
||||
('resource_type', 'resource_type'),
|
||||
("name", "node_name"),
|
||||
("resource_type", "resource_type"),
|
||||
]
|
||||
|
||||
|
||||
@@ -333,30 +330,31 @@ class TimestampNamed(logbook.Processor):
|
||||
record.extra[self.name] = datetime.utcnow().isoformat()
|
||||
|
||||
|
||||
logger = logbook.Logger('dbt')
|
||||
logger = logbook.Logger("dbt")
|
||||
# provide this for the cache, disabled by default
|
||||
CACHE_LOGGER = logbook.Logger('dbt.cache')
|
||||
CACHE_LOGGER = logbook.Logger("dbt.cache")
|
||||
CACHE_LOGGER.disable()
|
||||
|
||||
warnings.filterwarnings("ignore", category=ResourceWarning,
|
||||
message="unclosed.*<socket.socket.*>")
|
||||
warnings.filterwarnings(
|
||||
"ignore", category=ResourceWarning, message="unclosed.*<socket.socket.*>"
|
||||
)
|
||||
|
||||
initialized = False
|
||||
|
||||
|
||||
def make_log_dir_if_missing(log_dir):
|
||||
import dbt.clients.system
|
||||
|
||||
dbt.clients.system.make_directory(log_dir)
|
||||
|
||||
|
||||
class DebugWarnings(logbook.compat.redirected_warnings):
|
||||
"""Log warnings, except send them to 'debug' instead of 'warning' level.
|
||||
"""
|
||||
"""Log warnings, except send them to 'debug' instead of 'warning' level."""
|
||||
|
||||
def make_record(self, message, exception, filename, lineno):
|
||||
rv = super().make_record(message, exception, filename, lineno)
|
||||
rv.level = logbook.DEBUG
|
||||
rv.extra['from_warnings'] = True
|
||||
rv.extra["from_warnings"] = True
|
||||
return rv
|
||||
|
||||
|
||||
@@ -408,14 +406,14 @@ class DelayedFileHandler(logbook.RotatingFileHandler, FormatterMixin):
|
||||
if self.disabled:
|
||||
return
|
||||
|
||||
assert not self.initialized, 'set_path called after being set'
|
||||
assert not self.initialized, "set_path called after being set"
|
||||
|
||||
if log_dir is None:
|
||||
self.disabled = True
|
||||
return
|
||||
|
||||
make_log_dir_if_missing(log_dir)
|
||||
log_path = os.path.join(log_dir, 'dbt.log')
|
||||
log_path = os.path.join(log_dir, "dbt.log")
|
||||
self._super_init(log_path)
|
||||
self._replay_buffered()
|
||||
self._log_path = log_path
|
||||
@@ -435,8 +433,9 @@ class DelayedFileHandler(logbook.RotatingFileHandler, FormatterMixin):
|
||||
FormatterMixin.__init__(self, DEBUG_LOG_FORMAT)
|
||||
|
||||
def _replay_buffered(self):
|
||||
assert self._msg_buffer is not None, \
|
||||
'_msg_buffer should never be None in _replay_buffered'
|
||||
assert (
|
||||
self._msg_buffer is not None
|
||||
), "_msg_buffer should never be None in _replay_buffered"
|
||||
for record in self._msg_buffer:
|
||||
super().emit(record)
|
||||
self._msg_buffer = None
|
||||
@@ -445,7 +444,7 @@ class DelayedFileHandler(logbook.RotatingFileHandler, FormatterMixin):
|
||||
msg = super().format(record)
|
||||
subbed = str(msg)
|
||||
for escape_sequence in dbt.ui.COLORS.values():
|
||||
subbed = subbed.replace(escape_sequence, '')
|
||||
subbed = subbed.replace(escape_sequence, "")
|
||||
return subbed
|
||||
|
||||
def emit(self, record: logbook.LogRecord):
|
||||
@@ -457,11 +456,13 @@ class DelayedFileHandler(logbook.RotatingFileHandler, FormatterMixin):
|
||||
elif self.initialized:
|
||||
super().emit(record)
|
||||
else:
|
||||
assert self._msg_buffer is not None, \
|
||||
'_msg_buffer should never be None if _log_path is set'
|
||||
assert (
|
||||
self._msg_buffer is not None
|
||||
), "_msg_buffer should never be None if _log_path is set"
|
||||
self._msg_buffer.append(record)
|
||||
assert len(self._msg_buffer) < self._bufmax, \
|
||||
'too many messages received before initilization!'
|
||||
assert (
|
||||
len(self._msg_buffer) < self._bufmax
|
||||
), "too many messages received before initilization!"
|
||||
|
||||
|
||||
class LogManager(logbook.NestedSetup):
|
||||
@@ -471,19 +472,21 @@ class LogManager(logbook.NestedSetup):
|
||||
self._null_handler = logbook.NullHandler()
|
||||
self._output_handler = OutputHandler(self.stdout)
|
||||
self._file_handler = DelayedFileHandler()
|
||||
self._relevel_processor = Relevel(allowed=['dbt', 'werkzeug'])
|
||||
self._state_processor = DbtProcessState('internal')
|
||||
self._relevel_processor = Relevel(allowed=["dbt", "werkzeug"])
|
||||
self._state_processor = DbtProcessState("internal")
|
||||
# keep track of wheter we've already entered to decide if we should
|
||||
# be actually pushing. This allows us to log in main() and also
|
||||
# support entering dbt execution via handle_and_check.
|
||||
self._stack_depth = 0
|
||||
super().__init__([
|
||||
super().__init__(
|
||||
[
|
||||
self._null_handler,
|
||||
self._output_handler,
|
||||
self._file_handler,
|
||||
self._relevel_processor,
|
||||
self._state_processor,
|
||||
])
|
||||
]
|
||||
)
|
||||
|
||||
def push_application(self):
|
||||
self._stack_depth += 1
|
||||
@@ -499,8 +502,7 @@ class LogManager(logbook.NestedSetup):
|
||||
self.add_handler(logbook.NullHandler())
|
||||
|
||||
def add_handler(self, handler):
|
||||
"""add an handler to the log manager that runs before the file handler.
|
||||
"""
|
||||
"""add an handler to the log manager that runs before the file handler."""
|
||||
self.objects.append(handler)
|
||||
|
||||
# this is used by `dbt ls` to allow piping stdout to jq, etc
|
||||
@@ -558,8 +560,7 @@ log_manager = LogManager()
|
||||
|
||||
|
||||
def log_cache_events(flag):
|
||||
"""Set the cache logger to propagate its messages based on the given flag.
|
||||
"""
|
||||
"""Set the cache logger to propagate its messages based on the given flag."""
|
||||
# the flag is True if we should log, and False if we shouldn't, so disabled
|
||||
# is the inverse.
|
||||
CACHE_LOGGER.disabled = not flag
|
||||
@@ -583,7 +584,7 @@ class ListLogHandler(LogMessageHandler):
|
||||
level: int = logbook.NOTSET,
|
||||
filter: Callable = None,
|
||||
bubble: bool = False,
|
||||
lst: Optional[List[LogMessage]] = None
|
||||
lst: Optional[List[LogMessage]] = None,
|
||||
) -> None:
|
||||
super().__init__(level, filter, bubble)
|
||||
if lst is None:
|
||||
@@ -592,7 +593,7 @@ class ListLogHandler(LogMessageHandler):
|
||||
|
||||
def should_handle(self, record):
|
||||
"""Only ever emit dbt-sourced log messages to the ListHandler."""
|
||||
if _root_channel(record) != 'dbt':
|
||||
if _root_channel(record) != "dbt":
|
||||
return False
|
||||
return super().should_handle(record)
|
||||
|
||||
@@ -609,28 +610,27 @@ def _env_log_level(var_name: str) -> int:
|
||||
return logging.ERROR
|
||||
|
||||
|
||||
LOG_LEVEL_GOOGLE = _env_log_level('DBT_GOOGLE_DEBUG_LOGGING')
|
||||
LOG_LEVEL_SNOWFLAKE = _env_log_level('DBT_SNOWFLAKE_CONNECTOR_DEBUG_LOGGING')
|
||||
LOG_LEVEL_BOTOCORE = _env_log_level('DBT_BOTOCORE_DEBUG_LOGGING')
|
||||
LOG_LEVEL_HTTP = _env_log_level('DBT_HTTP_DEBUG_LOGGING')
|
||||
LOG_LEVEL_WERKZEUG = _env_log_level('DBT_WERKZEUG_DEBUG_LOGGING')
|
||||
LOG_LEVEL_GOOGLE = _env_log_level("DBT_GOOGLE_DEBUG_LOGGING")
|
||||
LOG_LEVEL_SNOWFLAKE = _env_log_level("DBT_SNOWFLAKE_CONNECTOR_DEBUG_LOGGING")
|
||||
LOG_LEVEL_BOTOCORE = _env_log_level("DBT_BOTOCORE_DEBUG_LOGGING")
|
||||
LOG_LEVEL_HTTP = _env_log_level("DBT_HTTP_DEBUG_LOGGING")
|
||||
LOG_LEVEL_WERKZEUG = _env_log_level("DBT_WERKZEUG_DEBUG_LOGGING")
|
||||
|
||||
logging.getLogger('botocore').setLevel(LOG_LEVEL_BOTOCORE)
|
||||
logging.getLogger('requests').setLevel(LOG_LEVEL_HTTP)
|
||||
logging.getLogger('urllib3').setLevel(LOG_LEVEL_HTTP)
|
||||
logging.getLogger('google').setLevel(LOG_LEVEL_GOOGLE)
|
||||
logging.getLogger('snowflake.connector').setLevel(LOG_LEVEL_SNOWFLAKE)
|
||||
logging.getLogger("botocore").setLevel(LOG_LEVEL_BOTOCORE)
|
||||
logging.getLogger("requests").setLevel(LOG_LEVEL_HTTP)
|
||||
logging.getLogger("urllib3").setLevel(LOG_LEVEL_HTTP)
|
||||
logging.getLogger("google").setLevel(LOG_LEVEL_GOOGLE)
|
||||
logging.getLogger("snowflake.connector").setLevel(LOG_LEVEL_SNOWFLAKE)
|
||||
|
||||
logging.getLogger('parsedatetime').setLevel(logging.ERROR)
|
||||
logging.getLogger('werkzeug').setLevel(LOG_LEVEL_WERKZEUG)
|
||||
logging.getLogger("parsedatetime").setLevel(logging.ERROR)
|
||||
logging.getLogger("werkzeug").setLevel(LOG_LEVEL_WERKZEUG)
|
||||
|
||||
|
||||
def list_handler(
|
||||
lst: Optional[List[LogMessage]],
|
||||
level=logbook.NOTSET,
|
||||
) -> ContextManager:
|
||||
"""Return a context manager that temporarly attaches a list to the logger.
|
||||
"""
|
||||
"""Return a context manager that temporarly attaches a list to the logger."""
|
||||
return ListLogHandler(lst=lst, level=level, bubble=True)
|
||||
|
||||
|
||||
|
||||
660
core/dbt/main.py
660
core/dbt/main.py
File diff suppressed because it is too large
Load Diff
@@ -4,20 +4,20 @@ from dbt.dataclass_schema import StrEnum
|
||||
|
||||
|
||||
class NodeType(StrEnum):
|
||||
Model = 'model'
|
||||
Analysis = 'analysis'
|
||||
Test = 'test'
|
||||
Snapshot = 'snapshot'
|
||||
Operation = 'operation'
|
||||
Seed = 'seed'
|
||||
RPCCall = 'rpc'
|
||||
Documentation = 'docs'
|
||||
Source = 'source'
|
||||
Macro = 'macro'
|
||||
Exposure = 'exposure'
|
||||
Model = "model"
|
||||
Analysis = "analysis"
|
||||
Test = "test"
|
||||
Snapshot = "snapshot"
|
||||
Operation = "operation"
|
||||
Seed = "seed"
|
||||
RPCCall = "rpc"
|
||||
Documentation = "docs"
|
||||
Source = "source"
|
||||
Macro = "macro"
|
||||
Exposure = "exposure"
|
||||
|
||||
@classmethod
|
||||
def executable(cls) -> List['NodeType']:
|
||||
def executable(cls) -> List["NodeType"]:
|
||||
return [
|
||||
cls.Model,
|
||||
cls.Test,
|
||||
@@ -30,7 +30,7 @@ class NodeType(StrEnum):
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def refable(cls) -> List['NodeType']:
|
||||
def refable(cls) -> List["NodeType"]:
|
||||
return [
|
||||
cls.Model,
|
||||
cls.Seed,
|
||||
@@ -38,7 +38,7 @@ class NodeType(StrEnum):
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def documentable(cls) -> List['NodeType']:
|
||||
def documentable(cls) -> List["NodeType"]:
|
||||
return [
|
||||
cls.Model,
|
||||
cls.Seed,
|
||||
@@ -46,16 +46,16 @@ class NodeType(StrEnum):
|
||||
cls.Source,
|
||||
cls.Macro,
|
||||
cls.Analysis,
|
||||
cls.Exposure
|
||||
cls.Exposure,
|
||||
]
|
||||
|
||||
def pluralize(self) -> str:
|
||||
if self == 'analysis':
|
||||
return 'analyses'
|
||||
if self == "analysis":
|
||||
return "analyses"
|
||||
else:
|
||||
return f'{self}s'
|
||||
return f"{self}s"
|
||||
|
||||
|
||||
class RunHookType(StrEnum):
|
||||
Start = 'on-run-start'
|
||||
End = 'on-run-end'
|
||||
Start = "on-run-start"
|
||||
End = "on-run-end"
|
||||
|
||||
@@ -11,6 +11,14 @@ from .seeds import SeedParser # noqa
|
||||
from .snapshots import SnapshotParser # noqa
|
||||
|
||||
from . import ( # noqa
|
||||
analysis, base, data_test, docs, hooks, macros, models, results, schemas,
|
||||
snapshots
|
||||
analysis,
|
||||
base,
|
||||
data_test,
|
||||
docs,
|
||||
hooks,
|
||||
macros,
|
||||
models,
|
||||
results,
|
||||
schemas,
|
||||
snapshots,
|
||||
)
|
||||
|
||||
@@ -8,9 +8,7 @@ from dbt.parser.search import FilesystemSearcher, FileBlock
|
||||
|
||||
class AnalysisParser(SimpleSQLParser[ParsedAnalysisNode]):
|
||||
def get_paths(self):
|
||||
return FilesystemSearcher(
|
||||
self.project, self.project.analysis_paths, '.sql'
|
||||
)
|
||||
return FilesystemSearcher(self.project, self.project.analysis_paths, ".sql")
|
||||
|
||||
def parse_from_dict(self, dct, validate=True) -> ParsedAnalysisNode:
|
||||
if validate:
|
||||
@@ -23,4 +21,4 @@ class AnalysisParser(SimpleSQLParser[ParsedAnalysisNode]):
|
||||
|
||||
@classmethod
|
||||
def get_compiled_path(cls, block: FileBlock):
|
||||
return os.path.join('analysis', block.path.relative_path)
|
||||
return os.path.join("analysis", block.path.relative_path)
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import abc
|
||||
import itertools
|
||||
import os
|
||||
from typing import (
|
||||
List, Dict, Any, Iterable, Generic, TypeVar
|
||||
)
|
||||
from typing import List, Dict, Any, Iterable, Generic, TypeVar
|
||||
|
||||
from dbt.dataclass_schema import ValidationError
|
||||
|
||||
@@ -17,17 +15,15 @@ from dbt.context.providers import (
|
||||
from dbt.adapters.factory import get_adapter
|
||||
from dbt.clients.jinja import get_rendered
|
||||
from dbt.config import Project, RuntimeConfig
|
||||
from dbt.context.context_config import (
|
||||
ContextConfig
|
||||
)
|
||||
from dbt.contracts.files import (
|
||||
SourceFile, FilePath, FileHash
|
||||
)
|
||||
from dbt.context.context_config import ContextConfig
|
||||
from dbt.contracts.files import SourceFile, FilePath, FileHash
|
||||
from dbt.contracts.graph.manifest import MacroManifest
|
||||
from dbt.contracts.graph.parsed import HasUniqueID
|
||||
from dbt.contracts.graph.unparsed import UnparsedNode
|
||||
from dbt.exceptions import (
|
||||
CompilationException, validator_error_message, InternalException
|
||||
CompilationException,
|
||||
validator_error_message,
|
||||
InternalException,
|
||||
)
|
||||
from dbt import hooks
|
||||
from dbt.node_types import NodeType
|
||||
@@ -37,14 +33,14 @@ from dbt.parser.search import FileBlock
|
||||
# internally, the parser may store a less-restrictive type that will be
|
||||
# transformed into the final type. But it will have to be derived from
|
||||
# ParsedNode to be operable.
|
||||
FinalValue = TypeVar('FinalValue', bound=HasUniqueID)
|
||||
IntermediateValue = TypeVar('IntermediateValue', bound=HasUniqueID)
|
||||
FinalValue = TypeVar("FinalValue", bound=HasUniqueID)
|
||||
IntermediateValue = TypeVar("IntermediateValue", bound=HasUniqueID)
|
||||
|
||||
IntermediateNode = TypeVar('IntermediateNode', bound=Any)
|
||||
FinalNode = TypeVar('FinalNode', bound=ManifestNodes)
|
||||
IntermediateNode = TypeVar("IntermediateNode", bound=Any)
|
||||
FinalNode = TypeVar("FinalNode", bound=ManifestNodes)
|
||||
|
||||
|
||||
ConfiguredBlockType = TypeVar('ConfiguredBlockType', bound=FileBlock)
|
||||
ConfiguredBlockType = TypeVar("ConfiguredBlockType", bound=FileBlock)
|
||||
|
||||
|
||||
class BaseParser(Generic[FinalValue]):
|
||||
@@ -73,9 +69,9 @@ class BaseParser(Generic[FinalValue]):
|
||||
|
||||
def generate_unique_id(self, resource_name: str) -> str:
|
||||
"""Returns a unique identifier for a resource"""
|
||||
return "{}.{}.{}".format(self.resource_type,
|
||||
self.project.project_name,
|
||||
resource_name)
|
||||
return "{}.{}.{}".format(
|
||||
self.resource_type, self.project.project_name, resource_name
|
||||
)
|
||||
|
||||
def load_file(
|
||||
self,
|
||||
@@ -89,7 +85,7 @@ class BaseParser(Generic[FinalValue]):
|
||||
if set_contents:
|
||||
source_file.contents = file_contents.strip()
|
||||
else:
|
||||
source_file.contents = ''
|
||||
source_file.contents = ""
|
||||
return source_file
|
||||
|
||||
|
||||
@@ -108,8 +104,7 @@ class Parser(BaseParser[FinalValue], Generic[FinalValue]):
|
||||
|
||||
class RelationUpdate:
|
||||
def __init__(
|
||||
self, config: RuntimeConfig, macro_manifest: MacroManifest,
|
||||
component: str
|
||||
self, config: RuntimeConfig, macro_manifest: MacroManifest, component: str
|
||||
) -> None:
|
||||
macro = macro_manifest.find_generate_macro_by_name(
|
||||
component=component,
|
||||
@@ -117,7 +112,7 @@ class RelationUpdate:
|
||||
)
|
||||
if macro is None:
|
||||
raise InternalException(
|
||||
f'No macro with name generate_{component}_name found'
|
||||
f"No macro with name generate_{component}_name found"
|
||||
)
|
||||
|
||||
root_context = generate_generate_component_name_macro(
|
||||
@@ -126,9 +121,7 @@ class RelationUpdate:
|
||||
self.updater = MacroGenerator(macro, root_context)
|
||||
self.component = component
|
||||
|
||||
def __call__(
|
||||
self, parsed_node: Any, config_dict: Dict[str, Any]
|
||||
) -> None:
|
||||
def __call__(self, parsed_node: Any, config_dict: Dict[str, Any]) -> None:
|
||||
override = config_dict.get(self.component)
|
||||
new_value = self.updater(override, parsed_node)
|
||||
if isinstance(new_value, str):
|
||||
@@ -150,16 +143,13 @@ class ConfiguredParser(
|
||||
super().__init__(results, project, root_project, macro_manifest)
|
||||
|
||||
self._update_node_database = RelationUpdate(
|
||||
macro_manifest=macro_manifest, config=root_project,
|
||||
component='database'
|
||||
macro_manifest=macro_manifest, config=root_project, component="database"
|
||||
)
|
||||
self._update_node_schema = RelationUpdate(
|
||||
macro_manifest=macro_manifest, config=root_project,
|
||||
component='schema'
|
||||
macro_manifest=macro_manifest, config=root_project, component="schema"
|
||||
)
|
||||
self._update_node_alias = RelationUpdate(
|
||||
macro_manifest=macro_manifest, config=root_project,
|
||||
component='alias'
|
||||
macro_manifest=macro_manifest, config=root_project, component="alias"
|
||||
)
|
||||
|
||||
@abc.abstractclassmethod
|
||||
@@ -206,7 +196,11 @@ class ConfiguredParser(
|
||||
config[key] = [hooks.get_hook_dict(h) for h in config[key]]
|
||||
|
||||
def _create_error_node(
|
||||
self, name: str, path: str, original_file_path: str, raw_sql: str,
|
||||
self,
|
||||
name: str,
|
||||
path: str,
|
||||
original_file_path: str,
|
||||
raw_sql: str,
|
||||
) -> UnparsedNode:
|
||||
"""If we hit an error before we've actually parsed a node, provide some
|
||||
level of useful information by attaching this to the exception.
|
||||
@@ -239,20 +233,20 @@ class ConfiguredParser(
|
||||
if name is None:
|
||||
name = block.name
|
||||
dct = {
|
||||
'alias': name,
|
||||
'schema': self.default_schema,
|
||||
'database': self.default_database,
|
||||
'fqn': fqn,
|
||||
'name': name,
|
||||
'root_path': self.project.project_root,
|
||||
'resource_type': self.resource_type,
|
||||
'path': path,
|
||||
'original_file_path': block.path.original_file_path,
|
||||
'package_name': self.project.project_name,
|
||||
'raw_sql': block.contents,
|
||||
'unique_id': self.generate_unique_id(name),
|
||||
'config': self.config_dict(config),
|
||||
'checksum': block.file.checksum.to_dict(omit_none=True),
|
||||
"alias": name,
|
||||
"schema": self.default_schema,
|
||||
"database": self.default_database,
|
||||
"fqn": fqn,
|
||||
"name": name,
|
||||
"root_path": self.project.project_root,
|
||||
"resource_type": self.resource_type,
|
||||
"path": path,
|
||||
"original_file_path": block.path.original_file_path,
|
||||
"package_name": self.project.project_name,
|
||||
"raw_sql": block.contents,
|
||||
"unique_id": self.generate_unique_id(name),
|
||||
"config": self.config_dict(config),
|
||||
"checksum": block.file.checksum.to_dict(omit_none=True),
|
||||
}
|
||||
dct.update(kwargs)
|
||||
try:
|
||||
@@ -290,9 +284,7 @@ class ConfiguredParser(
|
||||
|
||||
# this goes through the process of rendering, but just throws away
|
||||
# the rendered result. The "macro capture" is the point?
|
||||
get_rendered(
|
||||
parsed_node.raw_sql, context, parsed_node, capture_macros=True
|
||||
)
|
||||
get_rendered(parsed_node.raw_sql, context, parsed_node, capture_macros=True)
|
||||
|
||||
# This is taking the original config for the node, converting it to a dict,
|
||||
# updating the config with new config passed in, then re-creating the
|
||||
@@ -324,12 +316,10 @@ class ConfiguredParser(
|
||||
config_dict = config.build_config_dict()
|
||||
|
||||
# Set tags on node provided in config blocks
|
||||
model_tags = config_dict.get('tags', [])
|
||||
model_tags = config_dict.get("tags", [])
|
||||
parsed_node.tags.extend(model_tags)
|
||||
|
||||
parsed_node.unrendered_config = config.build_config_dict(
|
||||
rendered=False
|
||||
)
|
||||
parsed_node.unrendered_config = config.build_config_dict(rendered=False)
|
||||
|
||||
# do this once before we parse the node database/schema/alias, so
|
||||
# parsed_node.config is what it would be if they did nothing
|
||||
@@ -338,8 +328,9 @@ class ConfiguredParser(
|
||||
|
||||
# at this point, we've collected our hooks. Use the node context to
|
||||
# render each hook and collect refs/sources
|
||||
hooks = list(itertools.chain(parsed_node.config.pre_hook,
|
||||
parsed_node.config.post_hook))
|
||||
hooks = list(
|
||||
itertools.chain(parsed_node.config.pre_hook, parsed_node.config.post_hook)
|
||||
)
|
||||
# skip context rebuilding if there aren't any hooks
|
||||
if not hooks:
|
||||
return
|
||||
@@ -362,20 +353,18 @@ class ConfiguredParser(
|
||||
)
|
||||
else:
|
||||
raise InternalException(
|
||||
f'Got an unexpected project version={config_version}, '
|
||||
f'expected 2'
|
||||
f"Got an unexpected project version={config_version}, " f"expected 2"
|
||||
)
|
||||
|
||||
def config_dict(
|
||||
self, config: ContextConfig,
|
||||
self,
|
||||
config: ContextConfig,
|
||||
) -> Dict[str, Any]:
|
||||
config_dict = config.build_config_dict(base=True)
|
||||
self._mangle_hooks(config_dict)
|
||||
return config_dict
|
||||
|
||||
def render_update(
|
||||
self, node: IntermediateNode, config: ContextConfig
|
||||
) -> None:
|
||||
def render_update(self, node: IntermediateNode, config: ContextConfig) -> None:
|
||||
try:
|
||||
self.render_with_context(node, config)
|
||||
self.update_parsed_node(node, config)
|
||||
@@ -418,7 +407,7 @@ class ConfiguredParser(
|
||||
|
||||
class SimpleParser(
|
||||
ConfiguredParser[ConfiguredBlockType, FinalNode, FinalNode],
|
||||
Generic[ConfiguredBlockType, FinalNode]
|
||||
Generic[ConfiguredBlockType, FinalNode],
|
||||
):
|
||||
def transform(self, node):
|
||||
return node
|
||||
@@ -426,14 +415,12 @@ class SimpleParser(
|
||||
|
||||
class SQLParser(
|
||||
ConfiguredParser[FileBlock, IntermediateNode, FinalNode],
|
||||
Generic[IntermediateNode, FinalNode]
|
||||
Generic[IntermediateNode, FinalNode],
|
||||
):
|
||||
def parse_file(self, file_block: FileBlock) -> None:
|
||||
self.parse_node(file_block)
|
||||
|
||||
|
||||
class SimpleSQLParser(
|
||||
SQLParser[FinalNode, FinalNode]
|
||||
):
|
||||
class SimpleSQLParser(SQLParser[FinalNode, FinalNode]):
|
||||
def transform(self, node):
|
||||
return node
|
||||
|
||||
@@ -7,9 +7,7 @@ from dbt.utils import get_pseudo_test_path
|
||||
|
||||
class DataTestParser(SimpleSQLParser[ParsedDataTestNode]):
|
||||
def get_paths(self):
|
||||
return FilesystemSearcher(
|
||||
self.project, self.project.test_paths, '.sql'
|
||||
)
|
||||
return FilesystemSearcher(self.project, self.project.test_paths, ".sql")
|
||||
|
||||
def parse_from_dict(self, dct, validate=True) -> ParsedDataTestNode:
|
||||
if validate:
|
||||
@@ -21,11 +19,10 @@ class DataTestParser(SimpleSQLParser[ParsedDataTestNode]):
|
||||
return NodeType.Test
|
||||
|
||||
def transform(self, node):
|
||||
if 'data' not in node.tags:
|
||||
node.tags.append('data')
|
||||
if "data" not in node.tags:
|
||||
node.tags.append("data")
|
||||
return node
|
||||
|
||||
@classmethod
|
||||
def get_compiled_path(cls, block: FileBlock):
|
||||
return get_pseudo_test_path(block.name, block.path.relative_path,
|
||||
'data_test')
|
||||
return get_pseudo_test_path(block.name, block.path.relative_path, "data_test")
|
||||
|
||||
@@ -7,11 +7,14 @@ from dbt.contracts.graph.parsed import ParsedDocumentation
|
||||
from dbt.node_types import NodeType
|
||||
from dbt.parser.base import Parser
|
||||
from dbt.parser.search import (
|
||||
BlockContents, FileBlock, FilesystemSearcher, BlockSearcher
|
||||
BlockContents,
|
||||
FileBlock,
|
||||
FilesystemSearcher,
|
||||
BlockSearcher,
|
||||
)
|
||||
|
||||
|
||||
SHOULD_PARSE_RE = re.compile(r'{[{%]')
|
||||
SHOULD_PARSE_RE = re.compile(r"{[{%]")
|
||||
|
||||
|
||||
class DocumentationParser(Parser[ParsedDocumentation]):
|
||||
@@ -19,7 +22,7 @@ class DocumentationParser(Parser[ParsedDocumentation]):
|
||||
return FilesystemSearcher(
|
||||
project=self.project,
|
||||
relative_dirs=self.project.docs_paths,
|
||||
extension='.md',
|
||||
extension=".md",
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -33,11 +36,9 @@ class DocumentationParser(Parser[ParsedDocumentation]):
|
||||
def generate_unique_id(self, resource_name: str) -> str:
|
||||
# because docs are in their own graph namespace, node type doesn't
|
||||
# need to be part of the unique ID.
|
||||
return '{}.{}'.format(self.project.project_name, resource_name)
|
||||
return "{}.{}".format(self.project.project_name, resource_name)
|
||||
|
||||
def parse_block(
|
||||
self, block: BlockContents
|
||||
) -> Iterable[ParsedDocumentation]:
|
||||
def parse_block(self, block: BlockContents) -> Iterable[ParsedDocumentation]:
|
||||
unique_id = self.generate_unique_id(block.name)
|
||||
contents = get_rendered(block.contents, {}).strip()
|
||||
|
||||
@@ -55,7 +56,7 @@ class DocumentationParser(Parser[ParsedDocumentation]):
|
||||
def parse_file(self, file_block: FileBlock):
|
||||
searcher: Iterable[BlockContents] = BlockSearcher(
|
||||
source=[file_block],
|
||||
allowed_blocks={'docs'},
|
||||
allowed_blocks={"docs"},
|
||||
source_tag_factory=BlockContents,
|
||||
)
|
||||
for block in searcher:
|
||||
|
||||
@@ -24,7 +24,7 @@ class HookBlock(FileBlock):
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return '{}-{!s}-{!s}'.format(self.project, self.hook_type, self.index)
|
||||
return "{}-{!s}-{!s}".format(self.project, self.hook_type, self.index)
|
||||
|
||||
|
||||
class HookSearcher(Iterable[HookBlock]):
|
||||
@@ -33,9 +33,7 @@ class HookSearcher(Iterable[HookBlock]):
|
||||
self.source_file = source_file
|
||||
self.hook_type = hook_type
|
||||
|
||||
def _hook_list(
|
||||
self, hooks: Union[str, List[str], Tuple[str, ...]]
|
||||
) -> List[str]:
|
||||
def _hook_list(self, hooks: Union[str, List[str], Tuple[str, ...]]) -> List[str]:
|
||||
if isinstance(hooks, tuple):
|
||||
hooks = list(hooks)
|
||||
elif not isinstance(hooks, list):
|
||||
@@ -49,8 +47,9 @@ class HookSearcher(Iterable[HookBlock]):
|
||||
hooks = self.project.on_run_end
|
||||
else:
|
||||
raise InternalException(
|
||||
'hook_type must be one of "{}" or "{}" (got {})'
|
||||
.format(RunHookType.Start, RunHookType.End, self.hook_type)
|
||||
'hook_type must be one of "{}" or "{}" (got {})'.format(
|
||||
RunHookType.Start, RunHookType.End, self.hook_type
|
||||
)
|
||||
)
|
||||
return self._hook_list(hooks)
|
||||
|
||||
@@ -73,8 +72,8 @@ class HookParser(SimpleParser[HookBlock, ParsedHookNode]):
|
||||
def get_paths(self) -> List[FilePath]:
|
||||
path = FilePath(
|
||||
project_root=self.project.project_root,
|
||||
searched_path='.',
|
||||
relative_path='dbt_project.yml',
|
||||
searched_path=".",
|
||||
relative_path="dbt_project.yml",
|
||||
)
|
||||
return [path]
|
||||
|
||||
@@ -98,9 +97,13 @@ class HookParser(SimpleParser[HookBlock, ParsedHookNode]):
|
||||
) -> ParsedHookNode:
|
||||
|
||||
return super()._create_parsetime_node(
|
||||
block=block, path=path, config=config, fqn=fqn,
|
||||
index=block.index, name=name,
|
||||
tags=[str(block.hook_type)]
|
||||
block=block,
|
||||
path=path,
|
||||
config=config,
|
||||
fqn=fqn,
|
||||
index=block.index,
|
||||
name=name,
|
||||
tags=[str(block.hook_type)],
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@@ -18,7 +18,7 @@ class MacroParser(BaseParser[ParsedMacro]):
|
||||
return FilesystemSearcher(
|
||||
project=self.project,
|
||||
relative_dirs=self.project.macro_paths,
|
||||
extension='.sql',
|
||||
extension=".sql",
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -45,15 +45,13 @@ class MacroParser(BaseParser[ParsedMacro]):
|
||||
unique_id=unique_id,
|
||||
)
|
||||
|
||||
def parse_unparsed_macros(
|
||||
self, base_node: UnparsedMacro
|
||||
) -> Iterable[ParsedMacro]:
|
||||
def parse_unparsed_macros(self, base_node: UnparsedMacro) -> Iterable[ParsedMacro]:
|
||||
try:
|
||||
blocks: List[jinja.BlockTag] = [
|
||||
t for t in
|
||||
jinja.extract_toplevel_blocks(
|
||||
t
|
||||
for t in jinja.extract_toplevel_blocks(
|
||||
base_node.raw_sql,
|
||||
allowed_blocks={'macro', 'materialization'},
|
||||
allowed_blocks={"macro", "materialization"},
|
||||
collect_raw_data=False,
|
||||
)
|
||||
if isinstance(t, jinja.BlockTag)
|
||||
@@ -75,8 +73,8 @@ class MacroParser(BaseParser[ParsedMacro]):
|
||||
# things have gone disastrously wrong, we thought we only
|
||||
# parsed one block!
|
||||
raise CompilationException(
|
||||
f'Found multiple macros in {block.full_block}, expected 1',
|
||||
node=base_node
|
||||
f"Found multiple macros in {block.full_block}, expected 1",
|
||||
node=base_node,
|
||||
)
|
||||
|
||||
macro_name = macro_nodes[0].name
|
||||
@@ -84,7 +82,7 @@ class MacroParser(BaseParser[ParsedMacro]):
|
||||
if not macro_name.startswith(MACRO_PREFIX):
|
||||
continue
|
||||
|
||||
name: str = macro_name.replace(MACRO_PREFIX, '')
|
||||
name: str = macro_name.replace(MACRO_PREFIX, "")
|
||||
node = self.parse_macro(block, base_node, name)
|
||||
yield node
|
||||
|
||||
|
||||
@@ -3,7 +3,15 @@ from dataclasses import field
|
||||
import os
|
||||
import pickle
|
||||
from typing import (
|
||||
Dict, Optional, Mapping, Callable, Any, List, Type, Union, MutableMapping
|
||||
Dict,
|
||||
Optional,
|
||||
Mapping,
|
||||
Callable,
|
||||
Any,
|
||||
List,
|
||||
Type,
|
||||
Union,
|
||||
MutableMapping,
|
||||
)
|
||||
import time
|
||||
|
||||
@@ -23,11 +31,13 @@ from dbt.config import Project, RuntimeConfig
|
||||
from dbt.context.docs import generate_runtime_docs
|
||||
from dbt.contracts.files import FilePath, FileHash
|
||||
from dbt.contracts.graph.compiled import ManifestNode
|
||||
from dbt.contracts.graph.manifest import (
|
||||
Manifest, MacroManifest, AnyManifest, Disabled
|
||||
)
|
||||
from dbt.contracts.graph.manifest import Manifest, MacroManifest, AnyManifest, Disabled
|
||||
from dbt.contracts.graph.parsed import (
|
||||
ParsedSourceDefinition, ParsedNode, ParsedMacro, ColumnInfo, ParsedExposure
|
||||
ParsedSourceDefinition,
|
||||
ParsedNode,
|
||||
ParsedMacro,
|
||||
ColumnInfo,
|
||||
ParsedExposure,
|
||||
)
|
||||
from dbt.contracts.util import Writable
|
||||
from dbt.exceptions import (
|
||||
@@ -55,8 +65,8 @@ from dbt.version import __version__
|
||||
|
||||
from dbt.dataclass_schema import dbtClassMixin
|
||||
|
||||
PARTIAL_PARSE_FILE_NAME = 'partial_parse.pickle'
|
||||
PARSING_STATE = DbtProcessState('parsing')
|
||||
PARTIAL_PARSE_FILE_NAME = "partial_parse.pickle"
|
||||
PARSING_STATE = DbtProcessState("parsing")
|
||||
DEFAULT_PARTIAL_PARSE = False
|
||||
|
||||
|
||||
@@ -110,20 +120,22 @@ def make_parse_result(
|
||||
"""Make a ParseResult from the project configuration and the profile."""
|
||||
# if any of these change, we need to reject the parser
|
||||
vars_hash = FileHash.from_contents(
|
||||
'\x00'.join([
|
||||
getattr(config.args, 'vars', '{}') or '{}',
|
||||
getattr(config.args, 'profile', '') or '',
|
||||
getattr(config.args, 'target', '') or '',
|
||||
__version__
|
||||
])
|
||||
"\x00".join(
|
||||
[
|
||||
getattr(config.args, "vars", "{}") or "{}",
|
||||
getattr(config.args, "profile", "") or "",
|
||||
getattr(config.args, "target", "") or "",
|
||||
__version__,
|
||||
]
|
||||
)
|
||||
profile_path = os.path.join(config.args.profiles_dir, 'profiles.yml')
|
||||
)
|
||||
profile_path = os.path.join(config.args.profiles_dir, "profiles.yml")
|
||||
with open(profile_path) as fp:
|
||||
profile_hash = FileHash.from_contents(fp.read())
|
||||
|
||||
project_hashes = {}
|
||||
for name, project in all_projects.items():
|
||||
path = os.path.join(project.project_root, 'dbt_project.yml')
|
||||
path = os.path.join(project.project_root, "dbt_project.yml")
|
||||
with open(path) as fp:
|
||||
project_hashes[name] = FileHash.from_contents(fp.read())
|
||||
|
||||
@@ -153,7 +165,8 @@ class ManifestLoader:
|
||||
# in dictionaries: nodes, sources, docs, macros, exposures,
|
||||
# macro_patches, patches, source_patches, files, etc
|
||||
self.results: ParseResult = make_parse_result(
|
||||
root_project, all_projects,
|
||||
root_project,
|
||||
all_projects,
|
||||
)
|
||||
self._loaded_file_cache: Dict[str, FileBlock] = {}
|
||||
self._perf_info = ManifestLoaderInfo(
|
||||
@@ -162,20 +175,18 @@ class ManifestLoader:
|
||||
|
||||
def track_project_load(self):
|
||||
invocation_id = dbt.tracking.active_user.invocation_id
|
||||
dbt.tracking.track_project_load({
|
||||
dbt.tracking.track_project_load(
|
||||
{
|
||||
"invocation_id": invocation_id,
|
||||
"project_id": self.root_project.hashed_name(),
|
||||
"path_count": self._perf_info.path_count,
|
||||
"parse_project_elapsed": self._perf_info.parse_project_elapsed,
|
||||
"patch_sources_elapsed": self._perf_info.patch_sources_elapsed,
|
||||
"process_manifest_elapsed": (
|
||||
self._perf_info.process_manifest_elapsed
|
||||
),
|
||||
"process_manifest_elapsed": (self._perf_info.process_manifest_elapsed),
|
||||
"load_all_elapsed": self._perf_info.load_all_elapsed,
|
||||
"is_partial_parse_enabled": (
|
||||
self._perf_info.is_partial_parse_enabled
|
||||
),
|
||||
})
|
||||
"is_partial_parse_enabled": (self._perf_info.is_partial_parse_enabled),
|
||||
}
|
||||
)
|
||||
|
||||
def parse_with_cache(
|
||||
self,
|
||||
@@ -220,8 +231,7 @@ class ManifestLoader:
|
||||
) -> None:
|
||||
parsers: List[Parser] = []
|
||||
for cls in _parser_types:
|
||||
parser = cls(self.results, project, self.root_project,
|
||||
macro_manifest)
|
||||
parser = cls(self.results, project, self.root_project, macro_manifest)
|
||||
parsers.append(parser)
|
||||
|
||||
# per-project cache.
|
||||
@@ -238,11 +248,13 @@ class ManifestLoader:
|
||||
parser_path_count = parser_path_count + 1
|
||||
|
||||
if parser_path_count > 0:
|
||||
project_parser_info.append(ParserInfo(
|
||||
project_parser_info.append(
|
||||
ParserInfo(
|
||||
parser=parser.resource_type,
|
||||
path_count=parser_path_count,
|
||||
elapsed=time.perf_counter() - parser_start_timer
|
||||
))
|
||||
elapsed=time.perf_counter() - parser_start_timer,
|
||||
)
|
||||
)
|
||||
total_path_count = total_path_count + parser_path_count
|
||||
|
||||
elapsed = time.perf_counter() - start_timer
|
||||
@@ -250,12 +262,10 @@ class ManifestLoader:
|
||||
project_name=project.project_name,
|
||||
path_count=total_path_count,
|
||||
elapsed=elapsed,
|
||||
parsers=project_parser_info
|
||||
parsers=project_parser_info,
|
||||
)
|
||||
self._perf_info.projects.append(project_info)
|
||||
self._perf_info.path_count = (
|
||||
self._perf_info.path_count + total_path_count
|
||||
)
|
||||
self._perf_info.path_count = self._perf_info.path_count + total_path_count
|
||||
|
||||
def load_only_macros(self) -> MacroManifest:
|
||||
old_results = self.read_parse_results()
|
||||
@@ -267,8 +277,7 @@ class ManifestLoader:
|
||||
|
||||
# make a manifest with just the macros to get the context
|
||||
macro_manifest = MacroManifest(
|
||||
macros=self.results.macros,
|
||||
files=self.results.files
|
||||
macros=self.results.macros, files=self.results.files
|
||||
)
|
||||
self.macro_hook(macro_manifest)
|
||||
return macro_manifest
|
||||
@@ -278,7 +287,7 @@ class ManifestLoader:
|
||||
# if partial parse is enabled, load old results
|
||||
old_results = self.read_parse_results()
|
||||
if old_results is not None:
|
||||
logger.debug('Got an acceptable cached parse result')
|
||||
logger.debug("Got an acceptable cached parse result")
|
||||
# store the macros & files from the adapter macro manifest
|
||||
self.results.macros.update(macro_manifest.macros)
|
||||
self.results.files.update(macro_manifest.files)
|
||||
@@ -289,15 +298,12 @@ class ManifestLoader:
|
||||
# parse a single project
|
||||
self.parse_project(project, macro_manifest, old_results)
|
||||
|
||||
self._perf_info.parse_project_elapsed = (
|
||||
time.perf_counter() - start_timer
|
||||
)
|
||||
self._perf_info.parse_project_elapsed = time.perf_counter() - start_timer
|
||||
|
||||
def write_parse_results(self):
|
||||
path = os.path.join(self.root_project.target_path,
|
||||
PARTIAL_PARSE_FILE_NAME)
|
||||
path = os.path.join(self.root_project.target_path, PARTIAL_PARSE_FILE_NAME)
|
||||
make_directory(self.root_project.target_path)
|
||||
with open(path, 'wb') as fp:
|
||||
with open(path, "wb") as fp:
|
||||
pickle.dump(self.results, fp)
|
||||
|
||||
def matching_parse_results(self, result: ParseResult) -> bool:
|
||||
@@ -307,31 +313,32 @@ class ManifestLoader:
|
||||
try:
|
||||
if result.dbt_version != __version__:
|
||||
logger.debug(
|
||||
'dbt version mismatch: {} != {}, cache invalidated'
|
||||
.format(result.dbt_version, __version__)
|
||||
"dbt version mismatch: {} != {}, cache invalidated".format(
|
||||
result.dbt_version, __version__
|
||||
)
|
||||
)
|
||||
return False
|
||||
except AttributeError:
|
||||
logger.debug('malformed result file, cache invalidated')
|
||||
logger.debug("malformed result file, cache invalidated")
|
||||
return False
|
||||
|
||||
valid = True
|
||||
|
||||
if self.results.vars_hash != result.vars_hash:
|
||||
logger.debug('vars hash mismatch, cache invalidated')
|
||||
logger.debug("vars hash mismatch, cache invalidated")
|
||||
valid = False
|
||||
if self.results.profile_hash != result.profile_hash:
|
||||
logger.debug('profile hash mismatch, cache invalidated')
|
||||
logger.debug("profile hash mismatch, cache invalidated")
|
||||
valid = False
|
||||
|
||||
missing_keys = {
|
||||
k for k in self.results.project_hashes
|
||||
if k not in result.project_hashes
|
||||
k for k in self.results.project_hashes if k not in result.project_hashes
|
||||
}
|
||||
if missing_keys:
|
||||
logger.debug(
|
||||
'project hash mismatch: values missing, cache invalidated: {}'
|
||||
.format(missing_keys)
|
||||
"project hash mismatch: values missing, cache invalidated: {}".format(
|
||||
missing_keys
|
||||
)
|
||||
)
|
||||
valid = False
|
||||
|
||||
@@ -340,9 +347,8 @@ class ManifestLoader:
|
||||
old_value = result.project_hashes[key]
|
||||
if new_value != old_value:
|
||||
logger.debug(
|
||||
'For key {}, hash mismatch ({} -> {}), cache '
|
||||
'invalidated'
|
||||
.format(key, old_value, new_value)
|
||||
"For key {}, hash mismatch ({} -> {}), cache "
|
||||
"invalidated".format(key, old_value, new_value)
|
||||
)
|
||||
valid = False
|
||||
return valid
|
||||
@@ -359,14 +365,13 @@ class ManifestLoader:
|
||||
|
||||
def read_parse_results(self) -> Optional[ParseResult]:
|
||||
if not self._partial_parse_enabled():
|
||||
logger.debug('Partial parsing not enabled')
|
||||
logger.debug("Partial parsing not enabled")
|
||||
return None
|
||||
path = os.path.join(self.root_project.target_path,
|
||||
PARTIAL_PARSE_FILE_NAME)
|
||||
path = os.path.join(self.root_project.target_path, PARTIAL_PARSE_FILE_NAME)
|
||||
|
||||
if os.path.exists(path):
|
||||
try:
|
||||
with open(path, 'rb') as fp:
|
||||
with open(path, "rb") as fp:
|
||||
result: ParseResult = pickle.load(fp)
|
||||
# keep this check inside the try/except in case something about
|
||||
# the file has changed in weird ways, perhaps due to being a
|
||||
@@ -375,9 +380,8 @@ class ManifestLoader:
|
||||
return result
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
'Failed to load parsed file from disk at {}: {}'
|
||||
.format(path, exc),
|
||||
exc_info=True
|
||||
"Failed to load parsed file from disk at {}: {}".format(path, exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return None
|
||||
@@ -394,9 +398,7 @@ class ManifestLoader:
|
||||
# list is created
|
||||
start_patch = time.perf_counter()
|
||||
sources = patch_sources(self.results, self.root_project)
|
||||
self._perf_info.patch_sources_elapsed = (
|
||||
time.perf_counter() - start_patch
|
||||
)
|
||||
self._perf_info.patch_sources_elapsed = time.perf_counter() - start_patch
|
||||
disabled = []
|
||||
for value in self.results.disabled.values():
|
||||
disabled.extend(value)
|
||||
@@ -421,9 +423,7 @@ class ManifestLoader:
|
||||
start_process = time.perf_counter()
|
||||
self.process_manifest(manifest)
|
||||
|
||||
self._perf_info.process_manifest_elapsed = (
|
||||
time.perf_counter() - start_process
|
||||
)
|
||||
self._perf_info.process_manifest_elapsed = time.perf_counter() - start_process
|
||||
|
||||
return manifest
|
||||
|
||||
@@ -445,9 +445,7 @@ class ManifestLoader:
|
||||
_check_manifest(manifest, root_config)
|
||||
manifest.build_flat_graph()
|
||||
|
||||
loader._perf_info.load_all_elapsed = (
|
||||
time.perf_counter() - start_load_all
|
||||
)
|
||||
loader._perf_info.load_all_elapsed = time.perf_counter() - start_load_all
|
||||
|
||||
loader.track_project_load()
|
||||
|
||||
@@ -465,8 +463,9 @@ class ManifestLoader:
|
||||
return loader.load_only_macros()
|
||||
|
||||
|
||||
def invalid_ref_fail_unless_test(node, target_model_name,
|
||||
target_model_package, disabled):
|
||||
def invalid_ref_fail_unless_test(
|
||||
node, target_model_name, target_model_package, disabled
|
||||
):
|
||||
|
||||
if node.resource_type == NodeType.Test:
|
||||
msg = get_target_not_found_or_disabled_msg(
|
||||
@@ -475,10 +474,7 @@ def invalid_ref_fail_unless_test(node, target_model_name,
|
||||
if disabled:
|
||||
logger.debug(warning_tag(msg))
|
||||
else:
|
||||
warn_or_error(
|
||||
msg,
|
||||
log_fmt=warning_tag('{}')
|
||||
)
|
||||
warn_or_error(msg, log_fmt=warning_tag("{}"))
|
||||
else:
|
||||
ref_target_not_found(
|
||||
node,
|
||||
@@ -488,9 +484,7 @@ def invalid_ref_fail_unless_test(node, target_model_name,
|
||||
)
|
||||
|
||||
|
||||
def invalid_source_fail_unless_test(
|
||||
node, target_name, target_table_name, disabled
|
||||
):
|
||||
def invalid_source_fail_unless_test(node, target_name, target_table_name, disabled):
|
||||
if node.resource_type == NodeType.Test:
|
||||
msg = get_source_not_found_or_disabled_msg(
|
||||
node, target_name, target_table_name, disabled
|
||||
@@ -498,17 +492,9 @@ def invalid_source_fail_unless_test(
|
||||
if disabled:
|
||||
logger.debug(warning_tag(msg))
|
||||
else:
|
||||
warn_or_error(
|
||||
msg,
|
||||
log_fmt=warning_tag('{}')
|
||||
)
|
||||
warn_or_error(msg, log_fmt=warning_tag("{}"))
|
||||
else:
|
||||
source_target_not_found(
|
||||
node,
|
||||
target_name,
|
||||
target_table_name,
|
||||
disabled=disabled
|
||||
)
|
||||
source_target_not_found(node, target_name, target_table_name, disabled=disabled)
|
||||
|
||||
|
||||
def _check_resource_uniqueness(
|
||||
@@ -532,15 +518,11 @@ def _check_resource_uniqueness(
|
||||
|
||||
existing_node = names_resources.get(name)
|
||||
if existing_node is not None:
|
||||
dbt.exceptions.raise_duplicate_resource_name(
|
||||
existing_node, node
|
||||
)
|
||||
dbt.exceptions.raise_duplicate_resource_name(existing_node, node)
|
||||
|
||||
existing_alias = alias_resources.get(full_node_name)
|
||||
if existing_alias is not None:
|
||||
dbt.exceptions.raise_ambiguous_alias(
|
||||
existing_alias, node, full_node_name
|
||||
)
|
||||
dbt.exceptions.raise_ambiguous_alias(existing_alias, node, full_node_name)
|
||||
|
||||
names_resources[name] = node
|
||||
alias_resources[full_node_name] = node
|
||||
@@ -565,8 +547,7 @@ def _load_projects(config, paths):
|
||||
project = config.new_project(path)
|
||||
except dbt.exceptions.DbtProjectError as e:
|
||||
raise dbt.exceptions.DbtProjectError(
|
||||
'Failed to read package at {}: {}'
|
||||
.format(path, e)
|
||||
"Failed to read package at {}: {}".format(path, e)
|
||||
)
|
||||
else:
|
||||
yield project.project_name, project
|
||||
@@ -587,8 +568,7 @@ def _get_node_column(node, column_name):
|
||||
|
||||
|
||||
DocsContextCallback = Callable[
|
||||
[Union[ParsedNode, ParsedSourceDefinition]],
|
||||
Dict[str, Any]
|
||||
[Union[ParsedNode, ParsedSourceDefinition]], Dict[str, Any]
|
||||
]
|
||||
|
||||
|
||||
@@ -618,9 +598,7 @@ def _process_docs_for_source(
|
||||
column.description = column_desc
|
||||
|
||||
|
||||
def _process_docs_for_macro(
|
||||
context: Dict[str, Any], macro: ParsedMacro
|
||||
) -> None:
|
||||
def _process_docs_for_macro(context: Dict[str, Any], macro: ParsedMacro) -> None:
|
||||
macro.description = get_rendered(macro.description, context)
|
||||
for arg in macro.arguments:
|
||||
arg.description = get_rendered(arg.description, context)
|
||||
@@ -682,7 +660,7 @@ def _process_refs_for_exposure(
|
||||
target_model_package, target_model_name = ref
|
||||
else:
|
||||
raise dbt.exceptions.InternalException(
|
||||
f'Refs should always be 1 or 2 arguments - got {len(ref)}'
|
||||
f"Refs should always be 1 or 2 arguments - got {len(ref)}"
|
||||
)
|
||||
|
||||
target_model = manifest.resolve_ref(
|
||||
@@ -696,8 +674,10 @@ def _process_refs_for_exposure(
|
||||
# This may raise. Even if it doesn't, we don't want to add
|
||||
# this exposure to the graph b/c there is no destination exposure
|
||||
invalid_ref_fail_unless_test(
|
||||
exposure, target_model_name, target_model_package,
|
||||
disabled=(isinstance(target_model, Disabled))
|
||||
exposure,
|
||||
target_model_name,
|
||||
target_model_package,
|
||||
disabled=(isinstance(target_model, Disabled)),
|
||||
)
|
||||
|
||||
continue
|
||||
@@ -723,7 +703,7 @@ def _process_refs_for_node(
|
||||
target_model_package, target_model_name = ref
|
||||
else:
|
||||
raise dbt.exceptions.InternalException(
|
||||
f'Refs should always be 1 or 2 arguments - got {len(ref)}'
|
||||
f"Refs should always be 1 or 2 arguments - got {len(ref)}"
|
||||
)
|
||||
|
||||
target_model = manifest.resolve_ref(
|
||||
@@ -738,8 +718,10 @@ def _process_refs_for_node(
|
||||
# this node to the graph b/c there is no destination node
|
||||
node.config.enabled = False
|
||||
invalid_ref_fail_unless_test(
|
||||
node, target_model_name, target_model_package,
|
||||
disabled=(isinstance(target_model, Disabled))
|
||||
node,
|
||||
target_model_name,
|
||||
target_model_package,
|
||||
disabled=(isinstance(target_model, Disabled)),
|
||||
)
|
||||
|
||||
continue
|
||||
@@ -777,7 +759,7 @@ def _process_sources_for_exposure(
|
||||
exposure,
|
||||
source_name,
|
||||
table_name,
|
||||
disabled=(isinstance(target_source, Disabled))
|
||||
disabled=(isinstance(target_source, Disabled)),
|
||||
)
|
||||
continue
|
||||
target_source_id = target_source.unique_id
|
||||
@@ -804,7 +786,7 @@ def _process_sources_for_node(
|
||||
node,
|
||||
source_name,
|
||||
table_name,
|
||||
disabled=(isinstance(target_source, Disabled))
|
||||
disabled=(isinstance(target_source, Disabled)),
|
||||
)
|
||||
continue
|
||||
target_source_id = target_source.unique_id
|
||||
@@ -835,13 +817,9 @@ def process_macro(
|
||||
_process_docs_for_macro(ctx, macro)
|
||||
|
||||
|
||||
def process_node(
|
||||
config: RuntimeConfig, manifest: Manifest, node: ManifestNode
|
||||
):
|
||||
def process_node(config: RuntimeConfig, manifest: Manifest, node: ManifestNode):
|
||||
|
||||
_process_sources_for_node(
|
||||
manifest, config.project_name, node
|
||||
)
|
||||
_process_sources_for_node(manifest, config.project_name, node)
|
||||
_process_refs_for_node(manifest, config.project_name, node)
|
||||
ctx = generate_runtime_docs(config, node, manifest, config.project_name)
|
||||
_process_docs_for_node(ctx, node)
|
||||
|
||||
@@ -6,9 +6,7 @@ from dbt.parser.search import FilesystemSearcher, FileBlock
|
||||
|
||||
class ModelParser(SimpleSQLParser[ParsedModelNode]):
|
||||
def get_paths(self):
|
||||
return FilesystemSearcher(
|
||||
self.project, self.project.source_paths, '.sql'
|
||||
)
|
||||
return FilesystemSearcher(self.project, self.project.source_paths, ".sql")
|
||||
|
||||
def parse_from_dict(self, dct, validate=True) -> ParsedModelNode:
|
||||
if validate:
|
||||
|
||||
@@ -25,9 +25,13 @@ from dbt.contracts.graph.parsed import (
|
||||
from dbt.contracts.graph.unparsed import SourcePatch
|
||||
from dbt.contracts.util import Writable, Replaceable, MacroKey, SourceKey
|
||||
from dbt.exceptions import (
|
||||
raise_duplicate_resource_name, raise_duplicate_patch_name,
|
||||
raise_duplicate_macro_patch_name, CompilationException, InternalException,
|
||||
raise_compiler_error, raise_duplicate_source_patch_name
|
||||
raise_duplicate_resource_name,
|
||||
raise_duplicate_patch_name,
|
||||
raise_duplicate_macro_patch_name,
|
||||
CompilationException,
|
||||
InternalException,
|
||||
raise_compiler_error,
|
||||
raise_duplicate_source_patch_name,
|
||||
)
|
||||
from dbt.node_types import NodeType
|
||||
from dbt.ui import line_wrap_message
|
||||
@@ -35,12 +39,10 @@ from dbt.version import __version__
|
||||
|
||||
|
||||
# Parsers can return anything as long as it's a unique ID
|
||||
ParsedValueType = TypeVar('ParsedValueType', bound=HasUniqueID)
|
||||
ParsedValueType = TypeVar("ParsedValueType", bound=HasUniqueID)
|
||||
|
||||
|
||||
def _check_duplicates(
|
||||
value: HasUniqueID, src: Mapping[str, HasUniqueID]
|
||||
):
|
||||
def _check_duplicates(value: HasUniqueID, src: Mapping[str, HasUniqueID]):
|
||||
if value.unique_id in src:
|
||||
raise_duplicate_resource_name(value, src[value.unique_id])
|
||||
|
||||
@@ -86,9 +88,7 @@ class ParseResult(dbtClassMixin, Writable, Replaceable):
|
||||
self.files[key] = source_file
|
||||
return self.files[key]
|
||||
|
||||
def add_source(
|
||||
self, source_file: SourceFile, source: UnpatchedSourceDefinition
|
||||
):
|
||||
def add_source(self, source_file: SourceFile, source: UnpatchedSourceDefinition):
|
||||
# sources can't be overwritten!
|
||||
_check_duplicates(source, self.sources)
|
||||
self.sources[source.unique_id] = source
|
||||
@@ -126,7 +126,7 @@ class ParseResult(dbtClassMixin, Writable, Replaceable):
|
||||
# note that the line wrap eats newlines, so if you want newlines,
|
||||
# this is the result :(
|
||||
msg = line_wrap_message(
|
||||
f'''\
|
||||
f"""\
|
||||
dbt found two macros named "{macro.name}" in the project
|
||||
"{macro.package_name}".
|
||||
|
||||
@@ -137,8 +137,8 @@ class ParseResult(dbtClassMixin, Writable, Replaceable):
|
||||
- {macro.original_file_path}
|
||||
|
||||
- {other_path}
|
||||
''',
|
||||
subtract=2
|
||||
""",
|
||||
subtract=2,
|
||||
)
|
||||
raise_compiler_error(msg)
|
||||
|
||||
@@ -150,18 +150,14 @@ class ParseResult(dbtClassMixin, Writable, Replaceable):
|
||||
self.docs[doc.unique_id] = doc
|
||||
self.get_file(source_file).docs.append(doc.unique_id)
|
||||
|
||||
def add_patch(
|
||||
self, source_file: SourceFile, patch: ParsedNodePatch
|
||||
) -> None:
|
||||
def add_patch(self, source_file: SourceFile, patch: ParsedNodePatch) -> None:
|
||||
# patches can't be overwritten
|
||||
if patch.name in self.patches:
|
||||
raise_duplicate_patch_name(patch, self.patches[patch.name])
|
||||
self.patches[patch.name] = patch
|
||||
self.get_file(source_file).patches.append(patch.name)
|
||||
|
||||
def add_macro_patch(
|
||||
self, source_file: SourceFile, patch: ParsedMacroPatch
|
||||
) -> None:
|
||||
def add_macro_patch(self, source_file: SourceFile, patch: ParsedMacroPatch) -> None:
|
||||
# macros are fully namespaced
|
||||
key = (patch.package_name, patch.name)
|
||||
if key in self.macro_patches:
|
||||
@@ -169,9 +165,7 @@ class ParseResult(dbtClassMixin, Writable, Replaceable):
|
||||
self.macro_patches[key] = patch
|
||||
self.get_file(source_file).macro_patches.append(key)
|
||||
|
||||
def add_source_patch(
|
||||
self, source_file: SourceFile, patch: SourcePatch
|
||||
) -> None:
|
||||
def add_source_patch(self, source_file: SourceFile, patch: SourcePatch) -> None:
|
||||
# source patches must be unique
|
||||
key = (patch.overrides, patch.name)
|
||||
if key in self.source_patches:
|
||||
@@ -186,11 +180,13 @@ class ParseResult(dbtClassMixin, Writable, Replaceable):
|
||||
) -> List[CompileResultNode]:
|
||||
if unique_id not in self.disabled:
|
||||
raise InternalException(
|
||||
'called _get_disabled with id={}, but it does not exist'
|
||||
.format(unique_id)
|
||||
"called _get_disabled with id={}, but it does not exist".format(
|
||||
unique_id
|
||||
)
|
||||
)
|
||||
return [
|
||||
n for n in self.disabled[unique_id]
|
||||
n
|
||||
for n in self.disabled[unique_id]
|
||||
if n.original_file_path == match_file.path.original_file_path
|
||||
]
|
||||
|
||||
@@ -199,7 +195,7 @@ class ParseResult(dbtClassMixin, Writable, Replaceable):
|
||||
node_id: str,
|
||||
source_file: SourceFile,
|
||||
old_file: SourceFile,
|
||||
old_result: 'ParseResult',
|
||||
old_result: "ParseResult",
|
||||
) -> None:
|
||||
"""Nodes are a special kind of complicated - there can be multiple
|
||||
with the same name, as long as all but one are disabled.
|
||||
@@ -224,14 +220,15 @@ class ParseResult(dbtClassMixin, Writable, Replaceable):
|
||||
if not found:
|
||||
raise CompilationException(
|
||||
'Expected to find "{}" in cached "manifest.nodes" or '
|
||||
'"manifest.disabled" based on cached file information: {}!'
|
||||
.format(node_id, old_file)
|
||||
'"manifest.disabled" based on cached file information: {}!'.format(
|
||||
node_id, old_file
|
||||
)
|
||||
)
|
||||
|
||||
def sanitized_update(
|
||||
self,
|
||||
source_file: SourceFile,
|
||||
old_result: 'ParseResult',
|
||||
old_result: "ParseResult",
|
||||
resource_type: NodeType,
|
||||
) -> bool:
|
||||
"""Perform a santized update. If the file can't be updated, invalidate
|
||||
@@ -246,15 +243,11 @@ class ParseResult(dbtClassMixin, Writable, Replaceable):
|
||||
self.add_doc(source_file, doc)
|
||||
|
||||
for macro_id in old_file.macros:
|
||||
macro = _expect_value(
|
||||
macro_id, old_result.macros, old_file, "macros"
|
||||
)
|
||||
macro = _expect_value(macro_id, old_result.macros, old_file, "macros")
|
||||
self.add_macro(source_file, macro)
|
||||
|
||||
for source_id in old_file.sources:
|
||||
source = _expect_value(
|
||||
source_id, old_result.sources, old_file, "sources"
|
||||
)
|
||||
source = _expect_value(source_id, old_result.sources, old_file, "sources")
|
||||
self.add_source(source_file, source)
|
||||
|
||||
# because we know this is how we _parsed_ the node, we can safely
|
||||
@@ -265,7 +258,7 @@ class ParseResult(dbtClassMixin, Writable, Replaceable):
|
||||
for node_id in old_file.nodes:
|
||||
# cheat: look at the first part of the node ID and compare it to
|
||||
# the parser resource type. On a mismatch, bail out.
|
||||
if resource_type != node_id.split('.')[0]:
|
||||
if resource_type != node_id.split(".")[0]:
|
||||
continue
|
||||
self._process_node(node_id, source_file, old_file, old_result)
|
||||
|
||||
@@ -277,9 +270,7 @@ class ParseResult(dbtClassMixin, Writable, Replaceable):
|
||||
|
||||
patched = False
|
||||
for name in old_file.patches:
|
||||
patch = _expect_value(
|
||||
name, old_result.patches, old_file, "patches"
|
||||
)
|
||||
patch = _expect_value(name, old_result.patches, old_file, "patches")
|
||||
self.add_patch(source_file, patch)
|
||||
patched = True
|
||||
if patched:
|
||||
@@ -312,8 +303,8 @@ class ParseResult(dbtClassMixin, Writable, Replaceable):
|
||||
return cls(FileHash.empty(), FileHash.empty(), {})
|
||||
|
||||
|
||||
K_T = TypeVar('K_T')
|
||||
V_T = TypeVar('V_T')
|
||||
K_T = TypeVar("K_T")
|
||||
V_T = TypeVar("V_T")
|
||||
|
||||
|
||||
def _expect_value(
|
||||
@@ -322,7 +313,6 @@ def _expect_value(
|
||||
if key not in src:
|
||||
raise CompilationException(
|
||||
'Expected to find "{}" in cached "result.{}" based '
|
||||
'on cached file information: {}!'
|
||||
.format(key, name, old_file)
|
||||
"on cached file information: {}!".format(key, name, old_file)
|
||||
)
|
||||
return src[key]
|
||||
|
||||
@@ -38,11 +38,11 @@ class RPCCallParser(SimpleSQLParser[ParsedRPCNode]):
|
||||
# we do it this way to make mypy happy
|
||||
if not isinstance(block, RPCBlock):
|
||||
raise InternalException(
|
||||
'While parsing RPC calls, got an actual file block instead of '
|
||||
'an RPC block: {}'.format(block)
|
||||
"While parsing RPC calls, got an actual file block instead of "
|
||||
"an RPC block: {}".format(block)
|
||||
)
|
||||
|
||||
return os.path.join('rpc', block.name)
|
||||
return os.path.join("rpc", block.name)
|
||||
|
||||
def parse_remote(self, sql: str, name: str) -> ParsedRPCNode:
|
||||
source_file = SourceFile.remote(contents=sql)
|
||||
@@ -53,8 +53,8 @@ class RPCCallParser(SimpleSQLParser[ParsedRPCNode]):
|
||||
class RPCMacroParser(MacroParser):
|
||||
def parse_remote(self, contents) -> Iterable[ParsedMacro]:
|
||||
base = UnparsedMacro(
|
||||
path='from remote system',
|
||||
original_file_path='from remote system',
|
||||
path="from remote system",
|
||||
original_file_path="from remote system",
|
||||
package_name=self.project.project_name,
|
||||
raw_sql=contents,
|
||||
root_path=self.project.project_root,
|
||||
|
||||
@@ -3,7 +3,13 @@ import re
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Generic, TypeVar, Dict, Any, Tuple, Optional, List,
|
||||
Generic,
|
||||
TypeVar,
|
||||
Dict,
|
||||
Any,
|
||||
Tuple,
|
||||
Optional,
|
||||
List,
|
||||
)
|
||||
|
||||
from dbt.clients.jinja import get_rendered, SCHEMA_TEST_KWARGS_NAME
|
||||
@@ -25,7 +31,7 @@ def get_nice_schema_test_name(
|
||||
flat_args = []
|
||||
for arg_name in sorted(args):
|
||||
# the model is already embedded in the name, so skip it
|
||||
if arg_name == 'model':
|
||||
if arg_name == "model":
|
||||
continue
|
||||
arg_val = args[arg_name]
|
||||
|
||||
@@ -38,17 +44,17 @@ def get_nice_schema_test_name(
|
||||
|
||||
flat_args.extend([str(part) for part in parts])
|
||||
|
||||
clean_flat_args = [re.sub('[^0-9a-zA-Z_]+', '_', arg) for arg in flat_args]
|
||||
clean_flat_args = [re.sub("[^0-9a-zA-Z_]+", "_", arg) for arg in flat_args]
|
||||
unique = "__".join(clean_flat_args)
|
||||
|
||||
cutoff = 32
|
||||
if len(unique) <= cutoff:
|
||||
label = unique
|
||||
else:
|
||||
label = hashlib.md5(unique.encode('utf-8')).hexdigest()
|
||||
label = hashlib.md5(unique.encode("utf-8")).hexdigest()
|
||||
|
||||
filename = '{}_{}_{}'.format(test_type, test_name, label)
|
||||
name = '{}_{}_{}'.format(test_type, test_name, unique)
|
||||
filename = "{}_{}_{}".format(test_type, test_name, label)
|
||||
name = "{}_{}_{}".format(test_type, test_name, unique)
|
||||
|
||||
return filename, name
|
||||
|
||||
@@ -65,19 +71,17 @@ class YamlBlock(FileBlock):
|
||||
)
|
||||
|
||||
|
||||
Testable = TypeVar(
|
||||
'Testable', UnparsedNodeUpdate, UnpatchedSourceDefinition
|
||||
)
|
||||
Testable = TypeVar("Testable", UnparsedNodeUpdate, UnpatchedSourceDefinition)
|
||||
|
||||
ColumnTarget = TypeVar(
|
||||
'ColumnTarget',
|
||||
"ColumnTarget",
|
||||
UnparsedNodeUpdate,
|
||||
UnparsedAnalysisUpdate,
|
||||
UnpatchedSourceDefinition,
|
||||
)
|
||||
|
||||
Target = TypeVar(
|
||||
'Target',
|
||||
"Target",
|
||||
UnparsedNodeUpdate,
|
||||
UnparsedMacroUpdate,
|
||||
UnparsedAnalysisUpdate,
|
||||
@@ -103,9 +107,7 @@ class TargetBlock(YamlBlock, Generic[Target]):
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_yaml_block(
|
||||
cls, src: YamlBlock, target: Target
|
||||
) -> 'TargetBlock[Target]':
|
||||
def from_yaml_block(cls, src: YamlBlock, target: Target) -> "TargetBlock[Target]":
|
||||
return cls(
|
||||
file=src.file,
|
||||
data=src.data,
|
||||
@@ -137,9 +139,7 @@ class TestBlock(TargetColumnsBlock[Testable], Generic[Testable]):
|
||||
return self.target.quote_columns
|
||||
|
||||
@classmethod
|
||||
def from_yaml_block(
|
||||
cls, src: YamlBlock, target: Testable
|
||||
) -> 'TestBlock[Testable]':
|
||||
def from_yaml_block(cls, src: YamlBlock, target: Testable) -> "TestBlock[Testable]":
|
||||
return cls(
|
||||
file=src.file,
|
||||
data=src.data,
|
||||
@@ -160,7 +160,7 @@ class SchemaTestBlock(TestBlock[Testable], Generic[Testable]):
|
||||
test: Dict[str, Any],
|
||||
column_name: Optional[str],
|
||||
tags: List[str],
|
||||
) -> 'SchemaTestBlock':
|
||||
) -> "SchemaTestBlock":
|
||||
return cls(
|
||||
file=src.file,
|
||||
data=src.data,
|
||||
@@ -179,13 +179,14 @@ class TestBuilder(Generic[Testable]):
|
||||
- or it may not be namespaced (test)
|
||||
|
||||
"""
|
||||
|
||||
# The 'test_name' is used to find the 'macro' that implements the test
|
||||
TEST_NAME_PATTERN = re.compile(
|
||||
r'((?P<test_namespace>([a-zA-Z_][0-9a-zA-Z_]*))\.)?'
|
||||
r'(?P<test_name>([a-zA-Z_][0-9a-zA-Z_]*))'
|
||||
r"((?P<test_namespace>([a-zA-Z_][0-9a-zA-Z_]*))\.)?"
|
||||
r"(?P<test_name>([a-zA-Z_][0-9a-zA-Z_]*))"
|
||||
)
|
||||
# map magic keys to default values
|
||||
MODIFIER_ARGS = {'severity': 'ERROR', 'tags': []}
|
||||
MODIFIER_ARGS = {"severity": "ERROR", "tags": []}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -197,25 +198,24 @@ class TestBuilder(Generic[Testable]):
|
||||
) -> None:
|
||||
test_name, test_args = self.extract_test_args(test, column_name)
|
||||
self.args: Dict[str, Any] = test_args
|
||||
if 'model' in self.args:
|
||||
if "model" in self.args:
|
||||
raise_compiler_error(
|
||||
'Test arguments include "model", which is a reserved argument',
|
||||
)
|
||||
self.package_name: str = package_name
|
||||
self.target: Testable = target
|
||||
|
||||
self.args['model'] = self.build_model_str()
|
||||
self.args["model"] = self.build_model_str()
|
||||
|
||||
match = self.TEST_NAME_PATTERN.match(test_name)
|
||||
if match is None:
|
||||
raise_compiler_error(
|
||||
'Test name string did not match expected pattern: {}'
|
||||
.format(test_name)
|
||||
"Test name string did not match expected pattern: {}".format(test_name)
|
||||
)
|
||||
|
||||
groups = match.groupdict()
|
||||
self.name: str = groups['test_name']
|
||||
self.namespace: str = groups['test_namespace']
|
||||
self.name: str = groups["test_name"]
|
||||
self.namespace: str = groups["test_namespace"]
|
||||
self.modifiers: Dict[str, Any] = {}
|
||||
for key, default in self.MODIFIER_ARGS.items():
|
||||
value = self.args.pop(key, default)
|
||||
@@ -237,57 +237,52 @@ class TestBuilder(Generic[Testable]):
|
||||
def extract_test_args(test, name=None) -> Tuple[str, Dict[str, Any]]:
|
||||
if not isinstance(test, dict):
|
||||
raise_compiler_error(
|
||||
'test must be dict or str, got {} (value {})'.format(
|
||||
type(test), test
|
||||
)
|
||||
"test must be dict or str, got {} (value {})".format(type(test), test)
|
||||
)
|
||||
|
||||
test = list(test.items())
|
||||
if len(test) != 1:
|
||||
raise_compiler_error(
|
||||
'test definition dictionary must have exactly one key, got'
|
||||
' {} instead ({} keys)'.format(test, len(test))
|
||||
"test definition dictionary must have exactly one key, got"
|
||||
" {} instead ({} keys)".format(test, len(test))
|
||||
)
|
||||
test_name, test_args = test[0]
|
||||
|
||||
if not isinstance(test_args, dict):
|
||||
raise_compiler_error(
|
||||
'test arguments must be dict, got {} (value {})'.format(
|
||||
"test arguments must be dict, got {} (value {})".format(
|
||||
type(test_args), test_args
|
||||
)
|
||||
)
|
||||
if not isinstance(test_name, str):
|
||||
raise_compiler_error(
|
||||
'test name must be a str, got {} (value {})'.format(
|
||||
"test name must be a str, got {} (value {})".format(
|
||||
type(test_name), test_name
|
||||
)
|
||||
)
|
||||
test_args = deepcopy(test_args)
|
||||
if name is not None:
|
||||
test_args['column_name'] = name
|
||||
test_args["column_name"] = name
|
||||
return test_name, test_args
|
||||
|
||||
def severity(self) -> str:
|
||||
return self.modifiers.get('severity', 'ERROR').upper()
|
||||
return self.modifiers.get("severity", "ERROR").upper()
|
||||
|
||||
def tags(self) -> List[str]:
|
||||
tags = self.modifiers.get('tags', [])
|
||||
tags = self.modifiers.get("tags", [])
|
||||
if isinstance(tags, str):
|
||||
tags = [tags]
|
||||
if not isinstance(tags, list):
|
||||
raise_compiler_error(
|
||||
f'got {tags} ({type(tags)}) for tags, expected a list of '
|
||||
f'strings'
|
||||
f"got {tags} ({type(tags)}) for tags, expected a list of " f"strings"
|
||||
)
|
||||
for tag in tags:
|
||||
if not isinstance(tag, str):
|
||||
raise_compiler_error(
|
||||
f'got {tag} ({type(tag)}) for tag, expected a str'
|
||||
)
|
||||
raise_compiler_error(f"got {tag} ({type(tag)}) for tag, expected a str")
|
||||
return tags[:]
|
||||
|
||||
def macro_name(self) -> str:
|
||||
macro_name = 'test_{}'.format(self.name)
|
||||
macro_name = "test_{}".format(self.name)
|
||||
if self.namespace is not None:
|
||||
macro_name = "{}.{}".format(self.namespace, macro_name)
|
||||
return macro_name
|
||||
@@ -296,11 +291,11 @@ class TestBuilder(Generic[Testable]):
|
||||
if isinstance(self.target, UnparsedNodeUpdate):
|
||||
name = self.name
|
||||
elif isinstance(self.target, UnpatchedSourceDefinition):
|
||||
name = 'source_' + self.name
|
||||
name = "source_" + self.name
|
||||
else:
|
||||
raise self._bad_type()
|
||||
if self.namespace is not None:
|
||||
name = '{}_{}'.format(self.namespace, name)
|
||||
name = "{}_{}".format(self.namespace, name)
|
||||
return get_nice_schema_test_name(name, self.target.name, self.args)
|
||||
|
||||
# this is the 'raw_sql' that's used in 'render_update' and execution
|
||||
|
||||
@@ -2,9 +2,7 @@ import itertools
|
||||
import os
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import (
|
||||
Iterable, Dict, Any, Union, List, Optional, Generic, TypeVar, Type
|
||||
)
|
||||
from typing import Iterable, Dict, Any, Union, List, Optional, Generic, TypeVar, Type
|
||||
|
||||
from dbt.dataclass_schema import ValidationError, dbtClassMixin
|
||||
|
||||
@@ -20,9 +18,7 @@ from dbt.context.context_config import (
|
||||
)
|
||||
from dbt.context.configured import generate_schema_yml
|
||||
from dbt.context.target import generate_target_context
|
||||
from dbt.context.providers import (
|
||||
generate_parse_exposure, generate_test_context
|
||||
)
|
||||
from dbt.context.providers import generate_parse_exposure, generate_test_context
|
||||
from dbt.context.macro_resolver import MacroResolver
|
||||
from dbt.contracts.files import FileHash
|
||||
from dbt.contracts.graph.manifest import SourceFile
|
||||
@@ -50,20 +46,26 @@ from dbt.contracts.graph.unparsed import (
|
||||
UnparsedSourceDefinition,
|
||||
)
|
||||
from dbt.exceptions import (
|
||||
validator_error_message, JSONValidationException,
|
||||
raise_invalid_schema_yml_version, ValidationException,
|
||||
CompilationException, warn_or_error, InternalException
|
||||
validator_error_message,
|
||||
JSONValidationException,
|
||||
raise_invalid_schema_yml_version,
|
||||
ValidationException,
|
||||
CompilationException,
|
||||
warn_or_error,
|
||||
InternalException,
|
||||
)
|
||||
from dbt.node_types import NodeType
|
||||
from dbt.parser.base import SimpleParser
|
||||
from dbt.parser.search import FileBlock, FilesystemSearcher
|
||||
from dbt.parser.schema_test_builders import (
|
||||
TestBuilder, SchemaTestBlock, TargetBlock, YamlBlock,
|
||||
TestBlock, Testable
|
||||
)
|
||||
from dbt.utils import (
|
||||
get_pseudo_test_path, coerce_dict_str
|
||||
TestBuilder,
|
||||
SchemaTestBlock,
|
||||
TargetBlock,
|
||||
YamlBlock,
|
||||
TestBlock,
|
||||
Testable,
|
||||
)
|
||||
from dbt.utils import get_pseudo_test_path, coerce_dict_str
|
||||
|
||||
|
||||
UnparsedSchemaYaml = Union[
|
||||
@@ -80,19 +82,17 @@ def error_context(
|
||||
path: str,
|
||||
key: str,
|
||||
data: Any,
|
||||
cause: Union[str, ValidationException, JSONValidationException]
|
||||
cause: Union[str, ValidationException, JSONValidationException],
|
||||
) -> str:
|
||||
"""Provide contextual information about an error while parsing
|
||||
"""
|
||||
"""Provide contextual information about an error while parsing"""
|
||||
if isinstance(cause, str):
|
||||
reason = cause
|
||||
elif isinstance(cause, ValidationError):
|
||||
reason = validator_error_message(cause)
|
||||
else:
|
||||
reason = cause.msg
|
||||
return (
|
||||
'Invalid {key} config given in {path} @ {key}: {data} - {reason}'
|
||||
.format(key=key, path=path, data=data, reason=reason)
|
||||
return "Invalid {key} config given in {path} @ {key}: {data} - {reason}".format(
|
||||
key=key, path=path, data=data, reason=reason
|
||||
)
|
||||
|
||||
|
||||
@@ -110,7 +110,7 @@ class ParserRef:
|
||||
meta: Dict[str, Any],
|
||||
):
|
||||
tags: List[str] = []
|
||||
tags.extend(getattr(column, 'tags', ()))
|
||||
tags.extend(getattr(column, "tags", ()))
|
||||
quote: Optional[bool]
|
||||
if isinstance(column, UnparsedColumn):
|
||||
quote = column.quote
|
||||
@@ -123,13 +123,11 @@ class ParserRef:
|
||||
meta=meta,
|
||||
tags=tags,
|
||||
quote=quote,
|
||||
_extra=column.extra
|
||||
_extra=column.extra,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_target(
|
||||
cls, target: Union[HasColumnDocs, HasColumnTests]
|
||||
) -> 'ParserRef':
|
||||
def from_target(cls, target: Union[HasColumnDocs, HasColumnTests]) -> "ParserRef":
|
||||
refs = cls()
|
||||
for column in target.columns:
|
||||
description = column.description
|
||||
@@ -142,7 +140,7 @@ class ParserRef:
|
||||
def _trimmed(inp: str) -> str:
|
||||
if len(inp) < 50:
|
||||
return inp
|
||||
return inp[:44] + '...' + inp[-3:]
|
||||
return inp[:44] + "..." + inp[-3:]
|
||||
|
||||
|
||||
def merge_freshness(
|
||||
@@ -158,21 +156,20 @@ def merge_freshness(
|
||||
|
||||
class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
||||
def __init__(
|
||||
self, results, project, root_project, macro_manifest,
|
||||
self,
|
||||
results,
|
||||
project,
|
||||
root_project,
|
||||
macro_manifest,
|
||||
) -> None:
|
||||
super().__init__(results, project, root_project, macro_manifest)
|
||||
all_v_2 = (
|
||||
self.root_project.config_version == 2 and
|
||||
self.project.config_version == 2
|
||||
self.root_project.config_version == 2 and self.project.config_version == 2
|
||||
)
|
||||
if all_v_2:
|
||||
ctx = generate_schema_yml(
|
||||
self.root_project, self.project.project_name
|
||||
)
|
||||
ctx = generate_schema_yml(self.root_project, self.project.project_name)
|
||||
else:
|
||||
ctx = generate_target_context(
|
||||
self.root_project, self.root_project.cli_vars
|
||||
)
|
||||
ctx = generate_target_context(self.root_project, self.root_project.cli_vars)
|
||||
|
||||
self.raw_renderer = SchemaYamlRenderer(ctx)
|
||||
|
||||
@@ -182,7 +179,7 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
||||
self.macro_resolver = MacroResolver(
|
||||
self.macro_manifest.macros,
|
||||
self.root_project.project_name,
|
||||
internal_package_names
|
||||
internal_package_names,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -197,65 +194,55 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
||||
def get_paths(self):
|
||||
# TODO: In order to support this, make FilesystemSearcher accept a list
|
||||
# of file patterns. eg: ['.yml', '.yaml']
|
||||
yaml_files = list(FilesystemSearcher(
|
||||
self.project, self.project.all_source_paths, '.yaml'
|
||||
))
|
||||
yaml_files = list(
|
||||
FilesystemSearcher(self.project, self.project.all_source_paths, ".yaml")
|
||||
)
|
||||
if yaml_files:
|
||||
warn_or_error(
|
||||
'A future version of dbt will parse files with both'
|
||||
' .yml and .yaml file extensions. dbt found'
|
||||
f' {len(yaml_files)} files with .yaml extensions in'
|
||||
' your dbt project. To avoid errors when upgrading'
|
||||
' to a future release, either remove these files from'
|
||||
' your dbt project, or change their extensions.'
|
||||
)
|
||||
return FilesystemSearcher(
|
||||
self.project, self.project.all_source_paths, '.yml'
|
||||
"A future version of dbt will parse files with both"
|
||||
" .yml and .yaml file extensions. dbt found"
|
||||
f" {len(yaml_files)} files with .yaml extensions in"
|
||||
" your dbt project. To avoid errors when upgrading"
|
||||
" to a future release, either remove these files from"
|
||||
" your dbt project, or change their extensions."
|
||||
)
|
||||
return FilesystemSearcher(self.project, self.project.all_source_paths, ".yml")
|
||||
|
||||
def parse_from_dict(self, dct, validate=True) -> ParsedSchemaTestNode:
|
||||
if validate:
|
||||
ParsedSchemaTestNode.validate(dct)
|
||||
return ParsedSchemaTestNode.from_dict(dct)
|
||||
|
||||
def _check_format_version(
|
||||
self, yaml: YamlBlock
|
||||
) -> None:
|
||||
def _check_format_version(self, yaml: YamlBlock) -> None:
|
||||
path = yaml.path.relative_path
|
||||
if 'version' not in yaml.data:
|
||||
raise_invalid_schema_yml_version(path, 'no version is specified')
|
||||
if "version" not in yaml.data:
|
||||
raise_invalid_schema_yml_version(path, "no version is specified")
|
||||
|
||||
version = yaml.data['version']
|
||||
version = yaml.data["version"]
|
||||
# if it's not an integer, the version is malformed, or not
|
||||
# set. Either way, only 'version: 2' is supported.
|
||||
if not isinstance(version, int):
|
||||
raise_invalid_schema_yml_version(
|
||||
path, 'the version is not an integer'
|
||||
)
|
||||
raise_invalid_schema_yml_version(path, "the version is not an integer")
|
||||
if version != 2:
|
||||
raise_invalid_schema_yml_version(
|
||||
path, 'version {} is not supported'.format(version)
|
||||
path, "version {} is not supported".format(version)
|
||||
)
|
||||
|
||||
def _yaml_from_file(
|
||||
self, source_file: SourceFile
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""If loading the yaml fails, raise an exception.
|
||||
"""
|
||||
def _yaml_from_file(self, source_file: SourceFile) -> Optional[Dict[str, Any]]:
|
||||
"""If loading the yaml fails, raise an exception."""
|
||||
path: str = source_file.path.relative_path
|
||||
try:
|
||||
return load_yaml_text(source_file.contents)
|
||||
except ValidationException as e:
|
||||
reason = validator_error_message(e)
|
||||
raise CompilationException(
|
||||
'Error reading {}: {} - {}'
|
||||
.format(self.project.project_name, path, reason)
|
||||
"Error reading {}: {} - {}".format(
|
||||
self.project.project_name, path, reason
|
||||
)
|
||||
)
|
||||
return None
|
||||
|
||||
def parse_column_tests(
|
||||
self, block: TestBlock, column: UnparsedColumn
|
||||
) -> None:
|
||||
def parse_column_tests(self, block: TestBlock, column: UnparsedColumn) -> None:
|
||||
if not column.tests:
|
||||
return
|
||||
|
||||
@@ -267,9 +254,7 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
||||
if rendered:
|
||||
generator = ContextConfigGenerator(self.root_project)
|
||||
else:
|
||||
generator = UnrenderedConfigGenerator(
|
||||
self.root_project
|
||||
)
|
||||
generator = UnrenderedConfigGenerator(self.root_project)
|
||||
|
||||
return generator.calculate_node_config(
|
||||
config_calls=[],
|
||||
@@ -284,16 +269,14 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
||||
relation_cls = adapter.Relation
|
||||
return str(relation_cls.create_from(self.root_project, node))
|
||||
|
||||
def parse_source(
|
||||
self, target: UnpatchedSourceDefinition
|
||||
) -> ParsedSourceDefinition:
|
||||
def parse_source(self, target: UnpatchedSourceDefinition) -> ParsedSourceDefinition:
|
||||
source = target.source
|
||||
table = target.table
|
||||
refs = ParserRef.from_target(table)
|
||||
unique_id = target.unique_id
|
||||
description = table.description or ''
|
||||
description = table.description or ""
|
||||
meta = table.meta or {}
|
||||
source_description = source.description or ''
|
||||
source_description = source.description or ""
|
||||
loaded_at_field = table.loaded_at_field or source.loaded_at_field
|
||||
|
||||
freshness = merge_freshness(source.freshness, table.freshness)
|
||||
@@ -316,8 +299,8 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
||||
|
||||
if not isinstance(config, SourceConfig):
|
||||
raise InternalException(
|
||||
f'Calculated a {type(config)} for a source, but expected '
|
||||
f'a SourceConfig'
|
||||
f"Calculated a {type(config)} for a source, but expected "
|
||||
f"a SourceConfig"
|
||||
)
|
||||
|
||||
default_database = self.root_project.credentials.database
|
||||
@@ -369,23 +352,23 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
||||
) -> ParsedSchemaTestNode:
|
||||
|
||||
dct = {
|
||||
'alias': name,
|
||||
'schema': self.default_schema,
|
||||
'database': self.default_database,
|
||||
'fqn': fqn,
|
||||
'name': name,
|
||||
'root_path': self.project.project_root,
|
||||
'resource_type': self.resource_type,
|
||||
'tags': tags,
|
||||
'path': path,
|
||||
'original_file_path': target.original_file_path,
|
||||
'package_name': self.project.project_name,
|
||||
'raw_sql': raw_sql,
|
||||
'unique_id': self.generate_unique_id(name),
|
||||
'config': self.config_dict(config),
|
||||
'test_metadata': test_metadata,
|
||||
'column_name': column_name,
|
||||
'checksum': FileHash.empty().to_dict(omit_none=True),
|
||||
"alias": name,
|
||||
"schema": self.default_schema,
|
||||
"database": self.default_database,
|
||||
"fqn": fqn,
|
||||
"name": name,
|
||||
"root_path": self.project.project_root,
|
||||
"resource_type": self.resource_type,
|
||||
"tags": tags,
|
||||
"path": path,
|
||||
"original_file_path": target.original_file_path,
|
||||
"package_name": self.project.project_name,
|
||||
"raw_sql": raw_sql,
|
||||
"unique_id": self.generate_unique_id(name),
|
||||
"config": self.config_dict(config),
|
||||
"test_metadata": test_metadata,
|
||||
"column_name": column_name,
|
||||
"checksum": FileHash.empty().to_dict(omit_none=True),
|
||||
}
|
||||
try:
|
||||
ParsedSchemaTestNode.validate(dct)
|
||||
@@ -424,18 +407,20 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
||||
)
|
||||
except CompilationException as exc:
|
||||
context = _trimmed(str(target))
|
||||
msg = (
|
||||
'Invalid test config given in {}:'
|
||||
'\n\t{}\n\t@: {}'
|
||||
.format(target.original_file_path, exc.msg, context)
|
||||
msg = "Invalid test config given in {}:" "\n\t{}\n\t@: {}".format(
|
||||
target.original_file_path, exc.msg, context
|
||||
)
|
||||
raise CompilationException(msg) from exc
|
||||
original_name = os.path.basename(target.original_file_path)
|
||||
compiled_path = get_pseudo_test_path(
|
||||
builder.compiled_name, original_name, 'schema_test',
|
||||
builder.compiled_name,
|
||||
original_name,
|
||||
"schema_test",
|
||||
)
|
||||
fqn_path = get_pseudo_test_path(
|
||||
builder.fqn_name, original_name, 'schema_test',
|
||||
builder.fqn_name,
|
||||
original_name,
|
||||
"schema_test",
|
||||
)
|
||||
# the fqn for tests actually happens in the test target's name, which
|
||||
# is not necessarily this package's name
|
||||
@@ -445,13 +430,13 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
||||
config = self.initial_config(fqn)
|
||||
|
||||
metadata = {
|
||||
'namespace': builder.namespace,
|
||||
'name': builder.name,
|
||||
'kwargs': builder.args,
|
||||
"namespace": builder.namespace,
|
||||
"name": builder.name,
|
||||
"kwargs": builder.args,
|
||||
}
|
||||
tags = sorted(set(itertools.chain(tags, builder.tags())))
|
||||
if 'schema' not in tags:
|
||||
tags.append('schema')
|
||||
if "schema" not in tags:
|
||||
tags.append("schema")
|
||||
|
||||
node = self.create_test_node(
|
||||
target=target,
|
||||
@@ -477,15 +462,15 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
||||
# parsing to avoid jinja overhead.
|
||||
def render_test_update(self, node, config, builder):
|
||||
macro_unique_id = self.macro_resolver.get_macro_id(
|
||||
node.package_name, 'test_' + builder.name)
|
||||
node.package_name, "test_" + builder.name
|
||||
)
|
||||
# Add the depends_on here so we can limit the macros added
|
||||
# to the context in rendering processing
|
||||
node.depends_on.add_macro(macro_unique_id)
|
||||
if (macro_unique_id in
|
||||
['macro.dbt.test_not_null', 'macro.dbt.test_unique']):
|
||||
if macro_unique_id in ["macro.dbt.test_not_null", "macro.dbt.test_unique"]:
|
||||
self.update_parsed_node(node, config)
|
||||
node.unrendered_config['severity'] = builder.severity()
|
||||
node.config['severity'] = builder.severity()
|
||||
node.unrendered_config["severity"] = builder.severity()
|
||||
node.config["severity"] = builder.severity()
|
||||
# source node tests are processed at patch_source time
|
||||
if isinstance(builder.target, UnpatchedSourceDefinition):
|
||||
sources = [builder.target.fqn[-2], builder.target.fqn[-1]]
|
||||
@@ -496,15 +481,16 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
||||
try:
|
||||
# make a base context that doesn't have the magic kwargs field
|
||||
context = generate_test_context(
|
||||
node, self.root_project, self.macro_manifest, config,
|
||||
node,
|
||||
self.root_project,
|
||||
self.macro_manifest,
|
||||
config,
|
||||
self.macro_resolver,
|
||||
)
|
||||
# update with rendered test kwargs (which collects any refs)
|
||||
add_rendered_test_kwargs(context, node, capture_macros=True)
|
||||
# the parsed node is not rendered in the native context.
|
||||
get_rendered(
|
||||
node.raw_sql, context, node, capture_macros=True
|
||||
)
|
||||
get_rendered(node.raw_sql, context, node, capture_macros=True)
|
||||
self.update_parsed_node(node, config)
|
||||
except ValidationError as exc:
|
||||
# we got a ValidationError - probably bad types in config()
|
||||
@@ -522,9 +508,8 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
||||
column_name = None
|
||||
else:
|
||||
column_name = column.name
|
||||
should_quote = (
|
||||
column.quote or
|
||||
(column.quote is None and target.quote_columns)
|
||||
should_quote = column.quote or (
|
||||
column.quote is None and target.quote_columns
|
||||
)
|
||||
if should_quote:
|
||||
column_name = get_adapter(self.root_project).quote(column_name)
|
||||
@@ -535,10 +520,7 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
||||
tags = list(itertools.chain.from_iterable(tags_sources))
|
||||
|
||||
node = self._parse_generic_test(
|
||||
target=target,
|
||||
test=test,
|
||||
tags=tags,
|
||||
column_name=column_name
|
||||
target=target, test=test, tags=tags, column_name=column_name
|
||||
)
|
||||
# we can't go through result.add_node - no file... instead!
|
||||
if node.config.enabled:
|
||||
@@ -562,7 +544,9 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
||||
return node
|
||||
|
||||
def render_with_context(
|
||||
self, node: ParsedSchemaTestNode, config: ContextConfig,
|
||||
self,
|
||||
node: ParsedSchemaTestNode,
|
||||
config: ContextConfig,
|
||||
) -> None:
|
||||
"""Given the parsed node and a ContextConfig to use during
|
||||
parsing, collect all the refs that might be squirreled away in the test
|
||||
@@ -574,9 +558,7 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
||||
add_rendered_test_kwargs(context, node, capture_macros=True)
|
||||
|
||||
# the parsed node is not rendered in the native context.
|
||||
get_rendered(
|
||||
node.raw_sql, context, node, capture_macros=True
|
||||
)
|
||||
get_rendered(node.raw_sql, context, node, capture_macros=True)
|
||||
|
||||
def parse_test(
|
||||
self,
|
||||
@@ -592,9 +574,8 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
||||
column_tags: List[str] = []
|
||||
else:
|
||||
column_name = column.name
|
||||
should_quote = (
|
||||
column.quote or
|
||||
(column.quote is None and target_block.quote_columns)
|
||||
should_quote = column.quote or (
|
||||
column.quote is None and target_block.quote_columns
|
||||
)
|
||||
if should_quote:
|
||||
column_name = get_adapter(self.root_project).quote(column_name)
|
||||
@@ -632,8 +613,8 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
||||
dct = self.raw_renderer.render_data(dct)
|
||||
except CompilationException as exc:
|
||||
raise CompilationException(
|
||||
f'Failed to render {block.path.original_file_path} from '
|
||||
f'project {self.project.project_name}: {exc}'
|
||||
f"Failed to render {block.path.original_file_path} from "
|
||||
f"project {self.project.project_name}: {exc}"
|
||||
) from exc
|
||||
|
||||
# contains the FileBlock and the data (dictionary)
|
||||
@@ -649,66 +630,57 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
||||
|
||||
# NonSourceParser.parse(), TestablePatchParser is a variety of
|
||||
# NodePatchParser
|
||||
if 'models' in dct:
|
||||
parser = TestablePatchParser(self, yaml_block, 'models')
|
||||
if "models" in dct:
|
||||
parser = TestablePatchParser(self, yaml_block, "models")
|
||||
for test_block in parser.parse():
|
||||
self.parse_tests(test_block)
|
||||
|
||||
# NonSourceParser.parse()
|
||||
if 'seeds' in dct:
|
||||
parser = TestablePatchParser(self, yaml_block, 'seeds')
|
||||
if "seeds" in dct:
|
||||
parser = TestablePatchParser(self, yaml_block, "seeds")
|
||||
for test_block in parser.parse():
|
||||
self.parse_tests(test_block)
|
||||
|
||||
# NonSourceParser.parse()
|
||||
if 'snapshots' in dct:
|
||||
parser = TestablePatchParser(self, yaml_block, 'snapshots')
|
||||
if "snapshots" in dct:
|
||||
parser = TestablePatchParser(self, yaml_block, "snapshots")
|
||||
for test_block in parser.parse():
|
||||
self.parse_tests(test_block)
|
||||
|
||||
# This parser uses SourceParser.parse() which doesn't return
|
||||
# any test blocks. Source tests are handled at a later point
|
||||
# in the process.
|
||||
if 'sources' in dct:
|
||||
parser = SourceParser(self, yaml_block, 'sources')
|
||||
if "sources" in dct:
|
||||
parser = SourceParser(self, yaml_block, "sources")
|
||||
parser.parse()
|
||||
|
||||
# NonSourceParser.parse()
|
||||
if 'macros' in dct:
|
||||
parser = MacroPatchParser(self, yaml_block, 'macros')
|
||||
if "macros" in dct:
|
||||
parser = MacroPatchParser(self, yaml_block, "macros")
|
||||
for test_block in parser.parse():
|
||||
self.parse_tests(test_block)
|
||||
|
||||
# NonSourceParser.parse()
|
||||
if 'analyses' in dct:
|
||||
parser = AnalysisPatchParser(self, yaml_block, 'analyses')
|
||||
if "analyses" in dct:
|
||||
parser = AnalysisPatchParser(self, yaml_block, "analyses")
|
||||
for test_block in parser.parse():
|
||||
self.parse_tests(test_block)
|
||||
|
||||
# parse exposures
|
||||
if 'exposures' in dct:
|
||||
if "exposures" in dct:
|
||||
self.parse_exposures(yaml_block)
|
||||
|
||||
|
||||
Parsed = TypeVar(
|
||||
'Parsed',
|
||||
UnpatchedSourceDefinition, ParsedNodePatch, ParsedMacroPatch
|
||||
)
|
||||
NodeTarget = TypeVar(
|
||||
'NodeTarget',
|
||||
UnparsedNodeUpdate, UnparsedAnalysisUpdate
|
||||
)
|
||||
Parsed = TypeVar("Parsed", UnpatchedSourceDefinition, ParsedNodePatch, ParsedMacroPatch)
|
||||
NodeTarget = TypeVar("NodeTarget", UnparsedNodeUpdate, UnparsedAnalysisUpdate)
|
||||
NonSourceTarget = TypeVar(
|
||||
'NonSourceTarget',
|
||||
UnparsedNodeUpdate, UnparsedAnalysisUpdate, UnparsedMacroUpdate
|
||||
"NonSourceTarget", UnparsedNodeUpdate, UnparsedAnalysisUpdate, UnparsedMacroUpdate
|
||||
)
|
||||
|
||||
|
||||
# abstract base class (ABCMeta)
|
||||
class YamlReader(metaclass=ABCMeta):
|
||||
def __init__(
|
||||
self, schema_parser: SchemaParser, yaml: YamlBlock, key: str
|
||||
) -> None:
|
||||
def __init__(self, schema_parser: SchemaParser, yaml: YamlBlock, key: str) -> None:
|
||||
self.schema_parser = schema_parser
|
||||
# key: models, seeds, snapshots, sources, macros,
|
||||
# analyses, exposures
|
||||
@@ -738,8 +710,9 @@ class YamlReader(metaclass=ABCMeta):
|
||||
data = self.yaml.data.get(self.key, [])
|
||||
if not isinstance(data, list):
|
||||
raise CompilationException(
|
||||
'{} must be a list, got {} instead: ({})'
|
||||
.format(self.key, type(data), _trimmed(str(data)))
|
||||
"{} must be a list, got {} instead: ({})".format(
|
||||
self.key, type(data), _trimmed(str(data))
|
||||
)
|
||||
)
|
||||
path = self.yaml.path.original_file_path
|
||||
|
||||
@@ -751,7 +724,7 @@ class YamlReader(metaclass=ABCMeta):
|
||||
yield entry
|
||||
else:
|
||||
msg = error_context(
|
||||
path, self.key, data, 'expected a dict with string keys'
|
||||
path, self.key, data, "expected a dict with string keys"
|
||||
)
|
||||
raise CompilationException(msg)
|
||||
|
||||
@@ -759,10 +732,10 @@ class YamlReader(metaclass=ABCMeta):
|
||||
class YamlDocsReader(YamlReader):
|
||||
@abstractmethod
|
||||
def parse(self) -> List[TestBlock]:
|
||||
raise NotImplementedError('parse is abstract')
|
||||
raise NotImplementedError("parse is abstract")
|
||||
|
||||
|
||||
T = TypeVar('T', bound=dbtClassMixin)
|
||||
T = TypeVar("T", bound=dbtClassMixin)
|
||||
|
||||
|
||||
class SourceParser(YamlDocsReader):
|
||||
@@ -779,13 +752,11 @@ class SourceParser(YamlDocsReader):
|
||||
def parse(self) -> List[TestBlock]:
|
||||
# get a verified list of dicts for the key handled by this parser
|
||||
for data in self.get_key_dicts():
|
||||
data = self.project.credentials.translate_aliases(
|
||||
data, recurse=True
|
||||
)
|
||||
data = self.project.credentials.translate_aliases(data, recurse=True)
|
||||
|
||||
is_override = 'overrides' in data
|
||||
is_override = "overrides" in data
|
||||
if is_override:
|
||||
data['path'] = self.yaml.path.original_file_path
|
||||
data["path"] = self.yaml.path.original_file_path
|
||||
patch = self._target_from_dict(SourcePatch, data)
|
||||
self.results.add_source_patch(self.yaml.file, patch)
|
||||
else:
|
||||
@@ -797,10 +768,9 @@ class SourceParser(YamlDocsReader):
|
||||
original_file_path = self.yaml.path.original_file_path
|
||||
fqn_path = self.yaml.path.relative_path
|
||||
for table in source.tables:
|
||||
unique_id = '.'.join([
|
||||
NodeType.Source, self.project.project_name,
|
||||
source.name, table.name
|
||||
])
|
||||
unique_id = ".".join(
|
||||
[NodeType.Source, self.project.project_name, source.name, table.name]
|
||||
)
|
||||
|
||||
# the FQN is project name / path elements /source_name /table_name
|
||||
fqn = self.schema_parser.get_fqn_prefix(fqn_path)
|
||||
@@ -825,17 +795,15 @@ class SourceParser(YamlDocsReader):
|
||||
class NonSourceParser(YamlDocsReader, Generic[NonSourceTarget, Parsed]):
|
||||
@abstractmethod
|
||||
def _target_type(self) -> Type[NonSourceTarget]:
|
||||
raise NotImplementedError('_target_type not implemented')
|
||||
raise NotImplementedError("_target_type not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def get_block(self, node: NonSourceTarget) -> TargetBlock:
|
||||
raise NotImplementedError('get_block is abstract')
|
||||
raise NotImplementedError("get_block is abstract")
|
||||
|
||||
@abstractmethod
|
||||
def parse_patch(
|
||||
self, block: TargetBlock[NonSourceTarget], refs: ParserRef
|
||||
) -> None:
|
||||
raise NotImplementedError('parse_patch is abstract')
|
||||
def parse_patch(self, block: TargetBlock[NonSourceTarget], refs: ParserRef) -> None:
|
||||
raise NotImplementedError("parse_patch is abstract")
|
||||
|
||||
def parse(self) -> List[TestBlock]:
|
||||
node: NonSourceTarget
|
||||
@@ -874,11 +842,13 @@ class NonSourceParser(YamlDocsReader, Generic[NonSourceTarget, Parsed]):
|
||||
for data in key_dicts:
|
||||
# add extra data to each dict. This updates the dicts
|
||||
# in the parser yaml
|
||||
data.update({
|
||||
'original_file_path': path,
|
||||
'yaml_key': self.key,
|
||||
'package_name': self.project.project_name,
|
||||
})
|
||||
data.update(
|
||||
{
|
||||
"original_file_path": path,
|
||||
"yaml_key": self.key,
|
||||
"package_name": self.project.project_name,
|
||||
}
|
||||
)
|
||||
try:
|
||||
# target_type: UnparsedNodeUpdate, UnparsedAnalysisUpdate,
|
||||
# or UnparsedMacroUpdate
|
||||
@@ -892,12 +862,9 @@ class NonSourceParser(YamlDocsReader, Generic[NonSourceTarget, Parsed]):
|
||||
|
||||
|
||||
class NodePatchParser(
|
||||
NonSourceParser[NodeTarget, ParsedNodePatch],
|
||||
Generic[NodeTarget]
|
||||
NonSourceParser[NodeTarget, ParsedNodePatch], Generic[NodeTarget]
|
||||
):
|
||||
def parse_patch(
|
||||
self, block: TargetBlock[NodeTarget], refs: ParserRef
|
||||
) -> None:
|
||||
def parse_patch(self, block: TargetBlock[NodeTarget], refs: ParserRef) -> None:
|
||||
result = ParsedNodePatch(
|
||||
name=block.target.name,
|
||||
original_file_path=block.target.original_file_path,
|
||||
@@ -958,7 +925,7 @@ class ExposureParser(YamlReader):
|
||||
|
||||
def parse_exposure(self, unparsed: UnparsedExposure) -> ParsedExposure:
|
||||
package_name = self.project.project_name
|
||||
unique_id = f'{NodeType.Exposure}.{package_name}.{unparsed.name}'
|
||||
unique_id = f"{NodeType.Exposure}.{package_name}.{unparsed.name}"
|
||||
path = self.yaml.path.relative_path
|
||||
|
||||
fqn = self.schema_parser.get_fqn_prefix(path)
|
||||
@@ -984,12 +951,10 @@ class ExposureParser(YamlReader):
|
||||
self.schema_parser.macro_manifest,
|
||||
package_name,
|
||||
)
|
||||
depends_on_jinja = '\n'.join(
|
||||
'{{ ' + line + '}}' for line in unparsed.depends_on
|
||||
)
|
||||
get_rendered(
|
||||
depends_on_jinja, ctx, parsed, capture_macros=True
|
||||
depends_on_jinja = "\n".join(
|
||||
"{{ " + line + "}}" for line in unparsed.depends_on
|
||||
)
|
||||
get_rendered(depends_on_jinja, ctx, parsed, capture_macros=True)
|
||||
# parsed now has a populated refs/sources
|
||||
return parsed
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
List, Callable, Iterable, Set, Union, Iterator, TypeVar, Generic
|
||||
)
|
||||
from typing import List, Callable, Iterable, Set, Union, Iterator, TypeVar, Generic
|
||||
|
||||
from dbt.clients.jinja import extract_toplevel_blocks, BlockTag
|
||||
from dbt.clients.system import find_matching
|
||||
@@ -72,13 +70,13 @@ class FilesystemSearcher(Iterable[FilePath]):
|
||||
root = self.project.project_root
|
||||
|
||||
for result in find_matching(root, self.relative_dirs, ext):
|
||||
if 'searched_path' not in result or 'relative_path' not in result:
|
||||
if "searched_path" not in result or "relative_path" not in result:
|
||||
raise InternalException(
|
||||
'Invalid result from find_matching: {}'.format(result)
|
||||
"Invalid result from find_matching: {}".format(result)
|
||||
)
|
||||
file_match = FilePath(
|
||||
searched_path=result['searched_path'],
|
||||
relative_path=result['relative_path'],
|
||||
searched_path=result["searched_path"],
|
||||
relative_path=result["relative_path"],
|
||||
project_root=root,
|
||||
)
|
||||
yield file_match
|
||||
@@ -86,7 +84,7 @@ class FilesystemSearcher(Iterable[FilePath]):
|
||||
|
||||
Block = Union[BlockContents, FullBlock]
|
||||
|
||||
BlockSearchResult = TypeVar('BlockSearchResult', BlockContents, FullBlock)
|
||||
BlockSearchResult = TypeVar("BlockSearchResult", BlockContents, FullBlock)
|
||||
|
||||
BlockSearchResultFactory = Callable[[SourceFile, BlockTag], BlockSearchResult]
|
||||
|
||||
@@ -96,7 +94,7 @@ class BlockSearcher(Generic[BlockSearchResult], Iterable[BlockSearchResult]):
|
||||
self,
|
||||
source: List[FileBlock],
|
||||
allowed_blocks: Set[str],
|
||||
source_tag_factory: BlockSearchResultFactory
|
||||
source_tag_factory: BlockSearchResultFactory,
|
||||
) -> None:
|
||||
self.source = source
|
||||
self.allowed_blocks = allowed_blocks
|
||||
@@ -107,7 +105,7 @@ class BlockSearcher(Generic[BlockSearchResult], Iterable[BlockSearchResult]):
|
||||
blocks = extract_toplevel_blocks(
|
||||
source_file.contents,
|
||||
allowed_blocks=self.allowed_blocks,
|
||||
collect_raw_data=False
|
||||
collect_raw_data=False,
|
||||
)
|
||||
# this makes mypy happy, and this is an invariant we really need
|
||||
for block in blocks:
|
||||
|
||||
@@ -8,9 +8,7 @@ from dbt.parser.search import FileBlock, FilesystemSearcher
|
||||
|
||||
class SeedParser(SimpleSQLParser[ParsedSeedNode]):
|
||||
def get_paths(self):
|
||||
return FilesystemSearcher(
|
||||
self.project, self.project.data_paths, '.csv'
|
||||
)
|
||||
return FilesystemSearcher(self.project, self.project.data_paths, ".csv")
|
||||
|
||||
def parse_from_dict(self, dct, validate=True) -> ParsedSeedNode:
|
||||
if validate:
|
||||
@@ -30,9 +28,7 @@ class SeedParser(SimpleSQLParser[ParsedSeedNode]):
|
||||
) -> None:
|
||||
"""Seeds don't need to do any rendering."""
|
||||
|
||||
def load_file(
|
||||
self, match: FilePath, *, set_contents: bool = False
|
||||
) -> SourceFile:
|
||||
def load_file(self, match: FilePath, *, set_contents: bool = False) -> SourceFile:
|
||||
if match.seed_too_large():
|
||||
# We don't want to calculate a hash of this file. Use the path.
|
||||
return SourceFile.big_seed(match)
|
||||
|
||||
@@ -3,27 +3,22 @@ from typing import List
|
||||
|
||||
from dbt.dataclass_schema import ValidationError
|
||||
|
||||
from dbt.contracts.graph.parsed import (
|
||||
IntermediateSnapshotNode, ParsedSnapshotNode
|
||||
)
|
||||
from dbt.exceptions import (
|
||||
CompilationException, validator_error_message
|
||||
)
|
||||
from dbt.contracts.graph.parsed import IntermediateSnapshotNode, ParsedSnapshotNode
|
||||
from dbt.exceptions import CompilationException, validator_error_message
|
||||
from dbt.node_types import NodeType
|
||||
from dbt.parser.base import SQLParser
|
||||
from dbt.parser.search import (
|
||||
FilesystemSearcher, BlockContents, BlockSearcher, FileBlock
|
||||
FilesystemSearcher,
|
||||
BlockContents,
|
||||
BlockSearcher,
|
||||
FileBlock,
|
||||
)
|
||||
from dbt.utils import split_path
|
||||
|
||||
|
||||
class SnapshotParser(
|
||||
SQLParser[IntermediateSnapshotNode, ParsedSnapshotNode]
|
||||
):
|
||||
class SnapshotParser(SQLParser[IntermediateSnapshotNode, ParsedSnapshotNode]):
|
||||
def get_paths(self):
|
||||
return FilesystemSearcher(
|
||||
self.project, self.project.snapshot_paths, '.sql'
|
||||
)
|
||||
return FilesystemSearcher(self.project, self.project.snapshot_paths, ".sql")
|
||||
|
||||
def parse_from_dict(self, dct, validate=True) -> IntermediateSnapshotNode:
|
||||
if validate:
|
||||
@@ -78,7 +73,7 @@ class SnapshotParser(
|
||||
def parse_file(self, file_block: FileBlock) -> None:
|
||||
blocks = BlockSearcher(
|
||||
source=[file_block],
|
||||
allowed_blocks={'snapshot'},
|
||||
allowed_blocks={"snapshot"},
|
||||
source_tag_factory=BlockContents,
|
||||
)
|
||||
for block in blocks:
|
||||
|
||||
@@ -34,8 +34,7 @@ class SourcePatcher:
|
||||
self.results = results
|
||||
self.root_project = root_project
|
||||
self.macro_manifest = MacroManifest(
|
||||
macros=self.results.macros,
|
||||
files=self.results.files
|
||||
macros=self.results.macros, files=self.results.files
|
||||
)
|
||||
self.schema_parsers: Dict[str, SchemaParser] = {}
|
||||
self.patches_used: Dict[SourceKey, Set[str]] = {}
|
||||
@@ -65,9 +64,7 @@ class SourcePatcher:
|
||||
|
||||
source = UnparsedSourceDefinition.from_dict(source_dct)
|
||||
table = UnparsedSourceTableDefinition.from_dict(table_dct)
|
||||
return unpatched.replace(
|
||||
source=source, table=table, patch_path=patch_path
|
||||
)
|
||||
return unpatched.replace(source=source, table=table, patch_path=patch_path)
|
||||
|
||||
def parse_source_docs(self, block: UnpatchedSourceDefinition) -> ParserRef:
|
||||
refs = ParserRef()
|
||||
@@ -78,7 +75,7 @@ class SourcePatcher:
|
||||
refs.add(column, description, data_type, meta)
|
||||
return refs
|
||||
|
||||
def get_schema_parser_for(self, package_name: str) -> 'SchemaParser':
|
||||
def get_schema_parser_for(self, package_name: str) -> "SchemaParser":
|
||||
if package_name in self.schema_parsers:
|
||||
schema_parser = self.schema_parsers[package_name]
|
||||
else:
|
||||
@@ -157,31 +154,28 @@ class SourcePatcher:
|
||||
|
||||
if unused_tables:
|
||||
msg = self.get_unused_msg(unused_tables)
|
||||
warn_or_error(msg, log_fmt=ui.warning_tag('{}'))
|
||||
warn_or_error(msg, log_fmt=ui.warning_tag("{}"))
|
||||
|
||||
def get_unused_msg(
|
||||
self,
|
||||
unused_tables: Dict[SourceKey, Optional[Set[str]]],
|
||||
) -> str:
|
||||
msg = [
|
||||
'During parsing, dbt encountered source overrides that had no '
|
||||
'target:',
|
||||
"During parsing, dbt encountered source overrides that had no " "target:",
|
||||
]
|
||||
for key, table_names in unused_tables.items():
|
||||
patch = self.results.source_patches[key]
|
||||
patch_name = f'{patch.overrides}.{patch.name}'
|
||||
patch_name = f"{patch.overrides}.{patch.name}"
|
||||
if table_names is None:
|
||||
msg.append(
|
||||
f' - Source {patch_name} (in {patch.path})'
|
||||
)
|
||||
msg.append(f" - Source {patch_name} (in {patch.path})")
|
||||
else:
|
||||
for table_name in sorted(table_names):
|
||||
msg.append(
|
||||
f' - Source table {patch_name}.{table_name} '
|
||||
f'(in {patch.path})'
|
||||
f" - Source table {patch_name}.{table_name} "
|
||||
f"(in {patch.path})"
|
||||
)
|
||||
msg.append('')
|
||||
return '\n'.join(msg)
|
||||
msg.append("")
|
||||
return "\n".join(msg)
|
||||
|
||||
|
||||
def patch_sources(
|
||||
|
||||
@@ -15,5 +15,5 @@ def profiler(enable, outfile):
|
||||
if enable:
|
||||
profiler.disable()
|
||||
stats = Stats(profiler)
|
||||
stats.sort_stats('tottime')
|
||||
stats.sort_stats("tottime")
|
||||
stats.dump_stats(outfile)
|
||||
|
||||
@@ -46,14 +46,14 @@ from dbt.rpc.task_handler import RequestTaskHandler
|
||||
|
||||
|
||||
class GC(RemoteBuiltinMethod[GCParameters, GCResult]):
|
||||
METHOD_NAME = 'gc'
|
||||
METHOD_NAME = "gc"
|
||||
|
||||
def set_args(self, params: GCParameters):
|
||||
super().set_args(params)
|
||||
|
||||
def handle_request(self) -> GCResult:
|
||||
if self.params is None:
|
||||
raise dbt.exceptions.InternalException('GC: params not set')
|
||||
raise dbt.exceptions.InternalException("GC: params not set")
|
||||
return self.task_manager.gc_safe(
|
||||
task_ids=self.params.task_ids,
|
||||
before=self.params.before,
|
||||
@@ -62,14 +62,14 @@ class GC(RemoteBuiltinMethod[GCParameters, GCResult]):
|
||||
|
||||
|
||||
class Kill(RemoteBuiltinMethod[KillParameters, KillResult]):
|
||||
METHOD_NAME = 'kill'
|
||||
METHOD_NAME = "kill"
|
||||
|
||||
def set_args(self, params: KillParameters):
|
||||
super().set_args(params)
|
||||
|
||||
def handle_request(self) -> KillResult:
|
||||
if self.params is None:
|
||||
raise dbt.exceptions.InternalException('Kill: params not set')
|
||||
raise dbt.exceptions.InternalException("Kill: params not set")
|
||||
result = KillResult()
|
||||
task: RequestTaskHandler
|
||||
try:
|
||||
@@ -99,7 +99,7 @@ class Kill(RemoteBuiltinMethod[KillParameters, KillResult]):
|
||||
|
||||
|
||||
class Status(RemoteBuiltinMethod[StatusParameters, LastParse]):
|
||||
METHOD_NAME = 'status'
|
||||
METHOD_NAME = "status"
|
||||
|
||||
def set_args(self, params: StatusParameters):
|
||||
super().set_args(params)
|
||||
@@ -109,14 +109,14 @@ class Status(RemoteBuiltinMethod[StatusParameters, LastParse]):
|
||||
|
||||
|
||||
class PS(RemoteBuiltinMethod[PSParameters, PSResult]):
|
||||
METHOD_NAME = 'ps'
|
||||
METHOD_NAME = "ps"
|
||||
|
||||
def set_args(self, params: PSParameters):
|
||||
super().set_args(params)
|
||||
|
||||
def keep(self, row: TaskRow):
|
||||
if self.params is None:
|
||||
raise dbt.exceptions.InternalException('PS: params not set')
|
||||
raise dbt.exceptions.InternalException("PS: params not set")
|
||||
if row.state.finished and self.params.completed:
|
||||
return True
|
||||
elif not row.state.finished and self.params.active:
|
||||
@@ -125,9 +125,7 @@ class PS(RemoteBuiltinMethod[PSParameters, PSResult]):
|
||||
return False
|
||||
|
||||
def handle_request(self) -> PSResult:
|
||||
rows = [
|
||||
row for row in self.task_manager.task_table() if self.keep(row)
|
||||
]
|
||||
rows = [row for row in self.task_manager.task_table() if self.keep(row)]
|
||||
rows.sort(key=lambda r: (r.state, r.start, r.method))
|
||||
result = PSResult(rows=rows, logs=[])
|
||||
return result
|
||||
@@ -138,10 +136,11 @@ def poll_complete(
|
||||
) -> PollResult:
|
||||
if timing.state not in (TaskHandlerState.Success, TaskHandlerState.Failed):
|
||||
raise dbt.exceptions.InternalException(
|
||||
f'got invalid result state in poll_complete: {timing.state}'
|
||||
f"got invalid result state in poll_complete: {timing.state}"
|
||||
)
|
||||
|
||||
cls: Type[Union[
|
||||
cls: Type[
|
||||
Union[
|
||||
PollExecuteCompleteResult,
|
||||
PollRunCompleteResult,
|
||||
PollCompileCompleteResult,
|
||||
@@ -150,7 +149,8 @@ def poll_complete(
|
||||
PollRunOperationCompleteResult,
|
||||
PollGetManifestResult,
|
||||
PollFreshnessResult,
|
||||
]]
|
||||
]
|
||||
]
|
||||
|
||||
if isinstance(result, RemoteExecutionResult):
|
||||
cls = PollExecuteCompleteResult
|
||||
@@ -171,7 +171,7 @@ def poll_complete(
|
||||
cls = PollFreshnessResult
|
||||
else:
|
||||
raise dbt.exceptions.InternalException(
|
||||
'got invalid result in poll_complete: {}'.format(result)
|
||||
"got invalid result in poll_complete: {}".format(result)
|
||||
)
|
||||
return cls.from_result(result, tags, timing, logs)
|
||||
|
||||
@@ -181,14 +181,14 @@ def _dict_logs(logs: List[LogMessage]) -> List[Dict[str, Any]]:
|
||||
|
||||
|
||||
class Poll(RemoteBuiltinMethod[PollParameters, PollResult]):
|
||||
METHOD_NAME = 'poll'
|
||||
METHOD_NAME = "poll"
|
||||
|
||||
def set_args(self, params: PollParameters):
|
||||
super().set_args(params)
|
||||
|
||||
def handle_request(self) -> PollResult:
|
||||
if self.params is None:
|
||||
raise dbt.exceptions.InternalException('Poll: params not set')
|
||||
raise dbt.exceptions.InternalException("Poll: params not set")
|
||||
task_id = self.params.request_token
|
||||
task: RequestTaskHandler = self.task_manager.get_request(task_id)
|
||||
|
||||
@@ -216,7 +216,7 @@ class Poll(RemoteBuiltinMethod[PollParameters, PollResult]):
|
||||
err = task.error
|
||||
if err is None:
|
||||
exc = dbt.exceptions.InternalException(
|
||||
f'At end of task {task_id}, error state but error is None'
|
||||
f"At end of task {task_id}, error state but error is None"
|
||||
)
|
||||
raise RPCException.from_error(
|
||||
dbt_error(exc, logs=_dict_logs(task_logs))
|
||||
@@ -228,17 +228,13 @@ class Poll(RemoteBuiltinMethod[PollParameters, PollResult]):
|
||||
|
||||
if task.result is None:
|
||||
exc = dbt.exceptions.InternalException(
|
||||
f'At end of task {task_id}, state={state} but result is '
|
||||
'None'
|
||||
f"At end of task {task_id}, state={state} but result is " "None"
|
||||
)
|
||||
raise RPCException.from_error(
|
||||
dbt_error(exc, logs=_dict_logs(task_logs))
|
||||
)
|
||||
return poll_complete(
|
||||
timing=timing,
|
||||
result=task.result,
|
||||
tags=task.tags,
|
||||
logs=task_logs
|
||||
timing=timing, result=task.result, tags=task.tags, logs=task_logs
|
||||
)
|
||||
elif state == TaskHandlerState.Killed:
|
||||
return PollKilledResult(
|
||||
@@ -251,8 +247,6 @@ class Poll(RemoteBuiltinMethod[PollParameters, PollResult]):
|
||||
)
|
||||
else:
|
||||
exc = dbt.exceptions.InternalException(
|
||||
f'Got unknown value state={state} for task {task_id}'
|
||||
)
|
||||
raise RPCException.from_error(
|
||||
dbt_error(exc, logs=_dict_logs(task_logs))
|
||||
f"Got unknown value state={state} for task {task_id}"
|
||||
)
|
||||
raise RPCException.from_error(dbt_error(exc, logs=_dict_logs(task_logs)))
|
||||
|
||||
@@ -12,45 +12,44 @@ class RPCException(JSONRPCDispatchException):
|
||||
message: Optional[str] = None,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
logs: Optional[List[Dict[str, Any]]] = None,
|
||||
tags: Optional[Dict[str, Any]] = None
|
||||
tags: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
if code is None:
|
||||
code = -32000
|
||||
if message is None:
|
||||
message = 'Server error'
|
||||
message = "Server error"
|
||||
if data is None:
|
||||
data = {}
|
||||
|
||||
super().__init__(code=code, message=message, data=data)
|
||||
if logs is not None:
|
||||
self.logs = logs
|
||||
self.error.data['tags'] = tags
|
||||
self.error.data["tags"] = tags
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
'RPCException({0.code}, {0.message}, {0.data}, {1.logs})'
|
||||
.format(self.error, self)
|
||||
return "RPCException({0.code}, {0.message}, {0.data}, {1.logs})".format(
|
||||
self.error, self
|
||||
)
|
||||
|
||||
@property
|
||||
def logs(self) -> List[Dict[str, Any]]:
|
||||
return self.error.data.get('logs')
|
||||
return self.error.data.get("logs")
|
||||
|
||||
@logs.setter
|
||||
def logs(self, value):
|
||||
if value is None:
|
||||
return
|
||||
self.error.data['logs'] = value
|
||||
self.error.data["logs"] = value
|
||||
|
||||
@property
|
||||
def tags(self):
|
||||
return self.error.data.get('tags')
|
||||
return self.error.data.get("tags")
|
||||
|
||||
@tags.setter
|
||||
def tags(self, value):
|
||||
if value is None:
|
||||
return
|
||||
self.error.data['tags'] = value
|
||||
self.error.data["tags"] = value
|
||||
|
||||
@classmethod
|
||||
def from_error(cls, err):
|
||||
@@ -58,16 +57,14 @@ class RPCException(JSONRPCDispatchException):
|
||||
code=err.code,
|
||||
message=err.message,
|
||||
data=err.data,
|
||||
logs=err.data.get('logs'),
|
||||
tags=err.data.get('tags'),
|
||||
logs=err.data.get("logs"),
|
||||
tags=err.data.get("tags"),
|
||||
)
|
||||
|
||||
|
||||
def invalid_params(data):
|
||||
return RPCException(
|
||||
code=JSONRPCInvalidParams.CODE,
|
||||
message=JSONRPCInvalidParams.MESSAGE,
|
||||
data=data
|
||||
code=JSONRPCInvalidParams.CODE, message=JSONRPCInvalidParams.MESSAGE, data=data
|
||||
)
|
||||
|
||||
|
||||
@@ -82,6 +79,7 @@ def timeout_error(timeout_value, logs=None, tags=None):
|
||||
|
||||
|
||||
def dbt_error(exc, logs=None, tags=None):
|
||||
exc = RPCException(code=exc.CODE, message=exc.MESSAGE, data=exc.data(),
|
||||
logs=logs, tags=tags)
|
||||
exc = RPCException(
|
||||
code=exc.CODE, message=exc.MESSAGE, data=exc.data(), logs=logs, tags=tags
|
||||
)
|
||||
return exc
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user