mirror of
https://github.com/dbt-labs/dbt-core
synced 2025-12-20 00:01:28 +00:00
Compare commits
1 Commits
jerco/sql-
...
testing-pr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
42058de028 |
@@ -41,4 +41,3 @@ first_value = 1
|
|||||||
[bumpversion:file:plugins/snowflake/dbt/adapters/snowflake/__version__.py]
|
[bumpversion:file:plugins/snowflake/dbt/adapters/snowflake/__version__.py]
|
||||||
|
|
||||||
[bumpversion:file:plugins/bigquery/dbt/adapters/bigquery/__version__.py]
|
[bumpversion:file:plugins/bigquery/dbt/adapters/bigquery/__version__.py]
|
||||||
|
|
||||||
|
|||||||
20
.pre-commit-config.yaml
Normal file
20
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v2.3.0
|
||||||
|
hooks:
|
||||||
|
- id: check-yaml
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
- id: trailing-whitespace
|
||||||
|
- repo: https://github.com/psf/black
|
||||||
|
rev: 20.8b1
|
||||||
|
hooks:
|
||||||
|
- id: black
|
||||||
|
- repo: https://gitlab.com/PyCQA/flake8
|
||||||
|
rev: 3.9.0
|
||||||
|
hooks:
|
||||||
|
- id: flake8
|
||||||
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
|
rev: v0.812
|
||||||
|
hooks:
|
||||||
|
- id: mypy
|
||||||
|
files: ^core/dbt/
|
||||||
@@ -139,7 +139,7 @@ brew install postgresql
|
|||||||
First make sure that you set up your `virtualenv` as described in section _Setting up an environment_. Next, install dbt (and its dependencies) with:
|
First make sure that you set up your `virtualenv` as described in section _Setting up an environment_. Next, install dbt (and its dependencies) with:
|
||||||
|
|
||||||
```
|
```
|
||||||
pip install -r editable_requirements.txt
|
pip install -r requirements-editable.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
When dbt is installed from source in this way, any changes you make to the dbt source code will be reflected immediately in your next `dbt` run.
|
When dbt is installed from source in this way, any changes you make to the dbt source code will be reflected immediately in your next `dbt` run.
|
||||||
|
|||||||
7
Makefile
7
Makefile
@@ -1,10 +1,5 @@
|
|||||||
.PHONY: install test test-unit test-integration
|
.PHONY: install test test-unit test-integration
|
||||||
|
|
||||||
changed_tests := `git status --porcelain | grep '^\(M\| M\|A\| A\)' | awk '{ print $$2 }' | grep '\/test_[a-zA-Z_\-\.]\+.py'`
|
|
||||||
|
|
||||||
install:
|
|
||||||
pip install -e .
|
|
||||||
|
|
||||||
test: .env
|
test: .env
|
||||||
@echo "Full test run starting..."
|
@echo "Full test run starting..."
|
||||||
@time docker-compose run --rm test tox
|
@time docker-compose run --rm test tox
|
||||||
@@ -18,7 +13,7 @@ test-integration: .env
|
|||||||
@time docker-compose run --rm test tox -e integration-postgres-py36,integration-redshift-py36,integration-snowflake-py36,integration-bigquery-py36
|
@time docker-compose run --rm test tox -e integration-postgres-py36,integration-redshift-py36,integration-snowflake-py36,integration-bigquery-py36
|
||||||
|
|
||||||
test-quick: .env
|
test-quick: .env
|
||||||
@echo "Integration test run starting..."
|
@echo "Integration test run starting, will exit on first failure..."
|
||||||
@time docker-compose run --rm test tox -e integration-postgres-py36 -- -x
|
@time docker-compose run --rm test tox -e integration-postgres-py36 -- -x
|
||||||
|
|
||||||
# This rule creates a file named .env that is used by docker-compose for passing
|
# This rule creates a file named .env that is used by docker-compose for passing
|
||||||
|
|||||||
@@ -139,7 +139,7 @@ jobs:
|
|||||||
inputs:
|
inputs:
|
||||||
versionSpec: '3.7'
|
versionSpec: '3.7'
|
||||||
architecture: 'x64'
|
architecture: 'x64'
|
||||||
- script: python -m pip install --upgrade pip setuptools && python -m pip install -r requirements.txt && python -m pip install -r dev_requirements.txt
|
- script: python -m pip install --upgrade pip setuptools && python -m pip install -r requirements.txt && python -m pip install -r requirements-dev.txt
|
||||||
displayName: Install dependencies
|
displayName: Install dependencies
|
||||||
- task: ShellScript@2
|
- task: ShellScript@2
|
||||||
inputs:
|
inputs:
|
||||||
|
|||||||
@@ -63,11 +63,13 @@ def main():
|
|||||||
packages = registry.packages()
|
packages = registry.packages()
|
||||||
project_json = init_project_in_packages(args, packages)
|
project_json = init_project_in_packages(args, packages)
|
||||||
if args.project["version"] in project_json["versions"]:
|
if args.project["version"] in project_json["versions"]:
|
||||||
raise Exception("Version {} already in packages JSON"
|
raise Exception(
|
||||||
.format(args.project["version"]),
|
"Version {} already in packages JSON".format(args.project["version"]),
|
||||||
file=sys.stderr)
|
file=sys.stderr,
|
||||||
|
)
|
||||||
add_version_to_package(args, project_json)
|
add_version_to_package(args, project_json)
|
||||||
print(json.dumps(packages, indent=2))
|
print(json.dumps(packages, indent=2))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -8,10 +8,10 @@ from dbt.exceptions import RuntimeException
|
|||||||
@dataclass
|
@dataclass
|
||||||
class Column:
|
class Column:
|
||||||
TYPE_LABELS: ClassVar[Dict[str, str]] = {
|
TYPE_LABELS: ClassVar[Dict[str, str]] = {
|
||||||
'STRING': 'TEXT',
|
"STRING": "TEXT",
|
||||||
'TIMESTAMP': 'TIMESTAMP',
|
"TIMESTAMP": "TIMESTAMP",
|
||||||
'FLOAT': 'FLOAT',
|
"FLOAT": "FLOAT",
|
||||||
'INTEGER': 'INT'
|
"INTEGER": "INT",
|
||||||
}
|
}
|
||||||
column: str
|
column: str
|
||||||
dtype: str
|
dtype: str
|
||||||
@@ -24,7 +24,7 @@ class Column:
|
|||||||
return cls.TYPE_LABELS.get(dtype.upper(), dtype)
|
return cls.TYPE_LABELS.get(dtype.upper(), dtype)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, name, label_or_dtype: str) -> 'Column':
|
def create(cls, name, label_or_dtype: str) -> "Column":
|
||||||
column_type = cls.translate_type(label_or_dtype)
|
column_type = cls.translate_type(label_or_dtype)
|
||||||
return cls(name, column_type)
|
return cls(name, column_type)
|
||||||
|
|
||||||
@@ -41,14 +41,19 @@ class Column:
|
|||||||
if self.is_string():
|
if self.is_string():
|
||||||
return Column.string_type(self.string_size())
|
return Column.string_type(self.string_size())
|
||||||
elif self.is_numeric():
|
elif self.is_numeric():
|
||||||
return Column.numeric_type(self.dtype, self.numeric_precision,
|
return Column.numeric_type(
|
||||||
self.numeric_scale)
|
self.dtype, self.numeric_precision, self.numeric_scale
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return self.dtype
|
return self.dtype
|
||||||
|
|
||||||
def is_string(self) -> bool:
|
def is_string(self) -> bool:
|
||||||
return self.dtype.lower() in ['text', 'character varying', 'character',
|
return self.dtype.lower() in [
|
||||||
'varchar']
|
"text",
|
||||||
|
"character varying",
|
||||||
|
"character",
|
||||||
|
"varchar",
|
||||||
|
]
|
||||||
|
|
||||||
def is_number(self):
|
def is_number(self):
|
||||||
return any([self.is_integer(), self.is_numeric(), self.is_float()])
|
return any([self.is_integer(), self.is_numeric(), self.is_float()])
|
||||||
@@ -56,33 +61,45 @@ class Column:
|
|||||||
def is_float(self):
|
def is_float(self):
|
||||||
return self.dtype.lower() in [
|
return self.dtype.lower() in [
|
||||||
# floats
|
# floats
|
||||||
'real', 'float4', 'float', 'double precision', 'float8'
|
"real",
|
||||||
|
"float4",
|
||||||
|
"float",
|
||||||
|
"double precision",
|
||||||
|
"float8",
|
||||||
]
|
]
|
||||||
|
|
||||||
def is_integer(self) -> bool:
|
def is_integer(self) -> bool:
|
||||||
return self.dtype.lower() in [
|
return self.dtype.lower() in [
|
||||||
# real types
|
# real types
|
||||||
'smallint', 'integer', 'bigint',
|
"smallint",
|
||||||
'smallserial', 'serial', 'bigserial',
|
"integer",
|
||||||
|
"bigint",
|
||||||
|
"smallserial",
|
||||||
|
"serial",
|
||||||
|
"bigserial",
|
||||||
# aliases
|
# aliases
|
||||||
'int2', 'int4', 'int8',
|
"int2",
|
||||||
'serial2', 'serial4', 'serial8',
|
"int4",
|
||||||
|
"int8",
|
||||||
|
"serial2",
|
||||||
|
"serial4",
|
||||||
|
"serial8",
|
||||||
]
|
]
|
||||||
|
|
||||||
def is_numeric(self) -> bool:
|
def is_numeric(self) -> bool:
|
||||||
return self.dtype.lower() in ['numeric', 'decimal']
|
return self.dtype.lower() in ["numeric", "decimal"]
|
||||||
|
|
||||||
def string_size(self) -> int:
|
def string_size(self) -> int:
|
||||||
if not self.is_string():
|
if not self.is_string():
|
||||||
raise RuntimeException("Called string_size() on non-string field!")
|
raise RuntimeException("Called string_size() on non-string field!")
|
||||||
|
|
||||||
if self.dtype == 'text' or self.char_size is None:
|
if self.dtype == "text" or self.char_size is None:
|
||||||
# char_size should never be None. Handle it reasonably just in case
|
# char_size should never be None. Handle it reasonably just in case
|
||||||
return 256
|
return 256
|
||||||
else:
|
else:
|
||||||
return int(self.char_size)
|
return int(self.char_size)
|
||||||
|
|
||||||
def can_expand_to(self, other_column: 'Column') -> bool:
|
def can_expand_to(self, other_column: "Column") -> bool:
|
||||||
"""returns True if this column can be expanded to the size of the
|
"""returns True if this column can be expanded to the size of the
|
||||||
other column"""
|
other column"""
|
||||||
if not self.is_string() or not other_column.is_string():
|
if not self.is_string() or not other_column.is_string():
|
||||||
@@ -110,12 +127,10 @@ class Column:
|
|||||||
return "<Column {} ({})>".format(self.name, self.data_type)
|
return "<Column {} ({})>".format(self.name, self.data_type)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_description(cls, name: str, raw_data_type: str) -> 'Column':
|
def from_description(cls, name: str, raw_data_type: str) -> "Column":
|
||||||
match = re.match(r'([^(]+)(\([^)]+\))?', raw_data_type)
|
match = re.match(r"([^(]+)(\([^)]+\))?", raw_data_type)
|
||||||
if match is None:
|
if match is None:
|
||||||
raise RuntimeException(
|
raise RuntimeException(f'Could not interpret data type "{raw_data_type}"')
|
||||||
f'Could not interpret data type "{raw_data_type}"'
|
|
||||||
)
|
|
||||||
data_type, size_info = match.groups()
|
data_type, size_info = match.groups()
|
||||||
char_size = None
|
char_size = None
|
||||||
numeric_precision = None
|
numeric_precision = None
|
||||||
@@ -123,7 +138,7 @@ class Column:
|
|||||||
if size_info is not None:
|
if size_info is not None:
|
||||||
# strip out the parentheses
|
# strip out the parentheses
|
||||||
size_info = size_info[1:-1]
|
size_info = size_info[1:-1]
|
||||||
parts = size_info.split(',')
|
parts = size_info.split(",")
|
||||||
if len(parts) == 1:
|
if len(parts) == 1:
|
||||||
try:
|
try:
|
||||||
char_size = int(parts[0])
|
char_size = int(parts[0])
|
||||||
@@ -148,6 +163,4 @@ class Column:
|
|||||||
f'could not convert "{parts[1]}" to an integer'
|
f'could not convert "{parts[1]}" to an integer'
|
||||||
)
|
)
|
||||||
|
|
||||||
return cls(
|
return cls(name, data_type, char_size, numeric_precision, numeric_scale)
|
||||||
name, data_type, char_size, numeric_precision, numeric_scale
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -1,18 +1,21 @@
|
|||||||
import abc
|
import abc
|
||||||
import os
|
import os
|
||||||
|
|
||||||
# multiprocessing.RLock is a function returning this type
|
# multiprocessing.RLock is a function returning this type
|
||||||
from multiprocessing.synchronize import RLock
|
from multiprocessing.synchronize import RLock
|
||||||
from threading import get_ident
|
from threading import get_ident
|
||||||
from typing import (
|
from typing import Dict, Tuple, Hashable, Optional, ContextManager, List, Union
|
||||||
Dict, Tuple, Hashable, Optional, ContextManager, List, Union
|
|
||||||
)
|
|
||||||
|
|
||||||
import agate
|
import agate
|
||||||
|
|
||||||
import dbt.exceptions
|
import dbt.exceptions
|
||||||
from dbt.contracts.connection import (
|
from dbt.contracts.connection import (
|
||||||
Connection, Identifier, ConnectionState,
|
Connection,
|
||||||
AdapterRequiredConfig, LazyHandle, AdapterResponse
|
Identifier,
|
||||||
|
ConnectionState,
|
||||||
|
AdapterRequiredConfig,
|
||||||
|
LazyHandle,
|
||||||
|
AdapterResponse,
|
||||||
)
|
)
|
||||||
from dbt.contracts.graph.manifest import Manifest
|
from dbt.contracts.graph.manifest import Manifest
|
||||||
from dbt.adapters.base.query_headers import (
|
from dbt.adapters.base.query_headers import (
|
||||||
@@ -35,6 +38,7 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
|||||||
You must also set the 'TYPE' class attribute with a class-unique constant
|
You must also set the 'TYPE' class attribute with a class-unique constant
|
||||||
string.
|
string.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
TYPE: str = NotImplemented
|
TYPE: str = NotImplemented
|
||||||
|
|
||||||
def __init__(self, profile: AdapterRequiredConfig):
|
def __init__(self, profile: AdapterRequiredConfig):
|
||||||
@@ -65,7 +69,7 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
|||||||
key = self.get_thread_identifier()
|
key = self.get_thread_identifier()
|
||||||
if key in self.thread_connections:
|
if key in self.thread_connections:
|
||||||
raise dbt.exceptions.InternalException(
|
raise dbt.exceptions.InternalException(
|
||||||
'In set_thread_connection, existing connection exists for {}'
|
"In set_thread_connection, existing connection exists for {}"
|
||||||
)
|
)
|
||||||
self.thread_connections[key] = conn
|
self.thread_connections[key] = conn
|
||||||
|
|
||||||
@@ -105,18 +109,19 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
|||||||
underlying database.
|
underlying database.
|
||||||
"""
|
"""
|
||||||
raise dbt.exceptions.NotImplementedException(
|
raise dbt.exceptions.NotImplementedException(
|
||||||
'`exception_handler` is not implemented for this adapter!')
|
"`exception_handler` is not implemented for this adapter!"
|
||||||
|
)
|
||||||
|
|
||||||
def set_connection_name(self, name: Optional[str] = None) -> Connection:
|
def set_connection_name(self, name: Optional[str] = None) -> Connection:
|
||||||
conn_name: str
|
conn_name: str
|
||||||
if name is None:
|
if name is None:
|
||||||
# if a name isn't specified, we'll re-use a single handle
|
# if a name isn't specified, we'll re-use a single handle
|
||||||
# named 'master'
|
# named 'master'
|
||||||
conn_name = 'master'
|
conn_name = "master"
|
||||||
else:
|
else:
|
||||||
if not isinstance(name, str):
|
if not isinstance(name, str):
|
||||||
raise dbt.exceptions.CompilerException(
|
raise dbt.exceptions.CompilerException(
|
||||||
f'For connection name, got {name} - not a string!'
|
f"For connection name, got {name} - not a string!"
|
||||||
)
|
)
|
||||||
assert isinstance(name, str)
|
assert isinstance(name, str)
|
||||||
conn_name = name
|
conn_name = name
|
||||||
@@ -129,20 +134,20 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
|||||||
state=ConnectionState.INIT,
|
state=ConnectionState.INIT,
|
||||||
transaction_open=False,
|
transaction_open=False,
|
||||||
handle=None,
|
handle=None,
|
||||||
credentials=self.profile.credentials
|
credentials=self.profile.credentials,
|
||||||
)
|
)
|
||||||
self.set_thread_connection(conn)
|
self.set_thread_connection(conn)
|
||||||
|
|
||||||
if conn.name == conn_name and conn.state == 'open':
|
if conn.name == conn_name and conn.state == "open":
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
logger.debug(
|
logger.debug('Acquiring new {} connection "{}".'.format(self.TYPE, conn_name))
|
||||||
'Acquiring new {} connection "{}".'.format(self.TYPE, conn_name))
|
|
||||||
|
|
||||||
if conn.state == 'open':
|
if conn.state == "open":
|
||||||
logger.debug(
|
logger.debug(
|
||||||
'Re-using an available connection from the pool (formerly {}).'
|
"Re-using an available connection from the pool (formerly {}).".format(
|
||||||
.format(conn.name)
|
conn.name
|
||||||
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
conn.handle = LazyHandle(self.open)
|
conn.handle = LazyHandle(self.open)
|
||||||
@@ -154,7 +159,7 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
|||||||
def cancel_open(self) -> Optional[List[str]]:
|
def cancel_open(self) -> Optional[List[str]]:
|
||||||
"""Cancel all open connections on the adapter. (passable)"""
|
"""Cancel all open connections on the adapter. (passable)"""
|
||||||
raise dbt.exceptions.NotImplementedException(
|
raise dbt.exceptions.NotImplementedException(
|
||||||
'`cancel_open` is not implemented for this adapter!'
|
"`cancel_open` is not implemented for this adapter!"
|
||||||
)
|
)
|
||||||
|
|
||||||
@abc.abstractclassmethod
|
@abc.abstractclassmethod
|
||||||
@@ -168,7 +173,7 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
|||||||
connection should not be in either in_use or available.
|
connection should not be in either in_use or available.
|
||||||
"""
|
"""
|
||||||
raise dbt.exceptions.NotImplementedException(
|
raise dbt.exceptions.NotImplementedException(
|
||||||
'`open` is not implemented for this adapter!'
|
"`open` is not implemented for this adapter!"
|
||||||
)
|
)
|
||||||
|
|
||||||
def release(self) -> None:
|
def release(self) -> None:
|
||||||
@@ -189,12 +194,14 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
|||||||
def cleanup_all(self) -> None:
|
def cleanup_all(self) -> None:
|
||||||
with self.lock:
|
with self.lock:
|
||||||
for connection in self.thread_connections.values():
|
for connection in self.thread_connections.values():
|
||||||
if connection.state not in {'closed', 'init'}:
|
if connection.state not in {"closed", "init"}:
|
||||||
logger.debug("Connection '{}' was left open."
|
logger.debug(
|
||||||
.format(connection.name))
|
"Connection '{}' was left open.".format(connection.name)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug("Connection '{}' was properly closed."
|
logger.debug(
|
||||||
.format(connection.name))
|
"Connection '{}' was properly closed.".format(connection.name)
|
||||||
|
)
|
||||||
self.close(connection)
|
self.close(connection)
|
||||||
|
|
||||||
# garbage collect these connections
|
# garbage collect these connections
|
||||||
@@ -204,14 +211,14 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
|||||||
def begin(self) -> None:
|
def begin(self) -> None:
|
||||||
"""Begin a transaction. (passable)"""
|
"""Begin a transaction. (passable)"""
|
||||||
raise dbt.exceptions.NotImplementedException(
|
raise dbt.exceptions.NotImplementedException(
|
||||||
'`begin` is not implemented for this adapter!'
|
"`begin` is not implemented for this adapter!"
|
||||||
)
|
)
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def commit(self) -> None:
|
def commit(self) -> None:
|
||||||
"""Commit a transaction. (passable)"""
|
"""Commit a transaction. (passable)"""
|
||||||
raise dbt.exceptions.NotImplementedException(
|
raise dbt.exceptions.NotImplementedException(
|
||||||
'`commit` is not implemented for this adapter!'
|
"`commit` is not implemented for this adapter!"
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -220,20 +227,17 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
|||||||
try:
|
try:
|
||||||
connection.handle.rollback()
|
connection.handle.rollback()
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.debug(
|
logger.debug("Failed to rollback {}".format(connection.name), exc_info=True)
|
||||||
'Failed to rollback {}'.format(connection.name),
|
|
||||||
exc_info=True
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _close_handle(cls, connection: Connection) -> None:
|
def _close_handle(cls, connection: Connection) -> None:
|
||||||
"""Perform the actual close operation."""
|
"""Perform the actual close operation."""
|
||||||
# On windows, sometimes connection handles don't have a close() attr.
|
# On windows, sometimes connection handles don't have a close() attr.
|
||||||
if hasattr(connection.handle, 'close'):
|
if hasattr(connection.handle, "close"):
|
||||||
logger.debug(f'On {connection.name}: Close')
|
logger.debug(f"On {connection.name}: Close")
|
||||||
connection.handle.close()
|
connection.handle.close()
|
||||||
else:
|
else:
|
||||||
logger.debug(f'On {connection.name}: No close available on handle')
|
logger.debug(f"On {connection.name}: No close available on handle")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _rollback(cls, connection: Connection) -> None:
|
def _rollback(cls, connection: Connection) -> None:
|
||||||
@@ -241,16 +245,16 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
|||||||
if flags.STRICT_MODE:
|
if flags.STRICT_MODE:
|
||||||
if not isinstance(connection, Connection):
|
if not isinstance(connection, Connection):
|
||||||
raise dbt.exceptions.CompilerException(
|
raise dbt.exceptions.CompilerException(
|
||||||
f'In _rollback, got {connection} - not a Connection!'
|
f"In _rollback, got {connection} - not a Connection!"
|
||||||
)
|
)
|
||||||
|
|
||||||
if connection.transaction_open is False:
|
if connection.transaction_open is False:
|
||||||
raise dbt.exceptions.InternalException(
|
raise dbt.exceptions.InternalException(
|
||||||
f'Tried to rollback transaction on connection '
|
f"Tried to rollback transaction on connection "
|
||||||
f'"{connection.name}", but it does not have one open!'
|
f'"{connection.name}", but it does not have one open!'
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f'On {connection.name}: ROLLBACK')
|
logger.debug(f"On {connection.name}: ROLLBACK")
|
||||||
cls._rollback_handle(connection)
|
cls._rollback_handle(connection)
|
||||||
|
|
||||||
connection.transaction_open = False
|
connection.transaction_open = False
|
||||||
@@ -260,7 +264,7 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
|||||||
if flags.STRICT_MODE:
|
if flags.STRICT_MODE:
|
||||||
if not isinstance(connection, Connection):
|
if not isinstance(connection, Connection):
|
||||||
raise dbt.exceptions.CompilerException(
|
raise dbt.exceptions.CompilerException(
|
||||||
f'In close, got {connection} - not a Connection!'
|
f"In close, got {connection} - not a Connection!"
|
||||||
)
|
)
|
||||||
|
|
||||||
# if the connection is in closed or init, there's nothing to do
|
# if the connection is in closed or init, there's nothing to do
|
||||||
@@ -268,7 +272,7 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
|||||||
return connection
|
return connection
|
||||||
|
|
||||||
if connection.transaction_open and connection.handle:
|
if connection.transaction_open and connection.handle:
|
||||||
logger.debug('On {}: ROLLBACK'.format(connection.name))
|
logger.debug("On {}: ROLLBACK".format(connection.name))
|
||||||
cls._rollback_handle(connection)
|
cls._rollback_handle(connection)
|
||||||
connection.transaction_open = False
|
connection.transaction_open = False
|
||||||
|
|
||||||
@@ -302,5 +306,5 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
|||||||
:rtype: Tuple[Union[str, AdapterResponse], agate.Table]
|
:rtype: Tuple[Union[str, AdapterResponse], agate.Table]
|
||||||
"""
|
"""
|
||||||
raise dbt.exceptions.NotImplementedException(
|
raise dbt.exceptions.NotImplementedException(
|
||||||
'`execute` is not implemented for this adapter!'
|
"`execute` is not implemented for this adapter!"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,17 +4,31 @@ from contextlib import contextmanager
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import (
|
from typing import (
|
||||||
Optional, Tuple, Callable, Iterable, Type, Dict, Any, List, Mapping,
|
Optional,
|
||||||
Iterator, Union, Set
|
Tuple,
|
||||||
|
Callable,
|
||||||
|
Iterable,
|
||||||
|
Type,
|
||||||
|
Dict,
|
||||||
|
Any,
|
||||||
|
List,
|
||||||
|
Mapping,
|
||||||
|
Iterator,
|
||||||
|
Union,
|
||||||
|
Set,
|
||||||
)
|
)
|
||||||
|
|
||||||
import agate
|
import agate
|
||||||
import pytz
|
import pytz
|
||||||
|
|
||||||
from dbt.exceptions import (
|
from dbt.exceptions import (
|
||||||
raise_database_error, raise_compiler_error, invalid_type_error,
|
raise_database_error,
|
||||||
|
raise_compiler_error,
|
||||||
|
invalid_type_error,
|
||||||
get_relation_returned_multiple_results,
|
get_relation_returned_multiple_results,
|
||||||
InternalException, NotImplementedException, RuntimeException,
|
InternalException,
|
||||||
|
NotImplementedException,
|
||||||
|
RuntimeException,
|
||||||
)
|
)
|
||||||
from dbt import flags
|
from dbt import flags
|
||||||
|
|
||||||
@@ -25,9 +39,7 @@ from dbt.adapters.protocol import (
|
|||||||
)
|
)
|
||||||
from dbt.clients.agate_helper import empty_table, merge_tables, table_from_rows
|
from dbt.clients.agate_helper import empty_table, merge_tables, table_from_rows
|
||||||
from dbt.clients.jinja import MacroGenerator
|
from dbt.clients.jinja import MacroGenerator
|
||||||
from dbt.contracts.graph.compiled import (
|
from dbt.contracts.graph.compiled import CompileResultNode, CompiledSeedNode
|
||||||
CompileResultNode, CompiledSeedNode
|
|
||||||
)
|
|
||||||
from dbt.contracts.graph.manifest import Manifest, MacroManifest
|
from dbt.contracts.graph.manifest import Manifest, MacroManifest
|
||||||
from dbt.contracts.graph.parsed import ParsedSeedNode
|
from dbt.contracts.graph.parsed import ParsedSeedNode
|
||||||
from dbt.exceptions import warn_or_error
|
from dbt.exceptions import warn_or_error
|
||||||
@@ -38,7 +50,10 @@ from dbt.utils import filter_null_values, executor
|
|||||||
from dbt.adapters.base.connections import Connection, AdapterResponse
|
from dbt.adapters.base.connections import Connection, AdapterResponse
|
||||||
from dbt.adapters.base.meta import AdapterMeta, available
|
from dbt.adapters.base.meta import AdapterMeta, available
|
||||||
from dbt.adapters.base.relation import (
|
from dbt.adapters.base.relation import (
|
||||||
ComponentName, BaseRelation, InformationSchema, SchemaSearchMap
|
ComponentName,
|
||||||
|
BaseRelation,
|
||||||
|
InformationSchema,
|
||||||
|
SchemaSearchMap,
|
||||||
)
|
)
|
||||||
from dbt.adapters.base import Column as BaseColumn
|
from dbt.adapters.base import Column as BaseColumn
|
||||||
from dbt.adapters.cache import RelationsCache
|
from dbt.adapters.cache import RelationsCache
|
||||||
@@ -47,15 +62,14 @@ from dbt.adapters.cache import RelationsCache
|
|||||||
SeedModel = Union[ParsedSeedNode, CompiledSeedNode]
|
SeedModel = Union[ParsedSeedNode, CompiledSeedNode]
|
||||||
|
|
||||||
|
|
||||||
GET_CATALOG_MACRO_NAME = 'get_catalog'
|
GET_CATALOG_MACRO_NAME = "get_catalog"
|
||||||
FRESHNESS_MACRO_NAME = 'collect_freshness'
|
FRESHNESS_MACRO_NAME = "collect_freshness"
|
||||||
|
|
||||||
|
|
||||||
def _expect_row_value(key: str, row: agate.Row):
|
def _expect_row_value(key: str, row: agate.Row):
|
||||||
if key not in row.keys():
|
if key not in row.keys():
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
'Got a row without "{}" column, columns: {}'
|
'Got a row without "{}" column, columns: {}'.format(key, row.keys())
|
||||||
.format(key, row.keys())
|
|
||||||
)
|
)
|
||||||
return row[key]
|
return row[key]
|
||||||
|
|
||||||
@@ -64,40 +78,37 @@ def _catalog_filter_schemas(manifest: Manifest) -> Callable[[agate.Row], bool]:
|
|||||||
"""Return a function that takes a row and decides if the row should be
|
"""Return a function that takes a row and decides if the row should be
|
||||||
included in the catalog output.
|
included in the catalog output.
|
||||||
"""
|
"""
|
||||||
schemas = frozenset((d.lower(), s.lower())
|
schemas = frozenset((d.lower(), s.lower()) for d, s in manifest.get_used_schemas())
|
||||||
for d, s in manifest.get_used_schemas())
|
|
||||||
|
|
||||||
def test(row: agate.Row) -> bool:
|
def test(row: agate.Row) -> bool:
|
||||||
table_database = _expect_row_value('table_database', row)
|
table_database = _expect_row_value("table_database", row)
|
||||||
table_schema = _expect_row_value('table_schema', row)
|
table_schema = _expect_row_value("table_schema", row)
|
||||||
# the schema may be present but None, which is not an error and should
|
# the schema may be present but None, which is not an error and should
|
||||||
# be filtered out
|
# be filtered out
|
||||||
if table_schema is None:
|
if table_schema is None:
|
||||||
return False
|
return False
|
||||||
return (table_database.lower(), table_schema.lower()) in schemas
|
return (table_database.lower(), table_schema.lower()) in schemas
|
||||||
|
|
||||||
return test
|
return test
|
||||||
|
|
||||||
|
|
||||||
def _utc(
|
def _utc(dt: Optional[datetime], source: BaseRelation, field_name: str) -> datetime:
|
||||||
dt: Optional[datetime], source: BaseRelation, field_name: str
|
|
||||||
) -> datetime:
|
|
||||||
"""If dt has a timezone, return a new datetime that's in UTC. Otherwise,
|
"""If dt has a timezone, return a new datetime that's in UTC. Otherwise,
|
||||||
assume the datetime is already for UTC and add the timezone.
|
assume the datetime is already for UTC and add the timezone.
|
||||||
"""
|
"""
|
||||||
if dt is None:
|
if dt is None:
|
||||||
raise raise_database_error(
|
raise raise_database_error(
|
||||||
"Expected a non-null value when querying field '{}' of table "
|
"Expected a non-null value when querying field '{}' of table "
|
||||||
" {} but received value 'null' instead".format(
|
" {} but received value 'null' instead".format(field_name, source)
|
||||||
field_name,
|
)
|
||||||
source))
|
|
||||||
|
|
||||||
elif not hasattr(dt, 'tzinfo'):
|
elif not hasattr(dt, "tzinfo"):
|
||||||
raise raise_database_error(
|
raise raise_database_error(
|
||||||
"Expected a timestamp value when querying field '{}' of table "
|
"Expected a timestamp value when querying field '{}' of table "
|
||||||
"{} but received value of type '{}' instead".format(
|
"{} but received value of type '{}' instead".format(
|
||||||
field_name,
|
field_name, source, type(dt).__name__
|
||||||
source,
|
)
|
||||||
type(dt).__name__))
|
)
|
||||||
|
|
||||||
elif dt.tzinfo:
|
elif dt.tzinfo:
|
||||||
return dt.astimezone(pytz.UTC)
|
return dt.astimezone(pytz.UTC)
|
||||||
@@ -107,7 +118,7 @@ def _utc(
|
|||||||
|
|
||||||
def _relation_name(rel: Optional[BaseRelation]) -> str:
|
def _relation_name(rel: Optional[BaseRelation]) -> str:
|
||||||
if rel is None:
|
if rel is None:
|
||||||
return 'null relation'
|
return "null relation"
|
||||||
else:
|
else:
|
||||||
return str(rel)
|
return str(rel)
|
||||||
|
|
||||||
@@ -148,6 +159,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
Macros:
|
Macros:
|
||||||
- get_catalog
|
- get_catalog
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Relation: Type[BaseRelation] = BaseRelation
|
Relation: Type[BaseRelation] = BaseRelation
|
||||||
Column: Type[BaseColumn] = BaseColumn
|
Column: Type[BaseColumn] = BaseColumn
|
||||||
ConnectionManager: Type[ConnectionManagerProtocol]
|
ConnectionManager: Type[ConnectionManagerProtocol]
|
||||||
@@ -181,12 +193,12 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
self.connections.commit_if_has_connection()
|
self.connections.commit_if_has_connection()
|
||||||
|
|
||||||
def debug_query(self) -> None:
|
def debug_query(self) -> None:
|
||||||
self.execute('select 1 as id')
|
self.execute("select 1 as id")
|
||||||
|
|
||||||
def nice_connection_name(self) -> str:
|
def nice_connection_name(self) -> str:
|
||||||
conn = self.connections.get_if_exists()
|
conn = self.connections.get_if_exists()
|
||||||
if conn is None or conn.name is None:
|
if conn is None or conn.name is None:
|
||||||
return '<None>'
|
return "<None>"
|
||||||
return conn.name
|
return conn.name
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
@@ -204,13 +216,11 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
self.connections.query_header.reset()
|
self.connections.query_header.reset()
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def connection_for(
|
def connection_for(self, node: CompileResultNode) -> Iterator[None]:
|
||||||
self, node: CompileResultNode
|
|
||||||
) -> Iterator[None]:
|
|
||||||
with self.connection_named(node.unique_id, node):
|
with self.connection_named(node.unique_id, node):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@available.parse(lambda *a, **k: ('', empty_table()))
|
@available.parse(lambda *a, **k: ("", empty_table()))
|
||||||
def execute(
|
def execute(
|
||||||
self, sql: str, auto_begin: bool = False, fetch: bool = False
|
self, sql: str, auto_begin: bool = False, fetch: bool = False
|
||||||
) -> Tuple[Union[str, AdapterResponse], agate.Table]:
|
) -> Tuple[Union[str, AdapterResponse], agate.Table]:
|
||||||
@@ -224,16 +234,10 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
:return: A tuple of the status and the results (empty if fetch=False).
|
:return: A tuple of the status and the results (empty if fetch=False).
|
||||||
:rtype: Tuple[Union[str, AdapterResponse], agate.Table]
|
:rtype: Tuple[Union[str, AdapterResponse], agate.Table]
|
||||||
"""
|
"""
|
||||||
return self.connections.execute(
|
return self.connections.execute(sql=sql, auto_begin=auto_begin, fetch=fetch)
|
||||||
sql=sql,
|
|
||||||
auto_begin=auto_begin,
|
|
||||||
fetch=fetch
|
|
||||||
)
|
|
||||||
|
|
||||||
@available.parse(lambda *a, **k: ('', empty_table()))
|
@available.parse(lambda *a, **k: ("", empty_table()))
|
||||||
def get_partitions_metadata(
|
def get_partitions_metadata(self, table: str) -> Tuple[agate.Table]:
|
||||||
self, table: str
|
|
||||||
) -> Tuple[agate.Table]:
|
|
||||||
"""Obtain partitions metadata for a BigQuery partitioned table.
|
"""Obtain partitions metadata for a BigQuery partitioned table.
|
||||||
|
|
||||||
:param str table_id: a partitioned table id, in standard SQL format.
|
:param str table_id: a partitioned table id, in standard SQL format.
|
||||||
@@ -241,9 +245,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
https://cloud.google.com/bigquery/docs/creating-partitioned-tables#getting_partition_metadata_using_meta_tables.
|
https://cloud.google.com/bigquery/docs/creating-partitioned-tables#getting_partition_metadata_using_meta_tables.
|
||||||
:rtype: agate.Table
|
:rtype: agate.Table
|
||||||
"""
|
"""
|
||||||
return self.connections.get_partitions_metadata(
|
return self.connections.get_partitions_metadata(table=table)
|
||||||
table=table
|
|
||||||
)
|
|
||||||
|
|
||||||
###
|
###
|
||||||
# Methods that should never be overridden
|
# Methods that should never be overridden
|
||||||
@@ -274,6 +276,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
if self._macro_manifest_lazy is None:
|
if self._macro_manifest_lazy is None:
|
||||||
# avoid a circular import
|
# avoid a circular import
|
||||||
from dbt.parser.manifest import load_macro_manifest
|
from dbt.parser.manifest import load_macro_manifest
|
||||||
|
|
||||||
manifest = load_macro_manifest(
|
manifest = load_macro_manifest(
|
||||||
self.config, self.connections.set_query_header
|
self.config, self.connections.set_query_header
|
||||||
)
|
)
|
||||||
@@ -294,8 +297,9 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
return False
|
return False
|
||||||
elif (database, schema) not in self.cache:
|
elif (database, schema) not in self.cache:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
'On "{}": cache miss for schema "{}.{}", this is inefficient'
|
'On "{}": cache miss for schema "{}.{}", this is inefficient'.format(
|
||||||
.format(self.nice_connection_name(), database, schema)
|
self.nice_connection_name(), database, schema
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
@@ -310,8 +314,8 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
self.Relation.create_from(self.config, node).without_identifier()
|
self.Relation.create_from(self.config, node).without_identifier()
|
||||||
for node in manifest.nodes.values()
|
for node in manifest.nodes.values()
|
||||||
if (
|
if (
|
||||||
node.resource_type in NodeType.executable() and
|
node.resource_type in NodeType.executable()
|
||||||
not node.is_ephemeral_model
|
and not node.is_ephemeral_model
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -351,9 +355,9 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
for cache_schema in cache_schemas:
|
for cache_schema in cache_schemas:
|
||||||
fut = tpe.submit_connected(
|
fut = tpe.submit_connected(
|
||||||
self,
|
self,
|
||||||
f'list_{cache_schema.database}_{cache_schema.schema}',
|
f"list_{cache_schema.database}_{cache_schema.schema}",
|
||||||
self.list_relations_without_caching,
|
self.list_relations_without_caching,
|
||||||
cache_schema
|
cache_schema,
|
||||||
)
|
)
|
||||||
futures.append(fut)
|
futures.append(fut)
|
||||||
|
|
||||||
@@ -371,9 +375,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
cache_update.add((relation.database, relation.schema))
|
cache_update.add((relation.database, relation.schema))
|
||||||
self.cache.update_schemas(cache_update)
|
self.cache.update_schemas(cache_update)
|
||||||
|
|
||||||
def set_relations_cache(
|
def set_relations_cache(self, manifest: Manifest, clear: bool = False) -> None:
|
||||||
self, manifest: Manifest, clear: bool = False
|
|
||||||
) -> None:
|
|
||||||
"""Run a query that gets a populated cache of the relations in the
|
"""Run a query that gets a populated cache of the relations in the
|
||||||
database and set the cache on this adapter.
|
database and set the cache on this adapter.
|
||||||
"""
|
"""
|
||||||
@@ -391,12 +393,12 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
if relation is None:
|
if relation is None:
|
||||||
name = self.nice_connection_name()
|
name = self.nice_connection_name()
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
'Attempted to cache a null relation for {}'.format(name)
|
"Attempted to cache a null relation for {}".format(name)
|
||||||
)
|
)
|
||||||
if flags.USE_CACHE:
|
if flags.USE_CACHE:
|
||||||
self.cache.add(relation)
|
self.cache.add(relation)
|
||||||
# so jinja doesn't render things
|
# so jinja doesn't render things
|
||||||
return ''
|
return ""
|
||||||
|
|
||||||
@available
|
@available
|
||||||
def cache_dropped(self, relation: Optional[BaseRelation]) -> str:
|
def cache_dropped(self, relation: Optional[BaseRelation]) -> str:
|
||||||
@@ -406,11 +408,11 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
if relation is None:
|
if relation is None:
|
||||||
name = self.nice_connection_name()
|
name = self.nice_connection_name()
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
'Attempted to drop a null relation for {}'.format(name)
|
"Attempted to drop a null relation for {}".format(name)
|
||||||
)
|
)
|
||||||
if flags.USE_CACHE:
|
if flags.USE_CACHE:
|
||||||
self.cache.drop(relation)
|
self.cache.drop(relation)
|
||||||
return ''
|
return ""
|
||||||
|
|
||||||
@available
|
@available
|
||||||
def cache_renamed(
|
def cache_renamed(
|
||||||
@@ -426,13 +428,12 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
src_name = _relation_name(from_relation)
|
src_name = _relation_name(from_relation)
|
||||||
dst_name = _relation_name(to_relation)
|
dst_name = _relation_name(to_relation)
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
'Attempted to rename {} to {} for {}'
|
"Attempted to rename {} to {} for {}".format(src_name, dst_name, name)
|
||||||
.format(src_name, dst_name, name)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if flags.USE_CACHE:
|
if flags.USE_CACHE:
|
||||||
self.cache.rename(from_relation, to_relation)
|
self.cache.rename(from_relation, to_relation)
|
||||||
return ''
|
return ""
|
||||||
|
|
||||||
###
|
###
|
||||||
# Abstract methods for database-specific values, attributes, and types
|
# Abstract methods for database-specific values, attributes, and types
|
||||||
@@ -441,12 +442,13 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
def date_function(cls) -> str:
|
def date_function(cls) -> str:
|
||||||
"""Get the date function used by this adapter's database."""
|
"""Get the date function used by this adapter's database."""
|
||||||
raise NotImplementedException(
|
raise NotImplementedException(
|
||||||
'`date_function` is not implemented for this adapter!')
|
"`date_function` is not implemented for this adapter!"
|
||||||
|
)
|
||||||
|
|
||||||
@abc.abstractclassmethod
|
@abc.abstractclassmethod
|
||||||
def is_cancelable(cls) -> bool:
|
def is_cancelable(cls) -> bool:
|
||||||
raise NotImplementedException(
|
raise NotImplementedException(
|
||||||
'`is_cancelable` is not implemented for this adapter!'
|
"`is_cancelable` is not implemented for this adapter!"
|
||||||
)
|
)
|
||||||
|
|
||||||
###
|
###
|
||||||
@@ -456,7 +458,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
def list_schemas(self, database: str) -> List[str]:
|
def list_schemas(self, database: str) -> List[str]:
|
||||||
"""Get a list of existing schemas in database"""
|
"""Get a list of existing schemas in database"""
|
||||||
raise NotImplementedException(
|
raise NotImplementedException(
|
||||||
'`list_schemas` is not implemented for this adapter!'
|
"`list_schemas` is not implemented for this adapter!"
|
||||||
)
|
)
|
||||||
|
|
||||||
@available.parse(lambda *a, **k: False)
|
@available.parse(lambda *a, **k: False)
|
||||||
@@ -467,10 +469,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
and adapters should implement it if there is an optimized path (and
|
and adapters should implement it if there is an optimized path (and
|
||||||
there probably is)
|
there probably is)
|
||||||
"""
|
"""
|
||||||
search = (
|
search = (s.lower() for s in self.list_schemas(database=database))
|
||||||
s.lower() for s in
|
|
||||||
self.list_schemas(database=database)
|
|
||||||
)
|
|
||||||
return schema.lower() in search
|
return schema.lower() in search
|
||||||
|
|
||||||
###
|
###
|
||||||
@@ -484,7 +483,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
*Implementors must call self.cache.drop() to preserve cache state!*
|
*Implementors must call self.cache.drop() to preserve cache state!*
|
||||||
"""
|
"""
|
||||||
raise NotImplementedException(
|
raise NotImplementedException(
|
||||||
'`drop_relation` is not implemented for this adapter!'
|
"`drop_relation` is not implemented for this adapter!"
|
||||||
)
|
)
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
@@ -492,7 +491,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
def truncate_relation(self, relation: BaseRelation) -> None:
|
def truncate_relation(self, relation: BaseRelation) -> None:
|
||||||
"""Truncate the given relation."""
|
"""Truncate the given relation."""
|
||||||
raise NotImplementedException(
|
raise NotImplementedException(
|
||||||
'`truncate_relation` is not implemented for this adapter!'
|
"`truncate_relation` is not implemented for this adapter!"
|
||||||
)
|
)
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
@@ -505,36 +504,30 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
Implementors must call self.cache.rename() to preserve cache state.
|
Implementors must call self.cache.rename() to preserve cache state.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedException(
|
raise NotImplementedException(
|
||||||
'`rename_relation` is not implemented for this adapter!'
|
"`rename_relation` is not implemented for this adapter!"
|
||||||
)
|
)
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
@available.parse_list
|
@available.parse_list
|
||||||
def get_columns_in_relation(
|
def get_columns_in_relation(self, relation: BaseRelation) -> List[BaseColumn]:
|
||||||
self, relation: BaseRelation
|
|
||||||
) -> List[BaseColumn]:
|
|
||||||
"""Get a list of the columns in the given Relation."""
|
"""Get a list of the columns in the given Relation."""
|
||||||
raise NotImplementedException(
|
raise NotImplementedException(
|
||||||
'`get_columns_in_relation` is not implemented for this adapter!'
|
"`get_columns_in_relation` is not implemented for this adapter!"
|
||||||
)
|
)
|
||||||
|
|
||||||
@available.deprecated('get_columns_in_relation', lambda *a, **k: [])
|
@available.deprecated("get_columns_in_relation", lambda *a, **k: [])
|
||||||
def get_columns_in_table(
|
def get_columns_in_table(self, schema: str, identifier: str) -> List[BaseColumn]:
|
||||||
self, schema: str, identifier: str
|
|
||||||
) -> List[BaseColumn]:
|
|
||||||
"""DEPRECATED: Get a list of the columns in the given table."""
|
"""DEPRECATED: Get a list of the columns in the given table."""
|
||||||
relation = self.Relation.create(
|
relation = self.Relation.create(
|
||||||
database=self.config.credentials.database,
|
database=self.config.credentials.database,
|
||||||
schema=schema,
|
schema=schema,
|
||||||
identifier=identifier,
|
identifier=identifier,
|
||||||
quote_policy=self.config.quoting
|
quote_policy=self.config.quoting,
|
||||||
)
|
)
|
||||||
return self.get_columns_in_relation(relation)
|
return self.get_columns_in_relation(relation)
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def expand_column_types(
|
def expand_column_types(self, goal: BaseRelation, current: BaseRelation) -> None:
|
||||||
self, goal: BaseRelation, current: BaseRelation
|
|
||||||
) -> None:
|
|
||||||
"""Expand the current table's types to match the goal table. (passable)
|
"""Expand the current table's types to match the goal table. (passable)
|
||||||
|
|
||||||
:param self.Relation goal: A relation that currently exists in the
|
:param self.Relation goal: A relation that currently exists in the
|
||||||
@@ -543,7 +536,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
database with columns of unspecified types.
|
database with columns of unspecified types.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedException(
|
raise NotImplementedException(
|
||||||
'`expand_target_column_types` is not implemented for this adapter!'
|
"`expand_target_column_types` is not implemented for this adapter!"
|
||||||
)
|
)
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
@@ -560,8 +553,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
:rtype: List[self.Relation]
|
:rtype: List[self.Relation]
|
||||||
"""
|
"""
|
||||||
raise NotImplementedException(
|
raise NotImplementedException(
|
||||||
'`list_relations_without_caching` is not implemented for this '
|
"`list_relations_without_caching` is not implemented for this " "adapter!"
|
||||||
'adapter!'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
###
|
###
|
||||||
@@ -576,32 +568,33 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
"""
|
"""
|
||||||
if not isinstance(from_relation, self.Relation):
|
if not isinstance(from_relation, self.Relation):
|
||||||
invalid_type_error(
|
invalid_type_error(
|
||||||
method_name='get_missing_columns',
|
method_name="get_missing_columns",
|
||||||
arg_name='from_relation',
|
arg_name="from_relation",
|
||||||
got_value=from_relation,
|
got_value=from_relation,
|
||||||
expected_type=self.Relation)
|
expected_type=self.Relation,
|
||||||
|
)
|
||||||
|
|
||||||
if not isinstance(to_relation, self.Relation):
|
if not isinstance(to_relation, self.Relation):
|
||||||
invalid_type_error(
|
invalid_type_error(
|
||||||
method_name='get_missing_columns',
|
method_name="get_missing_columns",
|
||||||
arg_name='to_relation',
|
arg_name="to_relation",
|
||||||
got_value=to_relation,
|
got_value=to_relation,
|
||||||
expected_type=self.Relation)
|
expected_type=self.Relation,
|
||||||
|
)
|
||||||
|
|
||||||
from_columns = {
|
from_columns = {
|
||||||
col.name: col for col in
|
col.name: col for col in self.get_columns_in_relation(from_relation)
|
||||||
self.get_columns_in_relation(from_relation)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
to_columns = {
|
to_columns = {
|
||||||
col.name: col for col in
|
col.name: col for col in self.get_columns_in_relation(to_relation)
|
||||||
self.get_columns_in_relation(to_relation)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
missing_columns = set(from_columns.keys()) - set(to_columns.keys())
|
missing_columns = set(from_columns.keys()) - set(to_columns.keys())
|
||||||
|
|
||||||
return [
|
return [
|
||||||
col for (col_name, col) in from_columns.items()
|
col
|
||||||
|
for (col_name, col) in from_columns.items()
|
||||||
if col_name in missing_columns
|
if col_name in missing_columns
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -616,18 +609,19 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
"""
|
"""
|
||||||
if not isinstance(relation, self.Relation):
|
if not isinstance(relation, self.Relation):
|
||||||
invalid_type_error(
|
invalid_type_error(
|
||||||
method_name='valid_snapshot_target',
|
method_name="valid_snapshot_target",
|
||||||
arg_name='relation',
|
arg_name="relation",
|
||||||
got_value=relation,
|
got_value=relation,
|
||||||
expected_type=self.Relation)
|
expected_type=self.Relation,
|
||||||
|
)
|
||||||
|
|
||||||
columns = self.get_columns_in_relation(relation)
|
columns = self.get_columns_in_relation(relation)
|
||||||
names = set(c.name.lower() for c in columns)
|
names = set(c.name.lower() for c in columns)
|
||||||
expanded_keys = ('scd_id', 'valid_from', 'valid_to')
|
expanded_keys = ("scd_id", "valid_from", "valid_to")
|
||||||
extra = []
|
extra = []
|
||||||
missing = []
|
missing = []
|
||||||
for legacy in expanded_keys:
|
for legacy in expanded_keys:
|
||||||
desired = 'dbt_' + legacy
|
desired = "dbt_" + legacy
|
||||||
if desired not in names:
|
if desired not in names:
|
||||||
missing.append(desired)
|
missing.append(desired)
|
||||||
if legacy in names:
|
if legacy in names:
|
||||||
@@ -637,13 +631,13 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
if extra:
|
if extra:
|
||||||
msg = (
|
msg = (
|
||||||
'Snapshot target has ("{}") but not ("{}") - is it an '
|
'Snapshot target has ("{}") but not ("{}") - is it an '
|
||||||
'unmigrated previous version archive?'
|
"unmigrated previous version archive?".format(
|
||||||
.format('", "'.join(extra), '", "'.join(missing))
|
'", "'.join(extra), '", "'.join(missing)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
msg = (
|
msg = 'Snapshot target is not a snapshot table (missing "{}")'.format(
|
||||||
'Snapshot target is not a snapshot table (missing "{}")'
|
'", "'.join(missing)
|
||||||
.format('", "'.join(missing))
|
|
||||||
)
|
)
|
||||||
raise_compiler_error(msg)
|
raise_compiler_error(msg)
|
||||||
|
|
||||||
@@ -653,17 +647,19 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
) -> None:
|
) -> None:
|
||||||
if not isinstance(from_relation, self.Relation):
|
if not isinstance(from_relation, self.Relation):
|
||||||
invalid_type_error(
|
invalid_type_error(
|
||||||
method_name='expand_target_column_types',
|
method_name="expand_target_column_types",
|
||||||
arg_name='from_relation',
|
arg_name="from_relation",
|
||||||
got_value=from_relation,
|
got_value=from_relation,
|
||||||
expected_type=self.Relation)
|
expected_type=self.Relation,
|
||||||
|
)
|
||||||
|
|
||||||
if not isinstance(to_relation, self.Relation):
|
if not isinstance(to_relation, self.Relation):
|
||||||
invalid_type_error(
|
invalid_type_error(
|
||||||
method_name='expand_target_column_types',
|
method_name="expand_target_column_types",
|
||||||
arg_name='to_relation',
|
arg_name="to_relation",
|
||||||
got_value=to_relation,
|
got_value=to_relation,
|
||||||
expected_type=self.Relation)
|
expected_type=self.Relation,
|
||||||
|
)
|
||||||
|
|
||||||
self.expand_column_types(from_relation, to_relation)
|
self.expand_column_types(from_relation, to_relation)
|
||||||
|
|
||||||
@@ -676,38 +672,41 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
schema_relation = self.Relation.create(
|
schema_relation = self.Relation.create(
|
||||||
database=database,
|
database=database,
|
||||||
schema=schema,
|
schema=schema,
|
||||||
identifier='',
|
identifier="",
|
||||||
quote_policy=self.config.quoting
|
quote_policy=self.config.quoting,
|
||||||
).without_identifier()
|
).without_identifier()
|
||||||
|
|
||||||
# we can't build the relations cache because we don't have a
|
# we can't build the relations cache because we don't have a
|
||||||
# manifest so we can't run any operations.
|
# manifest so we can't run any operations.
|
||||||
relations = self.list_relations_without_caching(
|
relations = self.list_relations_without_caching(schema_relation)
|
||||||
schema_relation
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug('with database={}, schema={}, relations={}'
|
logger.debug(
|
||||||
.format(database, schema, relations))
|
"with database={}, schema={}, relations={}".format(
|
||||||
|
database, schema, relations
|
||||||
|
)
|
||||||
|
)
|
||||||
return relations
|
return relations
|
||||||
|
|
||||||
def _make_match_kwargs(
|
def _make_match_kwargs(
|
||||||
self, database: str, schema: str, identifier: str
|
self, database: str, schema: str, identifier: str
|
||||||
) -> Dict[str, str]:
|
) -> Dict[str, str]:
|
||||||
quoting = self.config.quoting
|
quoting = self.config.quoting
|
||||||
if identifier is not None and quoting['identifier'] is False:
|
if identifier is not None and quoting["identifier"] is False:
|
||||||
identifier = identifier.lower()
|
identifier = identifier.lower()
|
||||||
|
|
||||||
if schema is not None and quoting['schema'] is False:
|
if schema is not None and quoting["schema"] is False:
|
||||||
schema = schema.lower()
|
schema = schema.lower()
|
||||||
|
|
||||||
if database is not None and quoting['database'] is False:
|
if database is not None and quoting["database"] is False:
|
||||||
database = database.lower()
|
database = database.lower()
|
||||||
|
|
||||||
return filter_null_values({
|
return filter_null_values(
|
||||||
'database': database,
|
{
|
||||||
'identifier': identifier,
|
"database": database,
|
||||||
'schema': schema,
|
"identifier": identifier,
|
||||||
})
|
"schema": schema,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def _make_match(
|
def _make_match(
|
||||||
self,
|
self,
|
||||||
@@ -733,25 +732,22 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
) -> Optional[BaseRelation]:
|
) -> Optional[BaseRelation]:
|
||||||
relations_list = self.list_relations(database, schema)
|
relations_list = self.list_relations(database, schema)
|
||||||
|
|
||||||
matches = self._make_match(relations_list, database, schema,
|
matches = self._make_match(relations_list, database, schema, identifier)
|
||||||
identifier)
|
|
||||||
|
|
||||||
if len(matches) > 1:
|
if len(matches) > 1:
|
||||||
kwargs = {
|
kwargs = {
|
||||||
'identifier': identifier,
|
"identifier": identifier,
|
||||||
'schema': schema,
|
"schema": schema,
|
||||||
'database': database,
|
"database": database,
|
||||||
}
|
}
|
||||||
get_relation_returned_multiple_results(
|
get_relation_returned_multiple_results(kwargs, matches)
|
||||||
kwargs, matches
|
|
||||||
)
|
|
||||||
|
|
||||||
elif matches:
|
elif matches:
|
||||||
return matches[0]
|
return matches[0]
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@available.deprecated('get_relation', lambda *a, **k: False)
|
@available.deprecated("get_relation", lambda *a, **k: False)
|
||||||
def already_exists(self, schema: str, name: str) -> bool:
|
def already_exists(self, schema: str, name: str) -> bool:
|
||||||
"""DEPRECATED: Return if a model already exists in the database"""
|
"""DEPRECATED: Return if a model already exists in the database"""
|
||||||
database = self.config.credentials.database
|
database = self.config.credentials.database
|
||||||
@@ -767,7 +763,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
def create_schema(self, relation: BaseRelation):
|
def create_schema(self, relation: BaseRelation):
|
||||||
"""Create the given schema if it does not exist."""
|
"""Create the given schema if it does not exist."""
|
||||||
raise NotImplementedException(
|
raise NotImplementedException(
|
||||||
'`create_schema` is not implemented for this adapter!'
|
"`create_schema` is not implemented for this adapter!"
|
||||||
)
|
)
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
@@ -775,16 +771,14 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
def drop_schema(self, relation: BaseRelation):
|
def drop_schema(self, relation: BaseRelation):
|
||||||
"""Drop the given schema (and everything in it) if it exists."""
|
"""Drop the given schema (and everything in it) if it exists."""
|
||||||
raise NotImplementedException(
|
raise NotImplementedException(
|
||||||
'`drop_schema` is not implemented for this adapter!'
|
"`drop_schema` is not implemented for this adapter!"
|
||||||
)
|
)
|
||||||
|
|
||||||
@available
|
@available
|
||||||
@abc.abstractclassmethod
|
@abc.abstractclassmethod
|
||||||
def quote(cls, identifier: str) -> str:
|
def quote(cls, identifier: str) -> str:
|
||||||
"""Quote the given identifier, as appropriate for the database."""
|
"""Quote the given identifier, as appropriate for the database."""
|
||||||
raise NotImplementedException(
|
raise NotImplementedException("`quote` is not implemented for this adapter!")
|
||||||
'`quote` is not implemented for this adapter!'
|
|
||||||
)
|
|
||||||
|
|
||||||
@available
|
@available
|
||||||
def quote_as_configured(self, identifier: str, quote_key: str) -> str:
|
def quote_as_configured(self, identifier: str, quote_key: str) -> str:
|
||||||
@@ -806,19 +800,17 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
return identifier
|
return identifier
|
||||||
|
|
||||||
@available
|
@available
|
||||||
def quote_seed_column(
|
def quote_seed_column(self, column: str, quote_config: Optional[bool]) -> str:
|
||||||
self, column: str, quote_config: Optional[bool]
|
|
||||||
) -> str:
|
|
||||||
# this is the default for now
|
# this is the default for now
|
||||||
quote_columns: bool = False
|
quote_columns: bool = False
|
||||||
if isinstance(quote_config, bool):
|
if isinstance(quote_config, bool):
|
||||||
quote_columns = quote_config
|
quote_columns = quote_config
|
||||||
elif quote_config is None:
|
elif quote_config is None:
|
||||||
deprecations.warn('column-quoting-unset')
|
deprecations.warn("column-quoting-unset")
|
||||||
else:
|
else:
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
f'The seed configuration value of "quote_columns" has an '
|
f'The seed configuration value of "quote_columns" has an '
|
||||||
f'invalid type {type(quote_config)}'
|
f"invalid type {type(quote_config)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if quote_columns:
|
if quote_columns:
|
||||||
@@ -831,9 +823,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
# converting agate types into their sql equivalents.
|
# converting agate types into their sql equivalents.
|
||||||
###
|
###
|
||||||
@abc.abstractclassmethod
|
@abc.abstractclassmethod
|
||||||
def convert_text_type(
|
def convert_text_type(cls, agate_table: agate.Table, col_idx: int) -> str:
|
||||||
cls, agate_table: agate.Table, col_idx: int
|
|
||||||
) -> str:
|
|
||||||
"""Return the type in the database that best maps to the agate.Text
|
"""Return the type in the database that best maps to the agate.Text
|
||||||
type for the given agate table and column index.
|
type for the given agate table and column index.
|
||||||
|
|
||||||
@@ -842,12 +832,11 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
:return: The name of the type in the database
|
:return: The name of the type in the database
|
||||||
"""
|
"""
|
||||||
raise NotImplementedException(
|
raise NotImplementedException(
|
||||||
'`convert_text_type` is not implemented for this adapter!')
|
"`convert_text_type` is not implemented for this adapter!"
|
||||||
|
)
|
||||||
|
|
||||||
@abc.abstractclassmethod
|
@abc.abstractclassmethod
|
||||||
def convert_number_type(
|
def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str:
|
||||||
cls, agate_table: agate.Table, col_idx: int
|
|
||||||
) -> str:
|
|
||||||
"""Return the type in the database that best maps to the agate.Number
|
"""Return the type in the database that best maps to the agate.Number
|
||||||
type for the given agate table and column index.
|
type for the given agate table and column index.
|
||||||
|
|
||||||
@@ -856,12 +845,11 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
:return: The name of the type in the database
|
:return: The name of the type in the database
|
||||||
"""
|
"""
|
||||||
raise NotImplementedException(
|
raise NotImplementedException(
|
||||||
'`convert_number_type` is not implemented for this adapter!')
|
"`convert_number_type` is not implemented for this adapter!"
|
||||||
|
)
|
||||||
|
|
||||||
@abc.abstractclassmethod
|
@abc.abstractclassmethod
|
||||||
def convert_boolean_type(
|
def convert_boolean_type(cls, agate_table: agate.Table, col_idx: int) -> str:
|
||||||
cls, agate_table: agate.Table, col_idx: int
|
|
||||||
) -> str:
|
|
||||||
"""Return the type in the database that best maps to the agate.Boolean
|
"""Return the type in the database that best maps to the agate.Boolean
|
||||||
type for the given agate table and column index.
|
type for the given agate table and column index.
|
||||||
|
|
||||||
@@ -870,12 +858,11 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
:return: The name of the type in the database
|
:return: The name of the type in the database
|
||||||
"""
|
"""
|
||||||
raise NotImplementedException(
|
raise NotImplementedException(
|
||||||
'`convert_boolean_type` is not implemented for this adapter!')
|
"`convert_boolean_type` is not implemented for this adapter!"
|
||||||
|
)
|
||||||
|
|
||||||
@abc.abstractclassmethod
|
@abc.abstractclassmethod
|
||||||
def convert_datetime_type(
|
def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str:
|
||||||
cls, agate_table: agate.Table, col_idx: int
|
|
||||||
) -> str:
|
|
||||||
"""Return the type in the database that best maps to the agate.DateTime
|
"""Return the type in the database that best maps to the agate.DateTime
|
||||||
type for the given agate table and column index.
|
type for the given agate table and column index.
|
||||||
|
|
||||||
@@ -884,7 +871,8 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
:return: The name of the type in the database
|
:return: The name of the type in the database
|
||||||
"""
|
"""
|
||||||
raise NotImplementedException(
|
raise NotImplementedException(
|
||||||
'`convert_datetime_type` is not implemented for this adapter!')
|
"`convert_datetime_type` is not implemented for this adapter!"
|
||||||
|
)
|
||||||
|
|
||||||
@abc.abstractclassmethod
|
@abc.abstractclassmethod
|
||||||
def convert_date_type(cls, agate_table: agate.Table, col_idx: int) -> str:
|
def convert_date_type(cls, agate_table: agate.Table, col_idx: int) -> str:
|
||||||
@@ -896,7 +884,8 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
:return: The name of the type in the database
|
:return: The name of the type in the database
|
||||||
"""
|
"""
|
||||||
raise NotImplementedException(
|
raise NotImplementedException(
|
||||||
'`convert_date_type` is not implemented for this adapter!')
|
"`convert_date_type` is not implemented for this adapter!"
|
||||||
|
)
|
||||||
|
|
||||||
@abc.abstractclassmethod
|
@abc.abstractclassmethod
|
||||||
def convert_time_type(cls, agate_table: agate.Table, col_idx: int) -> str:
|
def convert_time_type(cls, agate_table: agate.Table, col_idx: int) -> str:
|
||||||
@@ -908,13 +897,12 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
:return: The name of the type in the database
|
:return: The name of the type in the database
|
||||||
"""
|
"""
|
||||||
raise NotImplementedException(
|
raise NotImplementedException(
|
||||||
'`convert_time_type` is not implemented for this adapter!')
|
"`convert_time_type` is not implemented for this adapter!"
|
||||||
|
)
|
||||||
|
|
||||||
@available
|
@available
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_type(
|
def convert_type(cls, agate_table: agate.Table, col_idx: int) -> Optional[str]:
|
||||||
cls, agate_table: agate.Table, col_idx: int
|
|
||||||
) -> Optional[str]:
|
|
||||||
return cls.convert_agate_type(agate_table, col_idx)
|
return cls.convert_agate_type(agate_table, col_idx)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -963,7 +951,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
:param release: Ignored.
|
:param release: Ignored.
|
||||||
"""
|
"""
|
||||||
if release is not False:
|
if release is not False:
|
||||||
deprecations.warn('execute-macro-release')
|
deprecations.warn("execute-macro-release")
|
||||||
if kwargs is None:
|
if kwargs is None:
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if context_override is None:
|
if context_override is None:
|
||||||
@@ -977,28 +965,27 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
)
|
)
|
||||||
if macro is None:
|
if macro is None:
|
||||||
if project is None:
|
if project is None:
|
||||||
package_name = 'any package'
|
package_name = "any package"
|
||||||
else:
|
else:
|
||||||
package_name = 'the "{}" package'.format(project)
|
package_name = 'the "{}" package'.format(project)
|
||||||
|
|
||||||
raise RuntimeException(
|
raise RuntimeException(
|
||||||
'dbt could not find a macro with the name "{}" in {}'
|
'dbt could not find a macro with the name "{}" in {}'.format(
|
||||||
.format(macro_name, package_name)
|
macro_name, package_name
|
||||||
|
)
|
||||||
)
|
)
|
||||||
# This causes a reference cycle, as generate_runtime_macro()
|
# This causes a reference cycle, as generate_runtime_macro()
|
||||||
# ends up calling get_adapter, so the import has to be here.
|
# ends up calling get_adapter, so the import has to be here.
|
||||||
from dbt.context.providers import generate_runtime_macro
|
from dbt.context.providers import generate_runtime_macro
|
||||||
|
|
||||||
macro_context = generate_runtime_macro(
|
macro_context = generate_runtime_macro(
|
||||||
macro=macro,
|
macro=macro, config=self.config, manifest=manifest, package_name=project
|
||||||
config=self.config,
|
|
||||||
manifest=manifest,
|
|
||||||
package_name=project
|
|
||||||
)
|
)
|
||||||
macro_context.update(context_override)
|
macro_context.update(context_override)
|
||||||
|
|
||||||
macro_function = MacroGenerator(macro, macro_context)
|
macro_function = MacroGenerator(macro, macro_context)
|
||||||
|
|
||||||
with self.connections.exception_handler(f'macro {macro_name}'):
|
with self.connections.exception_handler(f"macro {macro_name}"):
|
||||||
result = macro_function(**kwargs)
|
result = macro_function(**kwargs)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -1013,7 +1000,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
table = table_from_rows(
|
table = table_from_rows(
|
||||||
table.rows,
|
table.rows,
|
||||||
table.column_names,
|
table.column_names,
|
||||||
text_only_columns=['table_database', 'table_schema', 'table_name']
|
text_only_columns=["table_database", "table_schema", "table_name"],
|
||||||
)
|
)
|
||||||
return table.where(_catalog_filter_schemas(manifest))
|
return table.where(_catalog_filter_schemas(manifest))
|
||||||
|
|
||||||
@@ -1024,10 +1011,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
manifest: Manifest,
|
manifest: Manifest,
|
||||||
) -> agate.Table:
|
) -> agate.Table:
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {"information_schema": information_schema, "schemas": schemas}
|
||||||
'information_schema': information_schema,
|
|
||||||
'schemas': schemas
|
|
||||||
}
|
|
||||||
table = self.execute_macro(
|
table = self.execute_macro(
|
||||||
GET_CATALOG_MACRO_NAME,
|
GET_CATALOG_MACRO_NAME,
|
||||||
kwargs=kwargs,
|
kwargs=kwargs,
|
||||||
@@ -1039,9 +1023,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
results = self._catalog_filter_table(table, manifest)
|
results = self._catalog_filter_table(table, manifest)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def get_catalog(
|
def get_catalog(self, manifest: Manifest) -> Tuple[agate.Table, List[Exception]]:
|
||||||
self, manifest: Manifest
|
|
||||||
) -> Tuple[agate.Table, List[Exception]]:
|
|
||||||
schema_map = self._get_catalog_schemas(manifest)
|
schema_map = self._get_catalog_schemas(manifest)
|
||||||
|
|
||||||
with executor(self.config) as tpe:
|
with executor(self.config) as tpe:
|
||||||
@@ -1049,14 +1031,10 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
for info, schemas in schema_map.items():
|
for info, schemas in schema_map.items():
|
||||||
if len(schemas) == 0:
|
if len(schemas) == 0:
|
||||||
continue
|
continue
|
||||||
name = '.'.join([
|
name = ".".join([str(info.database), "information_schema"])
|
||||||
str(info.database),
|
|
||||||
'information_schema'
|
|
||||||
])
|
|
||||||
|
|
||||||
fut = tpe.submit_connected(
|
fut = tpe.submit_connected(
|
||||||
self, name,
|
self, name, self._get_one_catalog, info, schemas, manifest
|
||||||
self._get_one_catalog, info, schemas, manifest
|
|
||||||
)
|
)
|
||||||
futures.append(fut)
|
futures.append(fut)
|
||||||
|
|
||||||
@@ -1073,20 +1051,18 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
source: BaseRelation,
|
source: BaseRelation,
|
||||||
loaded_at_field: str,
|
loaded_at_field: str,
|
||||||
filter: Optional[str],
|
filter: Optional[str],
|
||||||
manifest: Optional[Manifest] = None
|
manifest: Optional[Manifest] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Calculate the freshness of sources in dbt, and return it"""
|
"""Calculate the freshness of sources in dbt, and return it"""
|
||||||
kwargs: Dict[str, Any] = {
|
kwargs: Dict[str, Any] = {
|
||||||
'source': source,
|
"source": source,
|
||||||
'loaded_at_field': loaded_at_field,
|
"loaded_at_field": loaded_at_field,
|
||||||
'filter': filter,
|
"filter": filter,
|
||||||
}
|
}
|
||||||
|
|
||||||
# run the macro
|
# run the macro
|
||||||
table = self.execute_macro(
|
table = self.execute_macro(
|
||||||
FRESHNESS_MACRO_NAME,
|
FRESHNESS_MACRO_NAME, kwargs=kwargs, manifest=manifest
|
||||||
kwargs=kwargs,
|
|
||||||
manifest=manifest
|
|
||||||
)
|
)
|
||||||
# now we have a 1-row table of the maximum `loaded_at_field` value and
|
# now we have a 1-row table of the maximum `loaded_at_field` value and
|
||||||
# the current time according to the db.
|
# the current time according to the db.
|
||||||
@@ -1106,9 +1082,9 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
snapshotted_at = _utc(table[0][1], source, loaded_at_field)
|
snapshotted_at = _utc(table[0][1], source, loaded_at_field)
|
||||||
age = (snapshotted_at - max_loaded_at).total_seconds()
|
age = (snapshotted_at - max_loaded_at).total_seconds()
|
||||||
return {
|
return {
|
||||||
'max_loaded_at': max_loaded_at,
|
"max_loaded_at": max_loaded_at,
|
||||||
'snapshotted_at': snapshotted_at,
|
"snapshotted_at": snapshotted_at,
|
||||||
'age': age,
|
"age": age,
|
||||||
}
|
}
|
||||||
|
|
||||||
def pre_model_hook(self, config: Mapping[str, Any]) -> Any:
|
def pre_model_hook(self, config: Mapping[str, Any]) -> Any:
|
||||||
@@ -1138,6 +1114,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
|
|
||||||
def get_compiler(self):
|
def get_compiler(self):
|
||||||
from dbt.compilation import Compiler
|
from dbt.compilation import Compiler
|
||||||
|
|
||||||
return Compiler(self.config)
|
return Compiler(self.config)
|
||||||
|
|
||||||
# Methods used in adapter tests
|
# Methods used in adapter tests
|
||||||
@@ -1148,13 +1125,13 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
clause: str,
|
clause: str,
|
||||||
where_clause: Optional[str] = None,
|
where_clause: Optional[str] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
clause = f'update {dst_name} set {dst_column} = {clause}'
|
clause = f"update {dst_name} set {dst_column} = {clause}"
|
||||||
if where_clause is not None:
|
if where_clause is not None:
|
||||||
clause += f' where {where_clause}'
|
clause += f" where {where_clause}"
|
||||||
return clause
|
return clause
|
||||||
|
|
||||||
def timestamp_add_sql(
|
def timestamp_add_sql(
|
||||||
self, add_to: str, number: int = 1, interval: str = 'hour'
|
self, add_to: str, number: int = 1, interval: str = "hour"
|
||||||
) -> str:
|
) -> str:
|
||||||
# for backwards compatibility, we're compelled to set some sort of
|
# for backwards compatibility, we're compelled to set some sort of
|
||||||
# default. A lot of searching has lead me to believe that the
|
# default. A lot of searching has lead me to believe that the
|
||||||
@@ -1163,23 +1140,24 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
return f"{add_to} + interval '{number} {interval}'"
|
return f"{add_to} + interval '{number} {interval}'"
|
||||||
|
|
||||||
def string_add_sql(
|
def string_add_sql(
|
||||||
self, add_to: str, value: str, location='append',
|
self,
|
||||||
|
add_to: str,
|
||||||
|
value: str,
|
||||||
|
location="append",
|
||||||
) -> str:
|
) -> str:
|
||||||
if location == 'append':
|
if location == "append":
|
||||||
return f"{add_to} || '{value}'"
|
return f"{add_to} || '{value}'"
|
||||||
elif location == 'prepend':
|
elif location == "prepend":
|
||||||
return f"'{value}' || {add_to}"
|
return f"'{value}' || {add_to}"
|
||||||
else:
|
else:
|
||||||
raise RuntimeException(
|
raise RuntimeException(f'Got an unexpected location value of "{location}"')
|
||||||
f'Got an unexpected location value of "{location}"'
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_rows_different_sql(
|
def get_rows_different_sql(
|
||||||
self,
|
self,
|
||||||
relation_a: BaseRelation,
|
relation_a: BaseRelation,
|
||||||
relation_b: BaseRelation,
|
relation_b: BaseRelation,
|
||||||
column_names: Optional[List[str]] = None,
|
column_names: Optional[List[str]] = None,
|
||||||
except_operator: str = 'EXCEPT',
|
except_operator: str = "EXCEPT",
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generate SQL for a query that returns a single row with a two
|
"""Generate SQL for a query that returns a single row with a two
|
||||||
columns: the number of rows that are different between the two
|
columns: the number of rows that are different between the two
|
||||||
@@ -1192,7 +1170,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
names = sorted((self.quote(c.name) for c in columns))
|
names = sorted((self.quote(c.name) for c in columns))
|
||||||
else:
|
else:
|
||||||
names = sorted((self.quote(n) for n in column_names))
|
names = sorted((self.quote(n) for n in column_names))
|
||||||
columns_csv = ', '.join(names)
|
columns_csv = ", ".join(names)
|
||||||
|
|
||||||
sql = COLUMNS_EQUAL_SQL.format(
|
sql = COLUMNS_EQUAL_SQL.format(
|
||||||
columns=columns_csv,
|
columns=columns_csv,
|
||||||
@@ -1204,7 +1182,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
return sql
|
return sql
|
||||||
|
|
||||||
|
|
||||||
COLUMNS_EQUAL_SQL = '''
|
COLUMNS_EQUAL_SQL = """
|
||||||
with diff_count as (
|
with diff_count as (
|
||||||
SELECT
|
SELECT
|
||||||
1 as id,
|
1 as id,
|
||||||
@@ -1230,11 +1208,11 @@ select
|
|||||||
diff_count.num_missing as num_mismatched
|
diff_count.num_missing as num_mismatched
|
||||||
from row_count_diff
|
from row_count_diff
|
||||||
join diff_count using (id)
|
join diff_count using (id)
|
||||||
'''.strip()
|
""".strip()
|
||||||
|
|
||||||
|
|
||||||
def catch_as_completed(
|
def catch_as_completed(
|
||||||
futures # typing: List[Future[agate.Table]]
|
futures, # typing: List[Future[agate.Table]]
|
||||||
) -> Tuple[agate.Table, List[Exception]]:
|
) -> Tuple[agate.Table, List[Exception]]:
|
||||||
|
|
||||||
# catalogs: agate.Table = agate.Table(rows=[])
|
# catalogs: agate.Table = agate.Table(rows=[])
|
||||||
@@ -1247,15 +1225,10 @@ def catch_as_completed(
|
|||||||
if exc is None:
|
if exc is None:
|
||||||
catalog = future.result()
|
catalog = future.result()
|
||||||
tables.append(catalog)
|
tables.append(catalog)
|
||||||
elif (
|
elif isinstance(exc, KeyboardInterrupt) or not isinstance(exc, Exception):
|
||||||
isinstance(exc, KeyboardInterrupt) or
|
|
||||||
not isinstance(exc, Exception)
|
|
||||||
):
|
|
||||||
raise exc
|
raise exc
|
||||||
else:
|
else:
|
||||||
warn_or_error(
|
warn_or_error(f"Encountered an error while generating catalog: {str(exc)}")
|
||||||
f'Encountered an error while generating catalog: {str(exc)}'
|
|
||||||
)
|
|
||||||
# exc is not None, derives from Exception, and isn't ctrl+c
|
# exc is not None, derives from Exception, and isn't ctrl+c
|
||||||
exceptions.append(exc)
|
exceptions.append(exc)
|
||||||
return merge_tables(tables), exceptions
|
return merge_tables(tables), exceptions
|
||||||
|
|||||||
@@ -30,9 +30,11 @@ class _Available:
|
|||||||
x.update(big_expensive_db_query())
|
x.update(big_expensive_db_query())
|
||||||
return x
|
return x
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def inner(func):
|
def inner(func):
|
||||||
func._parse_replacement_ = parse_replacement
|
func._parse_replacement_ = parse_replacement
|
||||||
return self(func)
|
return self(func)
|
||||||
|
|
||||||
return inner
|
return inner
|
||||||
|
|
||||||
def deprecated(
|
def deprecated(
|
||||||
@@ -57,13 +59,14 @@ class _Available:
|
|||||||
The optional parse_replacement, if provided, will provide a parse-time
|
The optional parse_replacement, if provided, will provide a parse-time
|
||||||
replacement for the actual method (see `available.parse`).
|
replacement for the actual method (see `available.parse`).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrapper(func):
|
def wrapper(func):
|
||||||
func_name = func.__name__
|
func_name = func.__name__
|
||||||
renamed_method(func_name, supported_name)
|
renamed_method(func_name, supported_name)
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def inner(*args, **kwargs):
|
def inner(*args, **kwargs):
|
||||||
warn('adapter:{}'.format(func_name))
|
warn("adapter:{}".format(func_name))
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
if parse_replacement:
|
if parse_replacement:
|
||||||
@@ -71,6 +74,7 @@ class _Available:
|
|||||||
else:
|
else:
|
||||||
available_function = self
|
available_function = self
|
||||||
return available_function(inner)
|
return available_function(inner)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
def parse_none(self, func: Callable) -> Callable:
|
def parse_none(self, func: Callable) -> Callable:
|
||||||
@@ -109,14 +113,14 @@ class AdapterMeta(abc.ABCMeta):
|
|||||||
|
|
||||||
# collect base class data first
|
# collect base class data first
|
||||||
for base in bases:
|
for base in bases:
|
||||||
available.update(getattr(base, '_available_', set()))
|
available.update(getattr(base, "_available_", set()))
|
||||||
replacements.update(getattr(base, '_parse_replacements_', set()))
|
replacements.update(getattr(base, "_parse_replacements_", set()))
|
||||||
|
|
||||||
# override with local data if it exists
|
# override with local data if it exists
|
||||||
for name, value in namespace.items():
|
for name, value in namespace.items():
|
||||||
if getattr(value, '_is_available_', False):
|
if getattr(value, "_is_available_", False):
|
||||||
available.add(name)
|
available.add(name)
|
||||||
parse_replacement = getattr(value, '_parse_replacement_', None)
|
parse_replacement = getattr(value, "_parse_replacement_", None)
|
||||||
if parse_replacement is not None:
|
if parse_replacement is not None:
|
||||||
replacements[name] = parse_replacement
|
replacements[name] = parse_replacement
|
||||||
|
|
||||||
|
|||||||
@@ -8,11 +8,10 @@ from dbt.adapters.protocol import AdapterProtocol
|
|||||||
def project_name_from_path(include_path: str) -> str:
|
def project_name_from_path(include_path: str) -> str:
|
||||||
# avoid an import cycle
|
# avoid an import cycle
|
||||||
from dbt.config.project import Project
|
from dbt.config.project import Project
|
||||||
|
|
||||||
partial = Project.partial_load(include_path)
|
partial = Project.partial_load(include_path)
|
||||||
if partial.project_name is None:
|
if partial.project_name is None:
|
||||||
raise CompilationException(
|
raise CompilationException(f"Invalid project at {include_path}: name not set!")
|
||||||
f'Invalid project at {include_path}: name not set!'
|
|
||||||
)
|
|
||||||
return partial.project_name
|
return partial.project_name
|
||||||
|
|
||||||
|
|
||||||
@@ -23,12 +22,13 @@ class AdapterPlugin:
|
|||||||
:param dependencies: A list of adapter names that this adapter depends
|
:param dependencies: A list of adapter names that this adapter depends
|
||||||
upon.
|
upon.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
adapter: Type[AdapterProtocol],
|
adapter: Type[AdapterProtocol],
|
||||||
credentials: Type[Credentials],
|
credentials: Type[Credentials],
|
||||||
include_path: str,
|
include_path: str,
|
||||||
dependencies: Optional[List[str]] = None
|
dependencies: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
self.adapter: Type[AdapterProtocol] = adapter
|
self.adapter: Type[AdapterProtocol] = adapter
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ class NodeWrapper:
|
|||||||
self._inner_node = node
|
self._inner_node = node
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
return getattr(self._inner_node, name, '')
|
return getattr(self._inner_node, name, "")
|
||||||
|
|
||||||
|
|
||||||
class _QueryComment(local):
|
class _QueryComment(local):
|
||||||
@@ -24,6 +24,7 @@ class _QueryComment(local):
|
|||||||
- the current thread's query comment.
|
- the current thread's query comment.
|
||||||
- a source_name indicating what set the current thread's query comment
|
- a source_name indicating what set the current thread's query comment
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, initial):
|
def __init__(self, initial):
|
||||||
self.query_comment: Optional[str] = initial
|
self.query_comment: Optional[str] = initial
|
||||||
self.append = False
|
self.append = False
|
||||||
@@ -35,16 +36,16 @@ class _QueryComment(local):
|
|||||||
if self.append:
|
if self.append:
|
||||||
# replace last ';' with '<comment>;'
|
# replace last ';' with '<comment>;'
|
||||||
sql = sql.rstrip()
|
sql = sql.rstrip()
|
||||||
if sql[-1] == ';':
|
if sql[-1] == ";":
|
||||||
sql = sql[:-1]
|
sql = sql[:-1]
|
||||||
return '{}\n/* {} */;'.format(sql, self.query_comment.strip())
|
return "{}\n/* {} */;".format(sql, self.query_comment.strip())
|
||||||
|
|
||||||
return '{}\n/* {} */'.format(sql, self.query_comment.strip())
|
return "{}\n/* {} */".format(sql, self.query_comment.strip())
|
||||||
|
|
||||||
return '/* {} */\n{}'.format(self.query_comment.strip(), sql)
|
return "/* {} */\n{}".format(self.query_comment.strip(), sql)
|
||||||
|
|
||||||
def set(self, comment: Optional[str], append: bool):
|
def set(self, comment: Optional[str], append: bool):
|
||||||
if isinstance(comment, str) and '*/' in comment:
|
if isinstance(comment, str) and "*/" in comment:
|
||||||
# tell the user "no" so they don't hurt themselves by writing
|
# tell the user "no" so they don't hurt themselves by writing
|
||||||
# garbage
|
# garbage
|
||||||
raise RuntimeException(
|
raise RuntimeException(
|
||||||
@@ -63,15 +64,17 @@ class MacroQueryStringSetter:
|
|||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
comment_macro = self._get_comment_macro()
|
comment_macro = self._get_comment_macro()
|
||||||
self.generator: QueryStringFunc = lambda name, model: ''
|
self.generator: QueryStringFunc = lambda name, model: ""
|
||||||
# if the comment value was None or the empty string, just skip it
|
# if the comment value was None or the empty string, just skip it
|
||||||
if comment_macro:
|
if comment_macro:
|
||||||
assert isinstance(comment_macro, str)
|
assert isinstance(comment_macro, str)
|
||||||
macro = '\n'.join((
|
macro = "\n".join(
|
||||||
'{%- macro query_comment_macro(connection_name, node) -%}',
|
(
|
||||||
|
"{%- macro query_comment_macro(connection_name, node) -%}",
|
||||||
comment_macro,
|
comment_macro,
|
||||||
'{% endmacro %}'
|
"{% endmacro %}",
|
||||||
))
|
)
|
||||||
|
)
|
||||||
ctx = self._get_context()
|
ctx = self._get_context()
|
||||||
self.generator = QueryStringGenerator(macro, ctx)
|
self.generator = QueryStringGenerator(macro, ctx)
|
||||||
self.comment = _QueryComment(None)
|
self.comment = _QueryComment(None)
|
||||||
@@ -87,7 +90,7 @@ class MacroQueryStringSetter:
|
|||||||
return self.comment.add(sql)
|
return self.comment.add(sql)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.set('master', None)
|
self.set("master", None)
|
||||||
|
|
||||||
def set(self, name: str, node: Optional[CompileResultNode]):
|
def set(self, name: str, node: Optional[CompileResultNode]):
|
||||||
wrapped: Optional[NodeWrapper] = None
|
wrapped: Optional[NodeWrapper] = None
|
||||||
|
|||||||
@@ -1,13 +1,16 @@
|
|||||||
from collections.abc import Hashable
|
from collections.abc import Hashable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (
|
from typing import Optional, TypeVar, Any, Type, Dict, Union, Iterator, Tuple, Set
|
||||||
Optional, TypeVar, Any, Type, Dict, Union, Iterator, Tuple, Set
|
|
||||||
)
|
|
||||||
|
|
||||||
from dbt.contracts.graph.compiled import CompiledNode
|
from dbt.contracts.graph.compiled import CompiledNode
|
||||||
from dbt.contracts.graph.parsed import ParsedSourceDefinition, ParsedNode
|
from dbt.contracts.graph.parsed import ParsedSourceDefinition, ParsedNode
|
||||||
from dbt.contracts.relation import (
|
from dbt.contracts.relation import (
|
||||||
RelationType, ComponentName, HasQuoting, FakeAPIObject, Policy, Path
|
RelationType,
|
||||||
|
ComponentName,
|
||||||
|
HasQuoting,
|
||||||
|
FakeAPIObject,
|
||||||
|
Policy,
|
||||||
|
Path,
|
||||||
)
|
)
|
||||||
from dbt.exceptions import InternalException
|
from dbt.exceptions import InternalException
|
||||||
from dbt.node_types import NodeType
|
from dbt.node_types import NodeType
|
||||||
@@ -16,7 +19,7 @@ from dbt.utils import filter_null_values, deep_merge, classproperty
|
|||||||
import dbt.exceptions
|
import dbt.exceptions
|
||||||
|
|
||||||
|
|
||||||
Self = TypeVar('Self', bound='BaseRelation')
|
Self = TypeVar("Self", bound="BaseRelation")
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, eq=False, repr=False)
|
@dataclass(frozen=True, eq=False, repr=False)
|
||||||
@@ -40,7 +43,7 @@ class BaseRelation(FakeAPIObject, Hashable):
|
|||||||
if field.name == field_name:
|
if field.name == field_name:
|
||||||
return field
|
return field
|
||||||
# this should be unreachable
|
# this should be unreachable
|
||||||
raise ValueError(f'BaseRelation has no {field_name} field!')
|
raise ValueError(f"BaseRelation has no {field_name} field!")
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if not isinstance(other, self.__class__):
|
if not isinstance(other, self.__class__):
|
||||||
@@ -49,20 +52,18 @@ class BaseRelation(FakeAPIObject, Hashable):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_quote_policy(cls) -> Policy:
|
def get_default_quote_policy(cls) -> Policy:
|
||||||
return cls._get_field_named('quote_policy').default
|
return cls._get_field_named("quote_policy").default
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_include_policy(cls) -> Policy:
|
def get_default_include_policy(cls) -> Policy:
|
||||||
return cls._get_field_named('include_policy').default
|
return cls._get_field_named("include_policy").default
|
||||||
|
|
||||||
def get(self, key, default=None):
|
def get(self, key, default=None):
|
||||||
"""Override `.get` to return a metadata object so we don't break
|
"""Override `.get` to return a metadata object so we don't break
|
||||||
dbt_utils.
|
dbt_utils.
|
||||||
"""
|
"""
|
||||||
if key == 'metadata':
|
if key == "metadata":
|
||||||
return {
|
return {"type": self.__class__.__name__}
|
||||||
'type': self.__class__.__name__
|
|
||||||
}
|
|
||||||
return super().get(key, default)
|
return super().get(key, default)
|
||||||
|
|
||||||
def matches(
|
def matches(
|
||||||
@@ -71,16 +72,19 @@ class BaseRelation(FakeAPIObject, Hashable):
|
|||||||
schema: Optional[str] = None,
|
schema: Optional[str] = None,
|
||||||
identifier: Optional[str] = None,
|
identifier: Optional[str] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
search = filter_null_values({
|
search = filter_null_values(
|
||||||
|
{
|
||||||
ComponentName.Database: database,
|
ComponentName.Database: database,
|
||||||
ComponentName.Schema: schema,
|
ComponentName.Schema: schema,
|
||||||
ComponentName.Identifier: identifier
|
ComponentName.Identifier: identifier,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
if not search:
|
if not search:
|
||||||
# nothing was passed in
|
# nothing was passed in
|
||||||
raise dbt.exceptions.RuntimeException(
|
raise dbt.exceptions.RuntimeException(
|
||||||
"Tried to match relation, but no search path was passed!")
|
"Tried to match relation, but no search path was passed!"
|
||||||
|
)
|
||||||
|
|
||||||
exact_match = True
|
exact_match = True
|
||||||
approximate_match = True
|
approximate_match = True
|
||||||
@@ -109,11 +113,13 @@ class BaseRelation(FakeAPIObject, Hashable):
|
|||||||
schema: Optional[bool] = None,
|
schema: Optional[bool] = None,
|
||||||
identifier: Optional[bool] = None,
|
identifier: Optional[bool] = None,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
policy = filter_null_values({
|
policy = filter_null_values(
|
||||||
|
{
|
||||||
ComponentName.Database: database,
|
ComponentName.Database: database,
|
||||||
ComponentName.Schema: schema,
|
ComponentName.Schema: schema,
|
||||||
ComponentName.Identifier: identifier
|
ComponentName.Identifier: identifier,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
new_quote_policy = self.quote_policy.replace_dict(policy)
|
new_quote_policy = self.quote_policy.replace_dict(policy)
|
||||||
return self.replace(quote_policy=new_quote_policy)
|
return self.replace(quote_policy=new_quote_policy)
|
||||||
@@ -124,16 +130,18 @@ class BaseRelation(FakeAPIObject, Hashable):
|
|||||||
schema: Optional[bool] = None,
|
schema: Optional[bool] = None,
|
||||||
identifier: Optional[bool] = None,
|
identifier: Optional[bool] = None,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
policy = filter_null_values({
|
policy = filter_null_values(
|
||||||
|
{
|
||||||
ComponentName.Database: database,
|
ComponentName.Database: database,
|
||||||
ComponentName.Schema: schema,
|
ComponentName.Schema: schema,
|
||||||
ComponentName.Identifier: identifier
|
ComponentName.Identifier: identifier,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
new_include_policy = self.include_policy.replace_dict(policy)
|
new_include_policy = self.include_policy.replace_dict(policy)
|
||||||
return self.replace(include_policy=new_include_policy)
|
return self.replace(include_policy=new_include_policy)
|
||||||
|
|
||||||
def information_schema(self, view_name=None) -> 'InformationSchema':
|
def information_schema(self, view_name=None) -> "InformationSchema":
|
||||||
# some of our data comes from jinja, where things can be `Undefined`.
|
# some of our data comes from jinja, where things can be `Undefined`.
|
||||||
if not isinstance(view_name, str):
|
if not isinstance(view_name, str):
|
||||||
view_name = None
|
view_name = None
|
||||||
@@ -143,10 +151,10 @@ class BaseRelation(FakeAPIObject, Hashable):
|
|||||||
info_schema = InformationSchema.from_relation(self, view_name)
|
info_schema = InformationSchema.from_relation(self, view_name)
|
||||||
return info_schema.incorporate(path={"schema": None})
|
return info_schema.incorporate(path={"schema": None})
|
||||||
|
|
||||||
def information_schema_only(self) -> 'InformationSchema':
|
def information_schema_only(self) -> "InformationSchema":
|
||||||
return self.information_schema()
|
return self.information_schema()
|
||||||
|
|
||||||
def without_identifier(self) -> 'BaseRelation':
|
def without_identifier(self) -> "BaseRelation":
|
||||||
"""Return a form of this relation that only has the database and schema
|
"""Return a form of this relation that only has the database and schema
|
||||||
set to included. To get the appropriately-quoted form the schema out of
|
set to included. To get the appropriately-quoted form the schema out of
|
||||||
the result (for use as part of a query), use `.render()`. To get the
|
the result (for use as part of a query), use `.render()`. To get the
|
||||||
@@ -157,7 +165,7 @@ class BaseRelation(FakeAPIObject, Hashable):
|
|||||||
return self.include(identifier=False).replace_path(identifier=None)
|
return self.include(identifier=False).replace_path(identifier=None)
|
||||||
|
|
||||||
def _render_iterator(
|
def _render_iterator(
|
||||||
self
|
self,
|
||||||
) -> Iterator[Tuple[Optional[ComponentName], Optional[str]]]:
|
) -> Iterator[Tuple[Optional[ComponentName], Optional[str]]]:
|
||||||
|
|
||||||
for key in ComponentName:
|
for key in ComponentName:
|
||||||
@@ -170,13 +178,10 @@ class BaseRelation(FakeAPIObject, Hashable):
|
|||||||
|
|
||||||
def render(self) -> str:
|
def render(self) -> str:
|
||||||
# if there is nothing set, this will return the empty string.
|
# if there is nothing set, this will return the empty string.
|
||||||
return '.'.join(
|
return ".".join(part for _, part in self._render_iterator() if part is not None)
|
||||||
part for _, part in self._render_iterator()
|
|
||||||
if part is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
def quoted(self, identifier):
|
def quoted(self, identifier):
|
||||||
return '{quote_char}{identifier}{quote_char}'.format(
|
return "{quote_char}{identifier}{quote_char}".format(
|
||||||
quote_char=self.quote_character,
|
quote_char=self.quote_character,
|
||||||
identifier=identifier,
|
identifier=identifier,
|
||||||
)
|
)
|
||||||
@@ -186,11 +191,11 @@ class BaseRelation(FakeAPIObject, Hashable):
|
|||||||
cls: Type[Self], source: ParsedSourceDefinition, **kwargs: Any
|
cls: Type[Self], source: ParsedSourceDefinition, **kwargs: Any
|
||||||
) -> Self:
|
) -> Self:
|
||||||
source_quoting = source.quoting.to_dict(omit_none=True)
|
source_quoting = source.quoting.to_dict(omit_none=True)
|
||||||
source_quoting.pop('column', None)
|
source_quoting.pop("column", None)
|
||||||
quote_policy = deep_merge(
|
quote_policy = deep_merge(
|
||||||
cls.get_default_quote_policy().to_dict(omit_none=True),
|
cls.get_default_quote_policy().to_dict(omit_none=True),
|
||||||
source_quoting,
|
source_quoting,
|
||||||
kwargs.get('quote_policy', {}),
|
kwargs.get("quote_policy", {}),
|
||||||
)
|
)
|
||||||
|
|
||||||
return cls.create(
|
return cls.create(
|
||||||
@@ -198,12 +203,12 @@ class BaseRelation(FakeAPIObject, Hashable):
|
|||||||
schema=source.schema,
|
schema=source.schema,
|
||||||
identifier=source.identifier,
|
identifier=source.identifier,
|
||||||
quote_policy=quote_policy,
|
quote_policy=quote_policy,
|
||||||
**kwargs
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_ephemeral_prefix(name: str):
|
def add_ephemeral_prefix(name: str):
|
||||||
return f'__dbt__cte__{name}'
|
return f"__dbt__cte__{name}"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_ephemeral_from_node(
|
def create_ephemeral_from_node(
|
||||||
@@ -236,7 +241,8 @@ class BaseRelation(FakeAPIObject, Hashable):
|
|||||||
schema=node.schema,
|
schema=node.schema,
|
||||||
identifier=node.alias,
|
identifier=node.alias,
|
||||||
quote_policy=quote_policy,
|
quote_policy=quote_policy,
|
||||||
**kwargs)
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_from(
|
def create_from(
|
||||||
@@ -248,15 +254,16 @@ class BaseRelation(FakeAPIObject, Hashable):
|
|||||||
if node.resource_type == NodeType.Source:
|
if node.resource_type == NodeType.Source:
|
||||||
if not isinstance(node, ParsedSourceDefinition):
|
if not isinstance(node, ParsedSourceDefinition):
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
'type mismatch, expected ParsedSourceDefinition but got {}'
|
"type mismatch, expected ParsedSourceDefinition but got {}".format(
|
||||||
.format(type(node))
|
type(node)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return cls.create_from_source(node, **kwargs)
|
return cls.create_from_source(node, **kwargs)
|
||||||
else:
|
else:
|
||||||
if not isinstance(node, (ParsedNode, CompiledNode)):
|
if not isinstance(node, (ParsedNode, CompiledNode)):
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
'type mismatch, expected ParsedNode or CompiledNode but '
|
"type mismatch, expected ParsedNode or CompiledNode but "
|
||||||
'got {}'.format(type(node))
|
"got {}".format(type(node))
|
||||||
)
|
)
|
||||||
return cls.create_from_node(config, node, **kwargs)
|
return cls.create_from_node(config, node, **kwargs)
|
||||||
|
|
||||||
@@ -269,14 +276,16 @@ class BaseRelation(FakeAPIObject, Hashable):
|
|||||||
type: Optional[RelationType] = None,
|
type: Optional[RelationType] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
kwargs.update({
|
kwargs.update(
|
||||||
'path': {
|
{
|
||||||
'database': database,
|
"path": {
|
||||||
'schema': schema,
|
"database": database,
|
||||||
'identifier': identifier,
|
"schema": schema,
|
||||||
|
"identifier": identifier,
|
||||||
},
|
},
|
||||||
'type': type,
|
"type": type,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
return cls.from_dict(kwargs)
|
return cls.from_dict(kwargs)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
@@ -342,7 +351,7 @@ class BaseRelation(FakeAPIObject, Hashable):
|
|||||||
return RelationType
|
return RelationType
|
||||||
|
|
||||||
|
|
||||||
Info = TypeVar('Info', bound='InformationSchema')
|
Info = TypeVar("Info", bound="InformationSchema")
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, eq=False, repr=False)
|
@dataclass(frozen=True, eq=False, repr=False)
|
||||||
@@ -352,7 +361,7 @@ class InformationSchema(BaseRelation):
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if not isinstance(self.information_schema_view, (type(None), str)):
|
if not isinstance(self.information_schema_view, (type(None), str)):
|
||||||
raise dbt.exceptions.CompilationException(
|
raise dbt.exceptions.CompilationException(
|
||||||
'Got an invalid name: {}'.format(self.information_schema_view)
|
"Got an invalid name: {}".format(self.information_schema_view)
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -362,7 +371,7 @@ class InformationSchema(BaseRelation):
|
|||||||
return Path(
|
return Path(
|
||||||
database=relation.database,
|
database=relation.database,
|
||||||
schema=relation.schema,
|
schema=relation.schema,
|
||||||
identifier='INFORMATION_SCHEMA',
|
identifier="INFORMATION_SCHEMA",
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -393,9 +402,7 @@ class InformationSchema(BaseRelation):
|
|||||||
relation: BaseRelation,
|
relation: BaseRelation,
|
||||||
information_schema_view: Optional[str],
|
information_schema_view: Optional[str],
|
||||||
) -> Info:
|
) -> Info:
|
||||||
include_policy = cls.get_include_policy(
|
include_policy = cls.get_include_policy(relation, information_schema_view)
|
||||||
relation, information_schema_view
|
|
||||||
)
|
|
||||||
quote_policy = cls.get_quote_policy(relation, information_schema_view)
|
quote_policy = cls.get_quote_policy(relation, information_schema_view)
|
||||||
path = cls.get_path(relation, information_schema_view)
|
path = cls.get_path(relation, information_schema_view)
|
||||||
return cls(
|
return cls(
|
||||||
@@ -417,6 +424,7 @@ class SchemaSearchMap(Dict[InformationSchema, Set[Optional[str]]]):
|
|||||||
search for what schemas. The schema values are all lowercased to avoid
|
search for what schemas. The schema values are all lowercased to avoid
|
||||||
duplication.
|
duplication.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def add(self, relation: BaseRelation):
|
def add(self, relation: BaseRelation):
|
||||||
key = relation.information_schema_only()
|
key = relation.information_schema_only()
|
||||||
if key not in self:
|
if key not in self:
|
||||||
@@ -426,9 +434,7 @@ class SchemaSearchMap(Dict[InformationSchema, Set[Optional[str]]]):
|
|||||||
schema = relation.schema.lower()
|
schema = relation.schema.lower()
|
||||||
self[key].add(schema)
|
self[key].add(schema)
|
||||||
|
|
||||||
def search(
|
def search(self) -> Iterator[Tuple[InformationSchema, Optional[str]]]:
|
||||||
self
|
|
||||||
) -> Iterator[Tuple[InformationSchema, Optional[str]]]:
|
|
||||||
for information_schema_name, schemas in self.items():
|
for information_schema_name, schemas in self.items():
|
||||||
for schema in schemas:
|
for schema in schemas:
|
||||||
yield information_schema_name, schema
|
yield information_schema_name, schema
|
||||||
@@ -442,14 +448,13 @@ class SchemaSearchMap(Dict[InformationSchema, Set[Optional[str]]]):
|
|||||||
dbt.exceptions.raise_compiler_error(str(seen))
|
dbt.exceptions.raise_compiler_error(str(seen))
|
||||||
|
|
||||||
for information_schema_name, schema in self.search():
|
for information_schema_name, schema in self.search():
|
||||||
path = {
|
path = {"database": information_schema_name.database, "schema": schema}
|
||||||
'database': information_schema_name.database,
|
new.add(
|
||||||
'schema': schema
|
information_schema_name.incorporate(
|
||||||
}
|
|
||||||
new.add(information_schema_name.incorporate(
|
|
||||||
path=path,
|
path=path,
|
||||||
quote_policy={'database': False},
|
quote_policy={"database": False},
|
||||||
include_policy={'database': False},
|
include_policy={"database": False},
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return new
|
return new
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from dbt.logger import CACHE_LOGGER as logger
|
|||||||
from dbt.utils import lowercase
|
from dbt.utils import lowercase
|
||||||
import dbt.exceptions
|
import dbt.exceptions
|
||||||
|
|
||||||
_ReferenceKey = namedtuple('_ReferenceKey', 'database schema identifier')
|
_ReferenceKey = namedtuple("_ReferenceKey", "database schema identifier")
|
||||||
|
|
||||||
|
|
||||||
def _make_key(relation) -> _ReferenceKey:
|
def _make_key(relation) -> _ReferenceKey:
|
||||||
@@ -15,9 +15,11 @@ def _make_key(relation) -> _ReferenceKey:
|
|||||||
to keep track of quoting
|
to keep track of quoting
|
||||||
"""
|
"""
|
||||||
# databases and schemas can both be None
|
# databases and schemas can both be None
|
||||||
return _ReferenceKey(lowercase(relation.database),
|
return _ReferenceKey(
|
||||||
|
lowercase(relation.database),
|
||||||
lowercase(relation.schema),
|
lowercase(relation.schema),
|
||||||
lowercase(relation.identifier))
|
lowercase(relation.identifier),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def dot_separated(key: _ReferenceKey) -> str:
|
def dot_separated(key: _ReferenceKey) -> str:
|
||||||
@@ -25,7 +27,7 @@ def dot_separated(key: _ReferenceKey) -> str:
|
|||||||
|
|
||||||
:param _ReferenceKey key: The key to stringify.
|
:param _ReferenceKey key: The key to stringify.
|
||||||
"""
|
"""
|
||||||
return '.'.join(map(str, key))
|
return ".".join(map(str, key))
|
||||||
|
|
||||||
|
|
||||||
class _CachedRelation:
|
class _CachedRelation:
|
||||||
@@ -37,13 +39,14 @@ class _CachedRelation:
|
|||||||
that refer to this relation.
|
that refer to this relation.
|
||||||
:attr BaseRelation inner: The underlying dbt relation.
|
:attr BaseRelation inner: The underlying dbt relation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, inner):
|
def __init__(self, inner):
|
||||||
self.referenced_by = {}
|
self.referenced_by = {}
|
||||||
self.inner = inner
|
self.inner = inner
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return (
|
return (
|
||||||
'_CachedRelation(database={}, schema={}, identifier={}, inner={})'
|
"_CachedRelation(database={}, schema={}, identifier={}, inner={})"
|
||||||
).format(self.database, self.schema, self.identifier, self.inner)
|
).format(self.database, self.schema, self.identifier, self.inner)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -78,7 +81,7 @@ class _CachedRelation:
|
|||||||
"""
|
"""
|
||||||
return _make_key(self)
|
return _make_key(self)
|
||||||
|
|
||||||
def add_reference(self, referrer: '_CachedRelation'):
|
def add_reference(self, referrer: "_CachedRelation"):
|
||||||
"""Add a reference from referrer to self, indicating that if this node
|
"""Add a reference from referrer to self, indicating that if this node
|
||||||
were drop...cascaded, the referrer would be dropped as well.
|
were drop...cascaded, the referrer would be dropped as well.
|
||||||
|
|
||||||
@@ -122,9 +125,9 @@ class _CachedRelation:
|
|||||||
# table_name is ever anything but the identifier (via .create())
|
# table_name is ever anything but the identifier (via .create())
|
||||||
self.inner = self.inner.incorporate(
|
self.inner = self.inner.incorporate(
|
||||||
path={
|
path={
|
||||||
'database': new_relation.inner.database,
|
"database": new_relation.inner.database,
|
||||||
'schema': new_relation.inner.schema,
|
"schema": new_relation.inner.schema,
|
||||||
'identifier': new_relation.inner.identifier
|
"identifier": new_relation.inner.identifier,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -140,8 +143,9 @@ class _CachedRelation:
|
|||||||
"""
|
"""
|
||||||
if new_key in self.referenced_by:
|
if new_key in self.referenced_by:
|
||||||
dbt.exceptions.raise_cache_inconsistent(
|
dbt.exceptions.raise_cache_inconsistent(
|
||||||
'in rename of "{}" -> "{}", new name is in the cache already'
|
'in rename of "{}" -> "{}", new name is in the cache already'.format(
|
||||||
.format(old_key, new_key)
|
old_key, new_key
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if old_key not in self.referenced_by:
|
if old_key not in self.referenced_by:
|
||||||
@@ -172,13 +176,16 @@ class RelationsCache:
|
|||||||
The adapters also hold this lock while filling the cache.
|
The adapters also hold this lock while filling the cache.
|
||||||
:attr Set[str] schemas: The set of known/cached schemas, all lowercased.
|
:attr Set[str] schemas: The set of known/cached schemas, all lowercased.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.relations: Dict[_ReferenceKey, _CachedRelation] = {}
|
self.relations: Dict[_ReferenceKey, _CachedRelation] = {}
|
||||||
self.lock = threading.RLock()
|
self.lock = threading.RLock()
|
||||||
self.schemas: Set[Tuple[Optional[str], Optional[str]]] = set()
|
self.schemas: Set[Tuple[Optional[str], Optional[str]]] = set()
|
||||||
|
|
||||||
def add_schema(
|
def add_schema(
|
||||||
self, database: Optional[str], schema: Optional[str],
|
self,
|
||||||
|
database: Optional[str],
|
||||||
|
schema: Optional[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add a schema to the set of known schemas (case-insensitive)
|
"""Add a schema to the set of known schemas (case-insensitive)
|
||||||
|
|
||||||
@@ -188,7 +195,9 @@ class RelationsCache:
|
|||||||
self.schemas.add((lowercase(database), lowercase(schema)))
|
self.schemas.add((lowercase(database), lowercase(schema)))
|
||||||
|
|
||||||
def drop_schema(
|
def drop_schema(
|
||||||
self, database: Optional[str], schema: Optional[str],
|
self,
|
||||||
|
database: Optional[str],
|
||||||
|
schema: Optional[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Drop the given schema and remove it from the set of known schemas.
|
"""Drop the given schema and remove it from the set of known schemas.
|
||||||
|
|
||||||
@@ -263,15 +272,15 @@ class RelationsCache:
|
|||||||
return
|
return
|
||||||
if referenced is None:
|
if referenced is None:
|
||||||
dbt.exceptions.raise_cache_inconsistent(
|
dbt.exceptions.raise_cache_inconsistent(
|
||||||
'in add_link, referenced link key {} not in cache!'
|
"in add_link, referenced link key {} not in cache!".format(
|
||||||
.format(referenced_key)
|
referenced_key
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
dependent = self.relations.get(dependent_key)
|
dependent = self.relations.get(dependent_key)
|
||||||
if dependent is None:
|
if dependent is None:
|
||||||
dbt.exceptions.raise_cache_inconsistent(
|
dbt.exceptions.raise_cache_inconsistent(
|
||||||
'in add_link, dependent link key {} not in cache!'
|
"in add_link, dependent link key {} not in cache!".format(dependent_key)
|
||||||
.format(dependent_key)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert dependent is not None # we just raised!
|
assert dependent is not None # we just raised!
|
||||||
@@ -298,28 +307,23 @@ class RelationsCache:
|
|||||||
# referring to a table outside our control. There's no need to make
|
# referring to a table outside our control. There's no need to make
|
||||||
# a link - we will never drop the referenced relation during a run.
|
# a link - we will never drop the referenced relation during a run.
|
||||||
logger.debug(
|
logger.debug(
|
||||||
'{dep!s} references {ref!s} but {ref.database}.{ref.schema} '
|
"{dep!s} references {ref!s} but {ref.database}.{ref.schema} "
|
||||||
'is not in the cache, skipping assumed external relation'
|
"is not in the cache, skipping assumed external relation".format(
|
||||||
.format(dep=dependent, ref=ref_key)
|
dep=dependent, ref=ref_key
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
if ref_key not in self.relations:
|
if ref_key not in self.relations:
|
||||||
# Insert a dummy "external" relation.
|
# Insert a dummy "external" relation.
|
||||||
referenced = referenced.replace(
|
referenced = referenced.replace(type=referenced.External)
|
||||||
type=referenced.External
|
|
||||||
)
|
|
||||||
self.add(referenced)
|
self.add(referenced)
|
||||||
|
|
||||||
dep_key = _make_key(dependent)
|
dep_key = _make_key(dependent)
|
||||||
if dep_key not in self.relations:
|
if dep_key not in self.relations:
|
||||||
# Insert a dummy "external" relation.
|
# Insert a dummy "external" relation.
|
||||||
dependent = dependent.replace(
|
dependent = dependent.replace(type=referenced.External)
|
||||||
type=referenced.External
|
|
||||||
)
|
|
||||||
self.add(dependent)
|
self.add(dependent)
|
||||||
logger.debug(
|
logger.debug("adding link, {!s} references {!s}".format(dep_key, ref_key))
|
||||||
'adding link, {!s} references {!s}'.format(dep_key, ref_key)
|
|
||||||
)
|
|
||||||
with self.lock:
|
with self.lock:
|
||||||
self._add_link(ref_key, dep_key)
|
self._add_link(ref_key, dep_key)
|
||||||
|
|
||||||
@@ -330,14 +334,14 @@ class RelationsCache:
|
|||||||
:param BaseRelation relation: The underlying relation.
|
:param BaseRelation relation: The underlying relation.
|
||||||
"""
|
"""
|
||||||
cached = _CachedRelation(relation)
|
cached = _CachedRelation(relation)
|
||||||
logger.debug('Adding relation: {!s}'.format(cached))
|
logger.debug("Adding relation: {!s}".format(cached))
|
||||||
|
|
||||||
lazy_log('before adding: {!s}', self.dump_graph)
|
lazy_log("before adding: {!s}", self.dump_graph)
|
||||||
|
|
||||||
with self.lock:
|
with self.lock:
|
||||||
self._setdefault(cached)
|
self._setdefault(cached)
|
||||||
|
|
||||||
lazy_log('after adding: {!s}', self.dump_graph)
|
lazy_log("after adding: {!s}", self.dump_graph)
|
||||||
|
|
||||||
def _remove_refs(self, keys):
|
def _remove_refs(self, keys):
|
||||||
"""Removes all references to all entries in keys. This does not
|
"""Removes all references to all entries in keys. This does not
|
||||||
@@ -359,13 +363,10 @@ class RelationsCache:
|
|||||||
:param _CachedRelation dropped: An existing _CachedRelation to drop.
|
:param _CachedRelation dropped: An existing _CachedRelation to drop.
|
||||||
"""
|
"""
|
||||||
if dropped not in self.relations:
|
if dropped not in self.relations:
|
||||||
logger.debug('dropped a nonexistent relationship: {!s}'
|
logger.debug("dropped a nonexistent relationship: {!s}".format(dropped))
|
||||||
.format(dropped))
|
|
||||||
return
|
return
|
||||||
consequences = self.relations[dropped].collect_consequences()
|
consequences = self.relations[dropped].collect_consequences()
|
||||||
logger.debug(
|
logger.debug("drop {} is cascading to {}".format(dropped, consequences))
|
||||||
'drop {} is cascading to {}'.format(dropped, consequences)
|
|
||||||
)
|
|
||||||
self._remove_refs(consequences)
|
self._remove_refs(consequences)
|
||||||
|
|
||||||
def drop(self, relation):
|
def drop(self, relation):
|
||||||
@@ -380,7 +381,7 @@ class RelationsCache:
|
|||||||
:param str identifier: The identifier of the relation to drop.
|
:param str identifier: The identifier of the relation to drop.
|
||||||
"""
|
"""
|
||||||
dropped = _make_key(relation)
|
dropped = _make_key(relation)
|
||||||
logger.debug('Dropping relation: {!s}'.format(dropped))
|
logger.debug("Dropping relation: {!s}".format(dropped))
|
||||||
with self.lock:
|
with self.lock:
|
||||||
self._drop_cascade_relation(dropped)
|
self._drop_cascade_relation(dropped)
|
||||||
|
|
||||||
@@ -404,8 +405,9 @@ class RelationsCache:
|
|||||||
for cached in self.relations.values():
|
for cached in self.relations.values():
|
||||||
if cached.is_referenced_by(old_key):
|
if cached.is_referenced_by(old_key):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
'updated reference from {0} -> {2} to {1} -> {2}'
|
"updated reference from {0} -> {2} to {1} -> {2}".format(
|
||||||
.format(old_key, new_key, cached.key())
|
old_key, new_key, cached.key()
|
||||||
|
)
|
||||||
)
|
)
|
||||||
cached.rename_key(old_key, new_key)
|
cached.rename_key(old_key, new_key)
|
||||||
|
|
||||||
@@ -430,14 +432,16 @@ class RelationsCache:
|
|||||||
"""
|
"""
|
||||||
if new_key in self.relations:
|
if new_key in self.relations:
|
||||||
dbt.exceptions.raise_cache_inconsistent(
|
dbt.exceptions.raise_cache_inconsistent(
|
||||||
'in rename, new key {} already in cache: {}'
|
"in rename, new key {} already in cache: {}".format(
|
||||||
.format(new_key, list(self.relations.keys()))
|
new_key, list(self.relations.keys())
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if old_key not in self.relations:
|
if old_key not in self.relations:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
'old key {} not found in self.relations, assuming temporary'
|
"old key {} not found in self.relations, assuming temporary".format(
|
||||||
.format(old_key)
|
old_key
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
@@ -456,11 +460,9 @@ class RelationsCache:
|
|||||||
"""
|
"""
|
||||||
old_key = _make_key(old)
|
old_key = _make_key(old)
|
||||||
new_key = _make_key(new)
|
new_key = _make_key(new)
|
||||||
logger.debug('Renaming relation {!s} to {!s}'.format(
|
logger.debug("Renaming relation {!s} to {!s}".format(old_key, new_key))
|
||||||
old_key, new_key
|
|
||||||
))
|
|
||||||
|
|
||||||
lazy_log('before rename: {!s}', self.dump_graph)
|
lazy_log("before rename: {!s}", self.dump_graph)
|
||||||
|
|
||||||
with self.lock:
|
with self.lock:
|
||||||
if self._check_rename_constraints(old_key, new_key):
|
if self._check_rename_constraints(old_key, new_key):
|
||||||
@@ -468,7 +470,7 @@ class RelationsCache:
|
|||||||
else:
|
else:
|
||||||
self._setdefault(_CachedRelation(new))
|
self._setdefault(_CachedRelation(new))
|
||||||
|
|
||||||
lazy_log('after rename: {!s}', self.dump_graph)
|
lazy_log("after rename: {!s}", self.dump_graph)
|
||||||
|
|
||||||
def get_relations(
|
def get_relations(
|
||||||
self, database: Optional[str], schema: Optional[str]
|
self, database: Optional[str], schema: Optional[str]
|
||||||
@@ -483,14 +485,14 @@ class RelationsCache:
|
|||||||
schema = lowercase(schema)
|
schema = lowercase(schema)
|
||||||
with self.lock:
|
with self.lock:
|
||||||
results = [
|
results = [
|
||||||
r.inner for r in self.relations.values()
|
r.inner
|
||||||
if (lowercase(r.schema) == schema and
|
for r in self.relations.values()
|
||||||
lowercase(r.database) == database)
|
if (lowercase(r.schema) == schema and lowercase(r.database) == database)
|
||||||
]
|
]
|
||||||
|
|
||||||
if None in results:
|
if None in results:
|
||||||
dbt.exceptions.raise_cache_inconsistent(
|
dbt.exceptions.raise_cache_inconsistent(
|
||||||
'in get_relations, a None relation was found in the cache!'
|
"in get_relations, a None relation was found in the cache!"
|
||||||
)
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|||||||
@@ -50,9 +50,7 @@ class AdapterContainer:
|
|||||||
adapter = self.get_adapter_class_by_name(name)
|
adapter = self.get_adapter_class_by_name(name)
|
||||||
return adapter.Relation
|
return adapter.Relation
|
||||||
|
|
||||||
def get_config_class_by_name(
|
def get_config_class_by_name(self, name: str) -> Type[AdapterConfig]:
|
||||||
self, name: str
|
|
||||||
) -> Type[AdapterConfig]:
|
|
||||||
adapter = self.get_adapter_class_by_name(name)
|
adapter = self.get_adapter_class_by_name(name)
|
||||||
return adapter.AdapterSpecificConfigs
|
return adapter.AdapterSpecificConfigs
|
||||||
|
|
||||||
@@ -62,24 +60,24 @@ class AdapterContainer:
|
|||||||
# singletons
|
# singletons
|
||||||
try:
|
try:
|
||||||
# mypy doesn't think modules have any attributes.
|
# mypy doesn't think modules have any attributes.
|
||||||
mod: Any = import_module('.' + name, 'dbt.adapters')
|
mod: Any = import_module("." + name, "dbt.adapters")
|
||||||
except ModuleNotFoundError as exc:
|
except ModuleNotFoundError as exc:
|
||||||
# if we failed to import the target module in particular, inform
|
# if we failed to import the target module in particular, inform
|
||||||
# the user about it via a runtime error
|
# the user about it via a runtime error
|
||||||
if exc.name == 'dbt.adapters.' + name:
|
if exc.name == "dbt.adapters." + name:
|
||||||
raise RuntimeException(f'Could not find adapter type {name}!')
|
raise RuntimeException(f"Could not find adapter type {name}!")
|
||||||
logger.info(f'Error importing adapter: {exc}')
|
logger.info(f"Error importing adapter: {exc}")
|
||||||
# otherwise, the error had to have come from some underlying
|
# otherwise, the error had to have come from some underlying
|
||||||
# library. Log the stack trace.
|
# library. Log the stack trace.
|
||||||
logger.debug('', exc_info=True)
|
logger.debug("", exc_info=True)
|
||||||
raise
|
raise
|
||||||
plugin: AdapterPlugin = mod.Plugin
|
plugin: AdapterPlugin = mod.Plugin
|
||||||
plugin_type = plugin.adapter.type()
|
plugin_type = plugin.adapter.type()
|
||||||
|
|
||||||
if plugin_type != name:
|
if plugin_type != name:
|
||||||
raise RuntimeException(
|
raise RuntimeException(
|
||||||
f'Expected to find adapter with type named {name}, got '
|
f"Expected to find adapter with type named {name}, got "
|
||||||
f'adapter with type {plugin_type}'
|
f"adapter with type {plugin_type}"
|
||||||
)
|
)
|
||||||
|
|
||||||
with self.lock:
|
with self.lock:
|
||||||
@@ -109,8 +107,7 @@ class AdapterContainer:
|
|||||||
return self.adapters[adapter_name]
|
return self.adapters[adapter_name]
|
||||||
|
|
||||||
def reset_adapters(self):
|
def reset_adapters(self):
|
||||||
"""Clear the adapters. This is useful for tests, which change configs.
|
"""Clear the adapters. This is useful for tests, which change configs."""
|
||||||
"""
|
|
||||||
with self.lock:
|
with self.lock:
|
||||||
for adapter in self.adapters.values():
|
for adapter in self.adapters.values():
|
||||||
adapter.cleanup_connections()
|
adapter.cleanup_connections()
|
||||||
@@ -140,9 +137,7 @@ class AdapterContainer:
|
|||||||
try:
|
try:
|
||||||
plugin = self.plugins[plugin_name]
|
plugin = self.plugins[plugin_name]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise InternalException(
|
raise InternalException(f"No plugin found for {plugin_name}") from None
|
||||||
f'No plugin found for {plugin_name}'
|
|
||||||
) from None
|
|
||||||
plugins.append(plugin)
|
plugins.append(plugin)
|
||||||
seen.add(plugin_name)
|
seen.add(plugin_name)
|
||||||
if plugin.dependencies is None:
|
if plugin.dependencies is None:
|
||||||
@@ -166,7 +161,7 @@ class AdapterContainer:
|
|||||||
path = self.packages[package_name]
|
path = self.packages[package_name]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
f'No internal package listing found for {package_name}'
|
f"No internal package listing found for {package_name}"
|
||||||
)
|
)
|
||||||
paths.append(path)
|
paths.append(path)
|
||||||
return paths
|
return paths
|
||||||
@@ -187,8 +182,7 @@ def get_adapter(config: AdapterRequiredConfig):
|
|||||||
|
|
||||||
|
|
||||||
def reset_adapters():
|
def reset_adapters():
|
||||||
"""Clear the adapters. This is useful for tests, which change configs.
|
"""Clear the adapters. This is useful for tests, which change configs."""
|
||||||
"""
|
|
||||||
FACTORY.reset_adapters()
|
FACTORY.reset_adapters()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,17 +1,27 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (
|
from typing import (
|
||||||
Type, Hashable, Optional, ContextManager, List, Generic, TypeVar, ClassVar,
|
Type,
|
||||||
Tuple, Union, Dict, Any
|
Hashable,
|
||||||
|
Optional,
|
||||||
|
ContextManager,
|
||||||
|
List,
|
||||||
|
Generic,
|
||||||
|
TypeVar,
|
||||||
|
ClassVar,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
Dict,
|
||||||
|
Any,
|
||||||
)
|
)
|
||||||
from typing_extensions import Protocol
|
from typing_extensions import Protocol
|
||||||
|
|
||||||
import agate
|
import agate
|
||||||
|
|
||||||
from dbt.contracts.connection import (
|
from dbt.contracts.connection import Connection, AdapterRequiredConfig, AdapterResponse
|
||||||
Connection, AdapterRequiredConfig, AdapterResponse
|
|
||||||
)
|
|
||||||
from dbt.contracts.graph.compiled import (
|
from dbt.contracts.graph.compiled import (
|
||||||
CompiledNode, ManifestNode, NonSourceCompiledNode
|
CompiledNode,
|
||||||
|
ManifestNode,
|
||||||
|
NonSourceCompiledNode,
|
||||||
)
|
)
|
||||||
from dbt.contracts.graph.parsed import ParsedNode, ParsedSourceDefinition
|
from dbt.contracts.graph.parsed import ParsedNode, ParsedSourceDefinition
|
||||||
from dbt.contracts.graph.model_config import BaseConfig
|
from dbt.contracts.graph.model_config import BaseConfig
|
||||||
@@ -34,7 +44,7 @@ class ColumnProtocol(Protocol):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
Self = TypeVar('Self', bound='RelationProtocol')
|
Self = TypeVar("Self", bound="RelationProtocol")
|
||||||
|
|
||||||
|
|
||||||
class RelationProtocol(Protocol):
|
class RelationProtocol(Protocol):
|
||||||
@@ -64,19 +74,11 @@ class CompilerProtocol(Protocol):
|
|||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
AdapterConfig_T = TypeVar(
|
AdapterConfig_T = TypeVar("AdapterConfig_T", bound=AdapterConfig)
|
||||||
'AdapterConfig_T', bound=AdapterConfig
|
ConnectionManager_T = TypeVar("ConnectionManager_T", bound=ConnectionManagerProtocol)
|
||||||
)
|
Relation_T = TypeVar("Relation_T", bound=RelationProtocol)
|
||||||
ConnectionManager_T = TypeVar(
|
Column_T = TypeVar("Column_T", bound=ColumnProtocol)
|
||||||
'ConnectionManager_T', bound=ConnectionManagerProtocol
|
Compiler_T = TypeVar("Compiler_T", bound=CompilerProtocol)
|
||||||
)
|
|
||||||
Relation_T = TypeVar(
|
|
||||||
'Relation_T', bound=RelationProtocol
|
|
||||||
)
|
|
||||||
Column_T = TypeVar(
|
|
||||||
'Column_T', bound=ColumnProtocol
|
|
||||||
)
|
|
||||||
Compiler_T = TypeVar('Compiler_T', bound=CompilerProtocol)
|
|
||||||
|
|
||||||
|
|
||||||
class AdapterProtocol(
|
class AdapterProtocol(
|
||||||
@@ -87,7 +89,7 @@ class AdapterProtocol(
|
|||||||
Relation_T,
|
Relation_T,
|
||||||
Column_T,
|
Column_T,
|
||||||
Compiler_T,
|
Compiler_T,
|
||||||
]
|
],
|
||||||
):
|
):
|
||||||
AdapterSpecificConfigs: ClassVar[Type[AdapterConfig_T]]
|
AdapterSpecificConfigs: ClassVar[Type[AdapterConfig_T]]
|
||||||
Column: ClassVar[Type[Column_T]]
|
Column: ClassVar[Type[Column_T]]
|
||||||
|
|||||||
@@ -7,9 +7,7 @@ import agate
|
|||||||
import dbt.clients.agate_helper
|
import dbt.clients.agate_helper
|
||||||
import dbt.exceptions
|
import dbt.exceptions
|
||||||
from dbt.adapters.base import BaseConnectionManager
|
from dbt.adapters.base import BaseConnectionManager
|
||||||
from dbt.contracts.connection import (
|
from dbt.contracts.connection import Connection, ConnectionState, AdapterResponse
|
||||||
Connection, ConnectionState, AdapterResponse
|
|
||||||
)
|
|
||||||
from dbt.logger import GLOBAL_LOGGER as logger
|
from dbt.logger import GLOBAL_LOGGER as logger
|
||||||
from dbt import flags
|
from dbt import flags
|
||||||
|
|
||||||
@@ -23,11 +21,12 @@ class SQLConnectionManager(BaseConnectionManager):
|
|||||||
- get_response
|
- get_response
|
||||||
- open
|
- open
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def cancel(self, connection: Connection):
|
def cancel(self, connection: Connection):
|
||||||
"""Cancel the given connection."""
|
"""Cancel the given connection."""
|
||||||
raise dbt.exceptions.NotImplementedException(
|
raise dbt.exceptions.NotImplementedException(
|
||||||
'`cancel` is not implemented for this adapter!'
|
"`cancel` is not implemented for this adapter!"
|
||||||
)
|
)
|
||||||
|
|
||||||
def cancel_open(self) -> List[str]:
|
def cancel_open(self) -> List[str]:
|
||||||
@@ -41,8 +40,8 @@ class SQLConnectionManager(BaseConnectionManager):
|
|||||||
# if the connection failed, the handle will be None so we have
|
# if the connection failed, the handle will be None so we have
|
||||||
# nothing to cancel.
|
# nothing to cancel.
|
||||||
if (
|
if (
|
||||||
connection.handle is not None and
|
connection.handle is not None
|
||||||
connection.state == ConnectionState.OPEN
|
and connection.state == ConnectionState.OPEN
|
||||||
):
|
):
|
||||||
self.cancel(connection)
|
self.cancel(connection)
|
||||||
if connection.name is not None:
|
if connection.name is not None:
|
||||||
@@ -54,23 +53,22 @@ class SQLConnectionManager(BaseConnectionManager):
|
|||||||
sql: str,
|
sql: str,
|
||||||
auto_begin: bool = True,
|
auto_begin: bool = True,
|
||||||
bindings: Optional[Any] = None,
|
bindings: Optional[Any] = None,
|
||||||
abridge_sql_log: bool = False
|
abridge_sql_log: bool = False,
|
||||||
) -> Tuple[Connection, Any]:
|
) -> Tuple[Connection, Any]:
|
||||||
connection = self.get_thread_connection()
|
connection = self.get_thread_connection()
|
||||||
if auto_begin and connection.transaction_open is False:
|
if auto_begin and connection.transaction_open is False:
|
||||||
self.begin()
|
self.begin()
|
||||||
|
|
||||||
logger.debug('Using {} connection "{}".'
|
logger.debug('Using {} connection "{}".'.format(self.TYPE, connection.name))
|
||||||
.format(self.TYPE, connection.name))
|
|
||||||
|
|
||||||
with self.exception_handler(sql):
|
with self.exception_handler(sql):
|
||||||
if abridge_sql_log:
|
if abridge_sql_log:
|
||||||
log_sql = '{}...'.format(sql[:512])
|
log_sql = "{}...".format(sql[:512])
|
||||||
else:
|
else:
|
||||||
log_sql = sql
|
log_sql = sql
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
'On {connection_name}: {sql}',
|
"On {connection_name}: {sql}",
|
||||||
connection_name=connection.name,
|
connection_name=connection.name,
|
||||||
sql=log_sql,
|
sql=log_sql,
|
||||||
)
|
)
|
||||||
@@ -81,7 +79,7 @@ class SQLConnectionManager(BaseConnectionManager):
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
"SQL status: {status} in {elapsed:0.2f} seconds",
|
"SQL status: {status} in {elapsed:0.2f} seconds",
|
||||||
status=self.get_response(cursor),
|
status=self.get_response(cursor),
|
||||||
elapsed=(time.time() - pre)
|
elapsed=(time.time() - pre),
|
||||||
)
|
)
|
||||||
|
|
||||||
return connection, cursor
|
return connection, cursor
|
||||||
@@ -90,14 +88,12 @@ class SQLConnectionManager(BaseConnectionManager):
|
|||||||
def get_response(cls, cursor: Any) -> Union[AdapterResponse, str]:
|
def get_response(cls, cursor: Any) -> Union[AdapterResponse, str]:
|
||||||
"""Get the status of the cursor."""
|
"""Get the status of the cursor."""
|
||||||
raise dbt.exceptions.NotImplementedException(
|
raise dbt.exceptions.NotImplementedException(
|
||||||
'`get_response` is not implemented for this adapter!'
|
"`get_response` is not implemented for this adapter!"
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def process_results(
|
def process_results(
|
||||||
cls,
|
cls, column_names: Iterable[str], rows: Iterable[Any]
|
||||||
column_names: Iterable[str],
|
|
||||||
rows: Iterable[Any]
|
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
|
|
||||||
return [dict(zip(column_names, row)) for row in rows]
|
return [dict(zip(column_names, row)) for row in rows]
|
||||||
@@ -112,10 +108,7 @@ class SQLConnectionManager(BaseConnectionManager):
|
|||||||
rows = cursor.fetchall()
|
rows = cursor.fetchall()
|
||||||
data = cls.process_results(column_names, rows)
|
data = cls.process_results(column_names, rows)
|
||||||
|
|
||||||
return dbt.clients.agate_helper.table_from_data_flat(
|
return dbt.clients.agate_helper.table_from_data_flat(data, column_names)
|
||||||
data,
|
|
||||||
column_names
|
|
||||||
)
|
|
||||||
|
|
||||||
def execute(
|
def execute(
|
||||||
self, sql: str, auto_begin: bool = False, fetch: bool = False
|
self, sql: str, auto_begin: bool = False, fetch: bool = False
|
||||||
@@ -130,10 +123,10 @@ class SQLConnectionManager(BaseConnectionManager):
|
|||||||
return response, table
|
return response, table
|
||||||
|
|
||||||
def add_begin_query(self):
|
def add_begin_query(self):
|
||||||
return self.add_query('BEGIN', auto_begin=False)
|
return self.add_query("BEGIN", auto_begin=False)
|
||||||
|
|
||||||
def add_commit_query(self):
|
def add_commit_query(self):
|
||||||
return self.add_query('COMMIT', auto_begin=False)
|
return self.add_query("COMMIT", auto_begin=False)
|
||||||
|
|
||||||
def begin(self):
|
def begin(self):
|
||||||
connection = self.get_thread_connection()
|
connection = self.get_thread_connection()
|
||||||
@@ -141,13 +134,14 @@ class SQLConnectionManager(BaseConnectionManager):
|
|||||||
if flags.STRICT_MODE:
|
if flags.STRICT_MODE:
|
||||||
if not isinstance(connection, Connection):
|
if not isinstance(connection, Connection):
|
||||||
raise dbt.exceptions.CompilerException(
|
raise dbt.exceptions.CompilerException(
|
||||||
f'In begin, got {connection} - not a Connection!'
|
f"In begin, got {connection} - not a Connection!"
|
||||||
)
|
)
|
||||||
|
|
||||||
if connection.transaction_open is True:
|
if connection.transaction_open is True:
|
||||||
raise dbt.exceptions.InternalException(
|
raise dbt.exceptions.InternalException(
|
||||||
'Tried to begin a new transaction on connection "{}", but '
|
'Tried to begin a new transaction on connection "{}", but '
|
||||||
'it already had one open!'.format(connection.name))
|
"it already had one open!".format(connection.name)
|
||||||
|
)
|
||||||
|
|
||||||
self.add_begin_query()
|
self.add_begin_query()
|
||||||
|
|
||||||
@@ -159,15 +153,16 @@ class SQLConnectionManager(BaseConnectionManager):
|
|||||||
if flags.STRICT_MODE:
|
if flags.STRICT_MODE:
|
||||||
if not isinstance(connection, Connection):
|
if not isinstance(connection, Connection):
|
||||||
raise dbt.exceptions.CompilerException(
|
raise dbt.exceptions.CompilerException(
|
||||||
f'In commit, got {connection} - not a Connection!'
|
f"In commit, got {connection} - not a Connection!"
|
||||||
)
|
)
|
||||||
|
|
||||||
if connection.transaction_open is False:
|
if connection.transaction_open is False:
|
||||||
raise dbt.exceptions.InternalException(
|
raise dbt.exceptions.InternalException(
|
||||||
'Tried to commit transaction on connection "{}", but '
|
'Tried to commit transaction on connection "{}", but '
|
||||||
'it does not have one open!'.format(connection.name))
|
"it does not have one open!".format(connection.name)
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug('On {}: COMMIT'.format(connection.name))
|
logger.debug("On {}: COMMIT".format(connection.name))
|
||||||
self.add_commit_query()
|
self.add_commit_query()
|
||||||
|
|
||||||
connection.transaction_open = False
|
connection.transaction_open = False
|
||||||
|
|||||||
@@ -10,16 +10,16 @@ from dbt.logger import GLOBAL_LOGGER as logger
|
|||||||
|
|
||||||
from dbt.adapters.base.relation import BaseRelation
|
from dbt.adapters.base.relation import BaseRelation
|
||||||
|
|
||||||
LIST_RELATIONS_MACRO_NAME = 'list_relations_without_caching'
|
LIST_RELATIONS_MACRO_NAME = "list_relations_without_caching"
|
||||||
GET_COLUMNS_IN_RELATION_MACRO_NAME = 'get_columns_in_relation'
|
GET_COLUMNS_IN_RELATION_MACRO_NAME = "get_columns_in_relation"
|
||||||
LIST_SCHEMAS_MACRO_NAME = 'list_schemas'
|
LIST_SCHEMAS_MACRO_NAME = "list_schemas"
|
||||||
CHECK_SCHEMA_EXISTS_MACRO_NAME = 'check_schema_exists'
|
CHECK_SCHEMA_EXISTS_MACRO_NAME = "check_schema_exists"
|
||||||
CREATE_SCHEMA_MACRO_NAME = 'create_schema'
|
CREATE_SCHEMA_MACRO_NAME = "create_schema"
|
||||||
DROP_SCHEMA_MACRO_NAME = 'drop_schema'
|
DROP_SCHEMA_MACRO_NAME = "drop_schema"
|
||||||
RENAME_RELATION_MACRO_NAME = 'rename_relation'
|
RENAME_RELATION_MACRO_NAME = "rename_relation"
|
||||||
TRUNCATE_RELATION_MACRO_NAME = 'truncate_relation'
|
TRUNCATE_RELATION_MACRO_NAME = "truncate_relation"
|
||||||
DROP_RELATION_MACRO_NAME = 'drop_relation'
|
DROP_RELATION_MACRO_NAME = "drop_relation"
|
||||||
ALTER_COLUMN_TYPE_MACRO_NAME = 'alter_column_type'
|
ALTER_COLUMN_TYPE_MACRO_NAME = "alter_column_type"
|
||||||
|
|
||||||
|
|
||||||
class SQLAdapter(BaseAdapter):
|
class SQLAdapter(BaseAdapter):
|
||||||
@@ -60,30 +60,23 @@ class SQLAdapter(BaseAdapter):
|
|||||||
:param abridge_sql_log: If set, limit the raw sql logged to 512
|
:param abridge_sql_log: If set, limit the raw sql logged to 512
|
||||||
characters
|
characters
|
||||||
"""
|
"""
|
||||||
return self.connections.add_query(sql, auto_begin, bindings,
|
return self.connections.add_query(sql, auto_begin, bindings, abridge_sql_log)
|
||||||
abridge_sql_log)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_text_type(cls, agate_table: agate.Table, col_idx: int) -> str:
|
def convert_text_type(cls, agate_table: agate.Table, col_idx: int) -> str:
|
||||||
return "text"
|
return "text"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_number_type(
|
def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str:
|
||||||
cls, agate_table: agate.Table, col_idx: int
|
decimals = agate_table.aggregate(agate.MaxPrecision(col_idx)) # type: ignore
|
||||||
) -> str:
|
|
||||||
decimals = agate_table.aggregate(agate.MaxPrecision(col_idx))
|
|
||||||
return "float8" if decimals else "integer"
|
return "float8" if decimals else "integer"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_boolean_type(
|
def convert_boolean_type(cls, agate_table: agate.Table, col_idx: int) -> str:
|
||||||
cls, agate_table: agate.Table, col_idx: int
|
|
||||||
) -> str:
|
|
||||||
return "boolean"
|
return "boolean"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_datetime_type(
|
def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str:
|
||||||
cls, agate_table: agate.Table, col_idx: int
|
|
||||||
) -> str:
|
|
||||||
return "timestamp without time zone"
|
return "timestamp without time zone"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -99,31 +92,28 @@ class SQLAdapter(BaseAdapter):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def expand_column_types(self, goal, current):
|
def expand_column_types(self, goal, current):
|
||||||
reference_columns = {
|
reference_columns = {c.name: c for c in self.get_columns_in_relation(goal)}
|
||||||
c.name: c for c in
|
|
||||||
self.get_columns_in_relation(goal)
|
|
||||||
}
|
|
||||||
|
|
||||||
target_columns = {
|
target_columns = {c.name: c for c in self.get_columns_in_relation(current)}
|
||||||
c.name: c for c
|
|
||||||
in self.get_columns_in_relation(current)
|
|
||||||
}
|
|
||||||
|
|
||||||
for column_name, reference_column in reference_columns.items():
|
for column_name, reference_column in reference_columns.items():
|
||||||
target_column = target_columns.get(column_name)
|
target_column = target_columns.get(column_name)
|
||||||
|
|
||||||
if target_column is not None and \
|
if target_column is not None and target_column.can_expand_to(
|
||||||
target_column.can_expand_to(reference_column):
|
reference_column
|
||||||
|
):
|
||||||
col_string_size = reference_column.string_size()
|
col_string_size = reference_column.string_size()
|
||||||
new_type = self.Column.string_type(col_string_size)
|
new_type = self.Column.string_type(col_string_size)
|
||||||
logger.debug("Changing col type from {} to {} in table {}",
|
logger.debug(
|
||||||
target_column.data_type, new_type, current)
|
"Changing col type from {} to {} in table {}",
|
||||||
|
target_column.data_type,
|
||||||
|
new_type,
|
||||||
|
current,
|
||||||
|
)
|
||||||
|
|
||||||
self.alter_column_type(current, column_name, new_type)
|
self.alter_column_type(current, column_name, new_type)
|
||||||
|
|
||||||
def alter_column_type(
|
def alter_column_type(self, relation, column_name, new_column_type) -> None:
|
||||||
self, relation, column_name, new_column_type
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
1. Create a new column (w/ temp name and correct type)
|
1. Create a new column (w/ temp name and correct type)
|
||||||
2. Copy data over to it
|
2. Copy data over to it
|
||||||
@@ -131,53 +121,40 @@ class SQLAdapter(BaseAdapter):
|
|||||||
4. Rename the new column to existing column
|
4. Rename the new column to existing column
|
||||||
"""
|
"""
|
||||||
kwargs = {
|
kwargs = {
|
||||||
'relation': relation,
|
"relation": relation,
|
||||||
'column_name': column_name,
|
"column_name": column_name,
|
||||||
'new_column_type': new_column_type,
|
"new_column_type": new_column_type,
|
||||||
}
|
}
|
||||||
self.execute_macro(
|
self.execute_macro(ALTER_COLUMN_TYPE_MACRO_NAME, kwargs=kwargs)
|
||||||
ALTER_COLUMN_TYPE_MACRO_NAME,
|
|
||||||
kwargs=kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
def drop_relation(self, relation):
|
def drop_relation(self, relation):
|
||||||
if relation.type is None:
|
if relation.type is None:
|
||||||
dbt.exceptions.raise_compiler_error(
|
dbt.exceptions.raise_compiler_error(
|
||||||
'Tried to drop relation {}, but its type is null.'
|
"Tried to drop relation {}, but its type is null.".format(relation)
|
||||||
.format(relation))
|
)
|
||||||
|
|
||||||
self.cache_dropped(relation)
|
self.cache_dropped(relation)
|
||||||
self.execute_macro(
|
self.execute_macro(DROP_RELATION_MACRO_NAME, kwargs={"relation": relation})
|
||||||
DROP_RELATION_MACRO_NAME,
|
|
||||||
kwargs={'relation': relation}
|
|
||||||
)
|
|
||||||
|
|
||||||
def truncate_relation(self, relation):
|
def truncate_relation(self, relation):
|
||||||
self.execute_macro(
|
self.execute_macro(TRUNCATE_RELATION_MACRO_NAME, kwargs={"relation": relation})
|
||||||
TRUNCATE_RELATION_MACRO_NAME,
|
|
||||||
kwargs={'relation': relation}
|
|
||||||
)
|
|
||||||
|
|
||||||
def rename_relation(self, from_relation, to_relation):
|
def rename_relation(self, from_relation, to_relation):
|
||||||
self.cache_renamed(from_relation, to_relation)
|
self.cache_renamed(from_relation, to_relation)
|
||||||
|
|
||||||
kwargs = {'from_relation': from_relation, 'to_relation': to_relation}
|
kwargs = {"from_relation": from_relation, "to_relation": to_relation}
|
||||||
self.execute_macro(
|
self.execute_macro(RENAME_RELATION_MACRO_NAME, kwargs=kwargs)
|
||||||
RENAME_RELATION_MACRO_NAME,
|
|
||||||
kwargs=kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_columns_in_relation(self, relation):
|
def get_columns_in_relation(self, relation):
|
||||||
return self.execute_macro(
|
return self.execute_macro(
|
||||||
GET_COLUMNS_IN_RELATION_MACRO_NAME,
|
GET_COLUMNS_IN_RELATION_MACRO_NAME, kwargs={"relation": relation}
|
||||||
kwargs={'relation': relation}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_schema(self, relation: BaseRelation) -> None:
|
def create_schema(self, relation: BaseRelation) -> None:
|
||||||
relation = relation.without_identifier()
|
relation = relation.without_identifier()
|
||||||
logger.debug('Creating schema "{}"', relation)
|
logger.debug('Creating schema "{}"', relation)
|
||||||
kwargs = {
|
kwargs = {
|
||||||
'relation': relation,
|
"relation": relation,
|
||||||
}
|
}
|
||||||
self.execute_macro(CREATE_SCHEMA_MACRO_NAME, kwargs=kwargs)
|
self.execute_macro(CREATE_SCHEMA_MACRO_NAME, kwargs=kwargs)
|
||||||
self.commit_if_has_connection()
|
self.commit_if_has_connection()
|
||||||
@@ -188,39 +165,35 @@ class SQLAdapter(BaseAdapter):
|
|||||||
relation = relation.without_identifier()
|
relation = relation.without_identifier()
|
||||||
logger.debug('Dropping schema "{}".', relation)
|
logger.debug('Dropping schema "{}".', relation)
|
||||||
kwargs = {
|
kwargs = {
|
||||||
'relation': relation,
|
"relation": relation,
|
||||||
}
|
}
|
||||||
self.execute_macro(DROP_SCHEMA_MACRO_NAME, kwargs=kwargs)
|
self.execute_macro(DROP_SCHEMA_MACRO_NAME, kwargs=kwargs)
|
||||||
# we can update the cache here
|
# we can update the cache here
|
||||||
self.cache.drop_schema(relation.database, relation.schema)
|
self.cache.drop_schema(relation.database, relation.schema)
|
||||||
|
|
||||||
def list_relations_without_caching(
|
def list_relations_without_caching(
|
||||||
self, schema_relation: BaseRelation,
|
self,
|
||||||
|
schema_relation: BaseRelation,
|
||||||
) -> List[BaseRelation]:
|
) -> List[BaseRelation]:
|
||||||
kwargs = {'schema_relation': schema_relation}
|
kwargs = {"schema_relation": schema_relation}
|
||||||
results = self.execute_macro(
|
results = self.execute_macro(LIST_RELATIONS_MACRO_NAME, kwargs=kwargs)
|
||||||
LIST_RELATIONS_MACRO_NAME,
|
|
||||||
kwargs=kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
relations = []
|
relations = []
|
||||||
quote_policy = {
|
quote_policy = {"database": True, "schema": True, "identifier": True}
|
||||||
'database': True,
|
|
||||||
'schema': True,
|
|
||||||
'identifier': True
|
|
||||||
}
|
|
||||||
for _database, name, _schema, _type in results:
|
for _database, name, _schema, _type in results:
|
||||||
try:
|
try:
|
||||||
_type = self.Relation.get_relation_type(_type)
|
_type = self.Relation.get_relation_type(_type)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
_type = self.Relation.External
|
_type = self.Relation.External
|
||||||
relations.append(self.Relation.create(
|
relations.append(
|
||||||
|
self.Relation.create(
|
||||||
database=_database,
|
database=_database,
|
||||||
schema=_schema,
|
schema=_schema,
|
||||||
identifier=name,
|
identifier=name,
|
||||||
quote_policy=quote_policy,
|
quote_policy=quote_policy,
|
||||||
type=_type
|
type=_type,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
return relations
|
return relations
|
||||||
|
|
||||||
def quote(self, identifier):
|
def quote(self, identifier):
|
||||||
@@ -228,8 +201,7 @@ class SQLAdapter(BaseAdapter):
|
|||||||
|
|
||||||
def list_schemas(self, database: str) -> List[str]:
|
def list_schemas(self, database: str) -> List[str]:
|
||||||
results = self.execute_macro(
|
results = self.execute_macro(
|
||||||
LIST_SCHEMAS_MACRO_NAME,
|
LIST_SCHEMAS_MACRO_NAME, kwargs={"database": database}
|
||||||
kwargs={'database': database}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return [row[0] for row in results]
|
return [row[0] for row in results]
|
||||||
@@ -238,13 +210,10 @@ class SQLAdapter(BaseAdapter):
|
|||||||
information_schema = self.Relation.create(
|
information_schema = self.Relation.create(
|
||||||
database=database,
|
database=database,
|
||||||
schema=schema,
|
schema=schema,
|
||||||
identifier='INFORMATION_SCHEMA',
|
identifier="INFORMATION_SCHEMA",
|
||||||
quote_policy=self.config.quoting
|
quote_policy=self.config.quoting,
|
||||||
).information_schema()
|
).information_schema()
|
||||||
|
|
||||||
kwargs = {'information_schema': information_schema, 'schema': schema}
|
kwargs = {"information_schema": information_schema, "schema": schema}
|
||||||
results = self.execute_macro(
|
results = self.execute_macro(CHECK_SCHEMA_EXISTS_MACRO_NAME, kwargs=kwargs)
|
||||||
CHECK_SCHEMA_EXISTS_MACRO_NAME,
|
|
||||||
kwargs=kwargs
|
|
||||||
)
|
|
||||||
return results[0][0] > 0
|
return results[0][0] > 0
|
||||||
|
|||||||
@@ -10,79 +10,89 @@ def regex(pat):
|
|||||||
|
|
||||||
class BlockData:
|
class BlockData:
|
||||||
"""raw plaintext data from the top level of the file."""
|
"""raw plaintext data from the top level of the file."""
|
||||||
|
|
||||||
def __init__(self, contents):
|
def __init__(self, contents):
|
||||||
self.block_type_name = '__dbt__data'
|
self.block_type_name = "__dbt__data"
|
||||||
self.contents = contents
|
self.contents = contents
|
||||||
self.full_block = contents
|
self.full_block = contents
|
||||||
|
|
||||||
|
|
||||||
class BlockTag:
|
class BlockTag:
|
||||||
def __init__(self, block_type_name, block_name, contents=None,
|
def __init__(
|
||||||
full_block=None, **kw):
|
self, block_type_name, block_name, contents=None, full_block=None, **kw
|
||||||
|
):
|
||||||
self.block_type_name = block_type_name
|
self.block_type_name = block_type_name
|
||||||
self.block_name = block_name
|
self.block_name = block_name
|
||||||
self.contents = contents
|
self.contents = contents
|
||||||
self.full_block = full_block
|
self.full_block = full_block
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return 'BlockTag({!r}, {!r})'.format(self.block_type_name,
|
return "BlockTag({!r}, {!r})".format(self.block_type_name, self.block_name)
|
||||||
self.block_name)
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return str(self)
|
return str(self)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def end_block_type_name(self):
|
def end_block_type_name(self):
|
||||||
return 'end{}'.format(self.block_type_name)
|
return "end{}".format(self.block_type_name)
|
||||||
|
|
||||||
def end_pat(self):
|
def end_pat(self):
|
||||||
# we don't want to use string formatting here because jinja uses most
|
# we don't want to use string formatting here because jinja uses most
|
||||||
# of the string formatting operators in its syntax...
|
# of the string formatting operators in its syntax...
|
||||||
pattern = ''.join((
|
pattern = "".join(
|
||||||
r'(?P<endblock>((?:\s*\{\%\-|\{\%)\s*',
|
(
|
||||||
|
r"(?P<endblock>((?:\s*\{\%\-|\{\%)\s*",
|
||||||
self.end_block_type_name,
|
self.end_block_type_name,
|
||||||
r'\s*(?:\-\%\}\s*|\%\})))',
|
r"\s*(?:\-\%\}\s*|\%\})))",
|
||||||
))
|
)
|
||||||
|
)
|
||||||
return regex(pattern)
|
return regex(pattern)
|
||||||
|
|
||||||
|
|
||||||
Tag = namedtuple('Tag', 'block_type_name block_name start end')
|
Tag = namedtuple("Tag", "block_type_name block_name start end")
|
||||||
|
|
||||||
|
|
||||||
_NAME_PATTERN = r'[A-Za-z_][A-Za-z_0-9]*'
|
_NAME_PATTERN = r"[A-Za-z_][A-Za-z_0-9]*"
|
||||||
|
|
||||||
COMMENT_START_PATTERN = regex(r'(?:(?P<comment_start>(\s*\{\#)))')
|
COMMENT_START_PATTERN = regex(r"(?:(?P<comment_start>(\s*\{\#)))")
|
||||||
COMMENT_END_PATTERN = regex(r'(.*?)(\s*\#\})')
|
COMMENT_END_PATTERN = regex(r"(.*?)(\s*\#\})")
|
||||||
RAW_START_PATTERN = regex(
|
RAW_START_PATTERN = regex(
|
||||||
r'(?:\s*\{\%\-|\{\%)\s*(?P<raw_start>(raw))\s*(?:\-\%\}\s*|\%\})'
|
r"(?:\s*\{\%\-|\{\%)\s*(?P<raw_start>(raw))\s*(?:\-\%\}\s*|\%\})"
|
||||||
)
|
)
|
||||||
EXPR_START_PATTERN = regex(r'(?P<expr_start>(\{\{\s*))')
|
EXPR_START_PATTERN = regex(r"(?P<expr_start>(\{\{\s*))")
|
||||||
EXPR_END_PATTERN = regex(r'(?P<expr_end>(\s*\}\}))')
|
EXPR_END_PATTERN = regex(r"(?P<expr_end>(\s*\}\}))")
|
||||||
|
|
||||||
BLOCK_START_PATTERN = regex(''.join((
|
BLOCK_START_PATTERN = regex(
|
||||||
r'(?:\s*\{\%\-|\{\%)\s*',
|
"".join(
|
||||||
r'(?P<block_type_name>({}))'.format(_NAME_PATTERN),
|
(
|
||||||
|
r"(?:\s*\{\%\-|\{\%)\s*",
|
||||||
|
r"(?P<block_type_name>({}))".format(_NAME_PATTERN),
|
||||||
# some blocks have a 'block name'.
|
# some blocks have a 'block name'.
|
||||||
r'(?:\s+(?P<block_name>({})))?'.format(_NAME_PATTERN),
|
r"(?:\s+(?P<block_name>({})))?".format(_NAME_PATTERN),
|
||||||
)))
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
RAW_BLOCK_PATTERN = regex(''.join((
|
RAW_BLOCK_PATTERN = regex(
|
||||||
r'(?:\s*\{\%\-|\{\%)\s*raw\s*(?:\-\%\}\s*|\%\})',
|
"".join(
|
||||||
r'(?:.*?)',
|
(
|
||||||
r'(?:\s*\{\%\-|\{\%)\s*endraw\s*(?:\-\%\}\s*|\%\})',
|
r"(?:\s*\{\%\-|\{\%)\s*raw\s*(?:\-\%\}\s*|\%\})",
|
||||||
)))
|
r"(?:.*?)",
|
||||||
|
r"(?:\s*\{\%\-|\{\%)\s*endraw\s*(?:\-\%\}\s*|\%\})",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
TAG_CLOSE_PATTERN = regex(r'(?:(?P<tag_close>(\-\%\}\s*|\%\})))')
|
TAG_CLOSE_PATTERN = regex(r"(?:(?P<tag_close>(\-\%\}\s*|\%\})))")
|
||||||
|
|
||||||
# stolen from jinja's lexer. Note that we've consumed all prefix whitespace by
|
# stolen from jinja's lexer. Note that we've consumed all prefix whitespace by
|
||||||
# the time we want to use this.
|
# the time we want to use this.
|
||||||
STRING_PATTERN = regex(
|
STRING_PATTERN = regex(
|
||||||
r"(?P<string>('([^'\\]*(?:\\.[^'\\]*)*)'|"
|
r"(?P<string>('([^'\\]*(?:\\.[^'\\]*)*)'|" r'"([^"\\]*(?:\\.[^"\\]*)*)"))'
|
||||||
r'"([^"\\]*(?:\\.[^"\\]*)*)"))'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
QUOTE_START_PATTERN = regex(r'''(?P<quote>(['"]))''')
|
QUOTE_START_PATTERN = regex(r"""(?P<quote>(['"]))""")
|
||||||
|
|
||||||
|
|
||||||
class TagIterator:
|
class TagIterator:
|
||||||
@@ -99,10 +109,10 @@ class TagIterator:
|
|||||||
end_val: int = self.pos if end is None else end
|
end_val: int = self.pos if end is None else end
|
||||||
data = self.data[:end_val]
|
data = self.data[:end_val]
|
||||||
# if not found, rfind returns -1, and -1+1=0, which is perfect!
|
# if not found, rfind returns -1, and -1+1=0, which is perfect!
|
||||||
last_line_start = data.rfind('\n') + 1
|
last_line_start = data.rfind("\n") + 1
|
||||||
# it's easy to forget this, but line numbers are 1-indexed
|
# it's easy to forget this, but line numbers are 1-indexed
|
||||||
line_number = data.count('\n') + 1
|
line_number = data.count("\n") + 1
|
||||||
return f'{line_number}:{end_val - last_line_start}'
|
return f"{line_number}:{end_val - last_line_start}"
|
||||||
|
|
||||||
def advance(self, new_position):
|
def advance(self, new_position):
|
||||||
self.pos = new_position
|
self.pos = new_position
|
||||||
@@ -120,7 +130,7 @@ class TagIterator:
|
|||||||
matches = []
|
matches = []
|
||||||
for pattern in patterns:
|
for pattern in patterns:
|
||||||
# default to 'search', but sometimes we want to 'match'.
|
# default to 'search', but sometimes we want to 'match'.
|
||||||
if kwargs.get('method', 'search') == 'search':
|
if kwargs.get("method", "search") == "search":
|
||||||
match = self._search(pattern)
|
match = self._search(pattern)
|
||||||
else:
|
else:
|
||||||
match = self._match(pattern)
|
match = self._match(pattern)
|
||||||
@@ -156,22 +166,20 @@ class TagIterator:
|
|||||||
"""
|
"""
|
||||||
self.advance(match.end())
|
self.advance(match.end())
|
||||||
while True:
|
while True:
|
||||||
match = self._expect_match('}}',
|
match = self._expect_match("}}", EXPR_END_PATTERN, QUOTE_START_PATTERN)
|
||||||
EXPR_END_PATTERN,
|
if match.groupdict().get("expr_end") is not None:
|
||||||
QUOTE_START_PATTERN)
|
|
||||||
if match.groupdict().get('expr_end') is not None:
|
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# it's a quote. we haven't advanced for this match yet, so
|
# it's a quote. we haven't advanced for this match yet, so
|
||||||
# just slurp up the whole string, no need to rewind.
|
# just slurp up the whole string, no need to rewind.
|
||||||
match = self._expect_match('string', STRING_PATTERN)
|
match = self._expect_match("string", STRING_PATTERN)
|
||||||
self.advance(match.end())
|
self.advance(match.end())
|
||||||
|
|
||||||
self.advance(match.end())
|
self.advance(match.end())
|
||||||
|
|
||||||
def handle_comment(self, match):
|
def handle_comment(self, match):
|
||||||
self.advance(match.end())
|
self.advance(match.end())
|
||||||
match = self._expect_match('#}', COMMENT_END_PATTERN)
|
match = self._expect_match("#}", COMMENT_END_PATTERN)
|
||||||
self.advance(match.end())
|
self.advance(match.end())
|
||||||
|
|
||||||
def _expect_block_close(self):
|
def _expect_block_close(self):
|
||||||
@@ -188,22 +196,19 @@ class TagIterator:
|
|||||||
"""
|
"""
|
||||||
while True:
|
while True:
|
||||||
end_match = self._expect_match(
|
end_match = self._expect_match(
|
||||||
'tag close ("%}")',
|
'tag close ("%}")', QUOTE_START_PATTERN, TAG_CLOSE_PATTERN
|
||||||
QUOTE_START_PATTERN,
|
|
||||||
TAG_CLOSE_PATTERN
|
|
||||||
)
|
)
|
||||||
self.advance(end_match.end())
|
self.advance(end_match.end())
|
||||||
if end_match.groupdict().get('tag_close') is not None:
|
if end_match.groupdict().get("tag_close") is not None:
|
||||||
return
|
return
|
||||||
# must be a string. Rewind to its start and advance past it.
|
# must be a string. Rewind to its start and advance past it.
|
||||||
self.rewind()
|
self.rewind()
|
||||||
string_match = self._expect_match('string', STRING_PATTERN)
|
string_match = self._expect_match("string", STRING_PATTERN)
|
||||||
self.advance(string_match.end())
|
self.advance(string_match.end())
|
||||||
|
|
||||||
def handle_raw(self):
|
def handle_raw(self):
|
||||||
# raw blocks are super special, they are a single complete regex
|
# raw blocks are super special, they are a single complete regex
|
||||||
match = self._expect_match('{% raw %}...{% endraw %}',
|
match = self._expect_match("{% raw %}...{% endraw %}", RAW_BLOCK_PATTERN)
|
||||||
RAW_BLOCK_PATTERN)
|
|
||||||
self.advance(match.end())
|
self.advance(match.end())
|
||||||
return match.end()
|
return match.end()
|
||||||
|
|
||||||
@@ -220,13 +225,12 @@ class TagIterator:
|
|||||||
"""
|
"""
|
||||||
groups = match.groupdict()
|
groups = match.groupdict()
|
||||||
# always a value
|
# always a value
|
||||||
block_type_name = groups['block_type_name']
|
block_type_name = groups["block_type_name"]
|
||||||
# might be None
|
# might be None
|
||||||
block_name = groups.get('block_name')
|
block_name = groups.get("block_name")
|
||||||
start_pos = self.pos
|
start_pos = self.pos
|
||||||
if block_type_name == 'raw':
|
if block_type_name == "raw":
|
||||||
match = self._expect_match('{% raw %}...{% endraw %}',
|
match = self._expect_match("{% raw %}...{% endraw %}", RAW_BLOCK_PATTERN)
|
||||||
RAW_BLOCK_PATTERN)
|
|
||||||
self.advance(match.end())
|
self.advance(match.end())
|
||||||
else:
|
else:
|
||||||
self.advance(match.end())
|
self.advance(match.end())
|
||||||
@@ -235,15 +239,13 @@ class TagIterator:
|
|||||||
block_type_name=block_type_name,
|
block_type_name=block_type_name,
|
||||||
block_name=block_name,
|
block_name=block_name,
|
||||||
start=start_pos,
|
start=start_pos,
|
||||||
end=self.pos
|
end=self.pos,
|
||||||
)
|
)
|
||||||
|
|
||||||
def find_tags(self):
|
def find_tags(self):
|
||||||
while True:
|
while True:
|
||||||
match = self._first_match(
|
match = self._first_match(
|
||||||
BLOCK_START_PATTERN,
|
BLOCK_START_PATTERN, COMMENT_START_PATTERN, EXPR_START_PATTERN
|
||||||
COMMENT_START_PATTERN,
|
|
||||||
EXPR_START_PATTERN
|
|
||||||
)
|
)
|
||||||
if match is None:
|
if match is None:
|
||||||
break
|
break
|
||||||
@@ -252,9 +254,9 @@ class TagIterator:
|
|||||||
# start = self.pos
|
# start = self.pos
|
||||||
|
|
||||||
groups = match.groupdict()
|
groups = match.groupdict()
|
||||||
comment_start = groups.get('comment_start')
|
comment_start = groups.get("comment_start")
|
||||||
expr_start = groups.get('expr_start')
|
expr_start = groups.get("expr_start")
|
||||||
block_type_name = groups.get('block_type_name')
|
block_type_name = groups.get("block_type_name")
|
||||||
|
|
||||||
if comment_start is not None:
|
if comment_start is not None:
|
||||||
self.handle_comment(match)
|
self.handle_comment(match)
|
||||||
@@ -264,8 +266,8 @@ class TagIterator:
|
|||||||
yield self.handle_tag(match)
|
yield self.handle_tag(match)
|
||||||
else:
|
else:
|
||||||
raise dbt.exceptions.InternalException(
|
raise dbt.exceptions.InternalException(
|
||||||
'Invalid regex match in next_block, expected block start, '
|
"Invalid regex match in next_block, expected block start, "
|
||||||
'expr start, or comment start'
|
"expr start, or comment start"
|
||||||
)
|
)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
@@ -273,21 +275,18 @@ class TagIterator:
|
|||||||
|
|
||||||
|
|
||||||
duplicate_tags = (
|
duplicate_tags = (
|
||||||
'Got nested tags: {outer.block_type_name} (started at {outer.start}) did '
|
"Got nested tags: {outer.block_type_name} (started at {outer.start}) did "
|
||||||
'not have a matching {{% end{outer.block_type_name} %}} before a '
|
"not have a matching {{% end{outer.block_type_name} %}} before a "
|
||||||
'subsequent {inner.block_type_name} was found (started at {inner.start})'
|
"subsequent {inner.block_type_name} was found (started at {inner.start})"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_CONTROL_FLOW_TAGS = {
|
_CONTROL_FLOW_TAGS = {
|
||||||
'if': 'endif',
|
"if": "endif",
|
||||||
'for': 'endfor',
|
"for": "endfor",
|
||||||
}
|
}
|
||||||
|
|
||||||
_CONTROL_FLOW_END_TAGS = {
|
_CONTROL_FLOW_END_TAGS = {v: k for k, v in _CONTROL_FLOW_TAGS.items()}
|
||||||
v: k
|
|
||||||
for k, v in _CONTROL_FLOW_TAGS.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class BlockIterator:
|
class BlockIterator:
|
||||||
@@ -310,15 +309,15 @@ class BlockIterator:
|
|||||||
|
|
||||||
def is_current_end(self, tag):
|
def is_current_end(self, tag):
|
||||||
return (
|
return (
|
||||||
tag.block_type_name.startswith('end') and
|
tag.block_type_name.startswith("end")
|
||||||
self.current is not None and
|
and self.current is not None
|
||||||
tag.block_type_name[3:] == self.current.block_type_name
|
and tag.block_type_name[3:] == self.current.block_type_name
|
||||||
)
|
)
|
||||||
|
|
||||||
def find_blocks(self, allowed_blocks=None, collect_raw_data=True):
|
def find_blocks(self, allowed_blocks=None, collect_raw_data=True):
|
||||||
"""Find all top-level blocks in the data."""
|
"""Find all top-level blocks in the data."""
|
||||||
if allowed_blocks is None:
|
if allowed_blocks is None:
|
||||||
allowed_blocks = {'snapshot', 'macro', 'materialization', 'docs'}
|
allowed_blocks = {"snapshot", "macro", "materialization", "docs"}
|
||||||
|
|
||||||
for tag in self.tag_parser.find_tags():
|
for tag in self.tag_parser.find_tags():
|
||||||
if tag.block_type_name in _CONTROL_FLOW_TAGS:
|
if tag.block_type_name in _CONTROL_FLOW_TAGS:
|
||||||
@@ -329,31 +328,37 @@ class BlockIterator:
|
|||||||
found = self.stack.pop()
|
found = self.stack.pop()
|
||||||
else:
|
else:
|
||||||
expected = _CONTROL_FLOW_END_TAGS[tag.block_type_name]
|
expected = _CONTROL_FLOW_END_TAGS[tag.block_type_name]
|
||||||
dbt.exceptions.raise_compiler_error((
|
dbt.exceptions.raise_compiler_error(
|
||||||
'Got an unexpected control flow end tag, got {} but '
|
(
|
||||||
'never saw a preceeding {} (@ {})'
|
"Got an unexpected control flow end tag, got {} but "
|
||||||
|
"never saw a preceeding {} (@ {})"
|
||||||
).format(
|
).format(
|
||||||
tag.block_type_name,
|
tag.block_type_name,
|
||||||
expected,
|
expected,
|
||||||
self.tag_parser.linepos(tag.start)
|
self.tag_parser.linepos(tag.start),
|
||||||
))
|
)
|
||||||
|
)
|
||||||
expected = _CONTROL_FLOW_TAGS[found]
|
expected = _CONTROL_FLOW_TAGS[found]
|
||||||
if expected != tag.block_type_name:
|
if expected != tag.block_type_name:
|
||||||
dbt.exceptions.raise_compiler_error((
|
dbt.exceptions.raise_compiler_error(
|
||||||
'Got an unexpected control flow end tag, got {} but '
|
(
|
||||||
'expected {} next (@ {})'
|
"Got an unexpected control flow end tag, got {} but "
|
||||||
|
"expected {} next (@ {})"
|
||||||
).format(
|
).format(
|
||||||
tag.block_type_name,
|
tag.block_type_name,
|
||||||
expected,
|
expected,
|
||||||
self.tag_parser.linepos(tag.start)
|
self.tag_parser.linepos(tag.start),
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if tag.block_type_name in allowed_blocks:
|
if tag.block_type_name in allowed_blocks:
|
||||||
if self.stack:
|
if self.stack:
|
||||||
dbt.exceptions.raise_compiler_error((
|
dbt.exceptions.raise_compiler_error(
|
||||||
'Got a block definition inside control flow at {}. '
|
(
|
||||||
'All dbt block definitions must be at the top level'
|
"Got a block definition inside control flow at {}. "
|
||||||
).format(self.tag_parser.linepos(tag.start)))
|
"All dbt block definitions must be at the top level"
|
||||||
|
).format(self.tag_parser.linepos(tag.start))
|
||||||
|
)
|
||||||
if self.current is not None:
|
if self.current is not None:
|
||||||
dbt.exceptions.raise_compiler_error(
|
dbt.exceptions.raise_compiler_error(
|
||||||
duplicate_tags.format(outer=self.current, inner=tag)
|
duplicate_tags.format(outer=self.current, inner=tag)
|
||||||
@@ -372,16 +377,18 @@ class BlockIterator:
|
|||||||
block_type_name=self.current.block_type_name,
|
block_type_name=self.current.block_type_name,
|
||||||
block_name=self.current.block_name,
|
block_name=self.current.block_name,
|
||||||
contents=self.data[self.current.end : tag.start],
|
contents=self.data[self.current.end : tag.start],
|
||||||
full_block=self.data[self.current.start:tag.end]
|
full_block=self.data[self.current.start : tag.end],
|
||||||
)
|
)
|
||||||
self.current = None
|
self.current = None
|
||||||
|
|
||||||
if self.current:
|
if self.current:
|
||||||
linecount = self.data[:self.current.end].count('\n') + 1
|
linecount = self.data[: self.current.end].count("\n") + 1
|
||||||
dbt.exceptions.raise_compiler_error((
|
dbt.exceptions.raise_compiler_error(
|
||||||
'Reached EOF without finding a close tag for '
|
(
|
||||||
'{} (searched from line {})'
|
"Reached EOF without finding a close tag for "
|
||||||
).format(self.current.block_type_name, linecount))
|
"{} (searched from line {})"
|
||||||
|
).format(self.current.block_type_name, linecount)
|
||||||
|
)
|
||||||
|
|
||||||
if collect_raw_data:
|
if collect_raw_data:
|
||||||
raw_data = self.data[self.last_position :]
|
raw_data = self.data[self.last_position :]
|
||||||
@@ -389,5 +396,8 @@ class BlockIterator:
|
|||||||
yield BlockData(raw_data)
|
yield BlockData(raw_data)
|
||||||
|
|
||||||
def lex_for_blocks(self, allowed_blocks=None, collect_raw_data=True):
|
def lex_for_blocks(self, allowed_blocks=None, collect_raw_data=True):
|
||||||
return list(self.find_blocks(allowed_blocks=allowed_blocks,
|
return list(
|
||||||
collect_raw_data=collect_raw_data))
|
self.find_blocks(
|
||||||
|
allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from typing import Iterable, List, Dict, Union, Optional, Any
|
|||||||
from dbt.exceptions import RuntimeException
|
from dbt.exceptions import RuntimeException
|
||||||
|
|
||||||
|
|
||||||
BOM = BOM_UTF8.decode('utf-8') # '\ufeff'
|
BOM = BOM_UTF8.decode("utf-8") # '\ufeff'
|
||||||
|
|
||||||
|
|
||||||
class ISODateTime(agate.data_types.DateTime):
|
class ISODateTime(agate.data_types.DateTime):
|
||||||
@@ -30,28 +30,23 @@ class ISODateTime(agate.data_types.DateTime):
|
|||||||
except: # noqa
|
except: # noqa
|
||||||
pass
|
pass
|
||||||
|
|
||||||
raise agate.exceptions.CastError(
|
raise agate.exceptions.CastError('Can not parse value "%s" as datetime.' % d)
|
||||||
'Can not parse value "%s" as datetime.' % d
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_type_tester(text_columns: Iterable[str]) -> agate.TypeTester:
|
def build_type_tester(text_columns: Iterable[str]) -> agate.TypeTester:
|
||||||
types = [
|
types = [
|
||||||
agate.data_types.Number(null_values=('null', '')),
|
agate.data_types.Number(null_values=("null", "")),
|
||||||
agate.data_types.Date(null_values=('null', ''),
|
agate.data_types.Date(null_values=("null", ""), date_format="%Y-%m-%d"),
|
||||||
date_format='%Y-%m-%d'),
|
agate.data_types.DateTime(
|
||||||
agate.data_types.DateTime(null_values=('null', ''),
|
null_values=("null", ""), datetime_format="%Y-%m-%d %H:%M:%S"
|
||||||
datetime_format='%Y-%m-%d %H:%M:%S'),
|
),
|
||||||
ISODateTime(null_values=('null', '')),
|
ISODateTime(null_values=("null", "")),
|
||||||
agate.data_types.Boolean(true_values=('true',),
|
agate.data_types.Boolean(
|
||||||
false_values=('false',),
|
true_values=("true",), false_values=("false",), null_values=("null", "")
|
||||||
null_values=('null', '')),
|
),
|
||||||
agate.data_types.Text(null_values=('null', ''))
|
agate.data_types.Text(null_values=("null", "")),
|
||||||
]
|
]
|
||||||
force = {
|
force = {k: agate.data_types.Text(null_values=("null", "")) for k in text_columns}
|
||||||
k: agate.data_types.Text(null_values=('null', ''))
|
|
||||||
for k in text_columns
|
|
||||||
}
|
|
||||||
return agate.TypeTester(force=force, types=types)
|
return agate.TypeTester(force=force, types=types)
|
||||||
|
|
||||||
|
|
||||||
@@ -115,7 +110,7 @@ def as_matrix(table):
|
|||||||
|
|
||||||
def from_csv(abspath, text_columns):
|
def from_csv(abspath, text_columns):
|
||||||
type_tester = build_type_tester(text_columns=text_columns)
|
type_tester = build_type_tester(text_columns=text_columns)
|
||||||
with open(abspath, encoding='utf-8') as fp:
|
with open(abspath, encoding="utf-8") as fp:
|
||||||
if fp.read(1) != BOM:
|
if fp.read(1) != BOM:
|
||||||
fp.seek(0)
|
fp.seek(0)
|
||||||
return agate.Table.from_csv(fp, column_types=type_tester)
|
return agate.Table.from_csv(fp, column_types=type_tester)
|
||||||
@@ -147,8 +142,8 @@ class ColumnTypeBuilder(Dict[str, NullableAgateType]):
|
|||||||
elif not isinstance(value, type(existing_type)):
|
elif not isinstance(value, type(existing_type)):
|
||||||
# actual type mismatch!
|
# actual type mismatch!
|
||||||
raise RuntimeException(
|
raise RuntimeException(
|
||||||
f'Tables contain columns with the same names ({key}), '
|
f"Tables contain columns with the same names ({key}), "
|
||||||
f'but different types ({value} vs {existing_type})'
|
f"but different types ({value} vs {existing_type})"
|
||||||
)
|
)
|
||||||
|
|
||||||
def finalize(self) -> Dict[str, agate.data_types.DataType]:
|
def finalize(self) -> Dict[str, agate.data_types.DataType]:
|
||||||
@@ -163,7 +158,7 @@ class ColumnTypeBuilder(Dict[str, NullableAgateType]):
|
|||||||
|
|
||||||
|
|
||||||
def _merged_column_types(
|
def _merged_column_types(
|
||||||
tables: List[agate.Table]
|
tables: List[agate.Table],
|
||||||
) -> Dict[str, agate.data_types.DataType]:
|
) -> Dict[str, agate.data_types.DataType]:
|
||||||
# this is a lot like agate.Table.merge, but with handling for all-null
|
# this is a lot like agate.Table.merge, but with handling for all-null
|
||||||
# rows being "any type".
|
# rows being "any type".
|
||||||
@@ -190,10 +185,7 @@ def merge_tables(tables: List[agate.Table]) -> agate.Table:
|
|||||||
|
|
||||||
rows: List[agate.Row] = []
|
rows: List[agate.Row] = []
|
||||||
for table in tables:
|
for table in tables:
|
||||||
if (
|
if table.column_names == column_names and table.column_types == column_types:
|
||||||
table.column_names == column_names and
|
|
||||||
table.column_types == column_types
|
|
||||||
):
|
|
||||||
rows.extend(table.rows)
|
rows.extend(table.rows)
|
||||||
else:
|
else:
|
||||||
for row in table.rows:
|
for row in table.rows:
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ https://cloud.google.com/sdk/
|
|||||||
|
|
||||||
def gcloud_installed():
|
def gcloud_installed():
|
||||||
try:
|
try:
|
||||||
run_cmd('.', ['gcloud', '--version'])
|
run_cmd(".", ["gcloud", "--version"])
|
||||||
return True
|
return True
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
logger.debug(e)
|
logger.debug(e)
|
||||||
@@ -21,6 +21,6 @@ def gcloud_installed():
|
|||||||
|
|
||||||
def setup_default_credentials():
|
def setup_default_credentials():
|
||||||
if gcloud_installed():
|
if gcloud_installed():
|
||||||
run_cmd('.', ["gcloud", "auth", "application-default", "login"])
|
run_cmd(".", ["gcloud", "auth", "application-default", "login"])
|
||||||
else:
|
else:
|
||||||
raise dbt.exceptions.RuntimeException(NOT_INSTALLED_MSG)
|
raise dbt.exceptions.RuntimeException(NOT_INSTALLED_MSG)
|
||||||
|
|||||||
@@ -7,77 +7,74 @@ import dbt.exceptions
|
|||||||
|
|
||||||
|
|
||||||
def clone(repo, cwd, dirname=None, remove_git_dir=False, branch=None):
|
def clone(repo, cwd, dirname=None, remove_git_dir=False, branch=None):
|
||||||
clone_cmd = ['git', 'clone', '--depth', '1']
|
clone_cmd = ["git", "clone", "--depth", "1"]
|
||||||
|
|
||||||
if branch is not None:
|
if branch is not None:
|
||||||
clone_cmd.extend(['--branch', branch])
|
clone_cmd.extend(["--branch", branch])
|
||||||
|
|
||||||
clone_cmd.append(repo)
|
clone_cmd.append(repo)
|
||||||
|
|
||||||
if dirname is not None:
|
if dirname is not None:
|
||||||
clone_cmd.append(dirname)
|
clone_cmd.append(dirname)
|
||||||
|
|
||||||
result = run_cmd(cwd, clone_cmd, env={'LC_ALL': 'C'})
|
result = run_cmd(cwd, clone_cmd, env={"LC_ALL": "C"})
|
||||||
|
|
||||||
if remove_git_dir:
|
if remove_git_dir:
|
||||||
rmdir(os.path.join(dirname, '.git'))
|
rmdir(os.path.join(dirname, ".git"))
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def list_tags(cwd):
|
def list_tags(cwd):
|
||||||
out, err = run_cmd(cwd, ['git', 'tag', '--list'], env={'LC_ALL': 'C'})
|
out, err = run_cmd(cwd, ["git", "tag", "--list"], env={"LC_ALL": "C"})
|
||||||
tags = out.decode('utf-8').strip().split("\n")
|
tags = out.decode("utf-8").strip().split("\n")
|
||||||
return tags
|
return tags
|
||||||
|
|
||||||
|
|
||||||
def _checkout(cwd, repo, branch):
|
def _checkout(cwd, repo, branch):
|
||||||
logger.debug(' Checking out branch {}.'.format(branch))
|
logger.debug(" Checking out branch {}.".format(branch))
|
||||||
|
|
||||||
run_cmd(cwd, ['git', 'remote', 'set-branches', 'origin', branch])
|
run_cmd(cwd, ["git", "remote", "set-branches", "origin", branch])
|
||||||
run_cmd(cwd, ['git', 'fetch', '--tags', '--depth', '1', 'origin', branch])
|
run_cmd(cwd, ["git", "fetch", "--tags", "--depth", "1", "origin", branch])
|
||||||
|
|
||||||
tags = list_tags(cwd)
|
tags = list_tags(cwd)
|
||||||
|
|
||||||
# Prefer tags to branches if one exists
|
# Prefer tags to branches if one exists
|
||||||
if branch in tags:
|
if branch in tags:
|
||||||
spec = 'tags/{}'.format(branch)
|
spec = "tags/{}".format(branch)
|
||||||
else:
|
else:
|
||||||
spec = 'origin/{}'.format(branch)
|
spec = "origin/{}".format(branch)
|
||||||
|
|
||||||
out, err = run_cmd(cwd, ['git', 'reset', '--hard', spec],
|
out, err = run_cmd(cwd, ["git", "reset", "--hard", spec], env={"LC_ALL": "C"})
|
||||||
env={'LC_ALL': 'C'})
|
|
||||||
return out, err
|
return out, err
|
||||||
|
|
||||||
|
|
||||||
def checkout(cwd, repo, branch=None):
|
def checkout(cwd, repo, branch=None):
|
||||||
if branch is None:
|
if branch is None:
|
||||||
branch = 'HEAD'
|
branch = "HEAD"
|
||||||
try:
|
try:
|
||||||
return _checkout(cwd, repo, branch)
|
return _checkout(cwd, repo, branch)
|
||||||
except dbt.exceptions.CommandResultError as exc:
|
except dbt.exceptions.CommandResultError as exc:
|
||||||
stderr = exc.stderr.decode('utf-8').strip()
|
stderr = exc.stderr.decode("utf-8").strip()
|
||||||
dbt.exceptions.bad_package_spec(repo, branch, stderr)
|
dbt.exceptions.bad_package_spec(repo, branch, stderr)
|
||||||
|
|
||||||
|
|
||||||
def get_current_sha(cwd):
|
def get_current_sha(cwd):
|
||||||
out, err = run_cmd(cwd, ['git', 'rev-parse', 'HEAD'], env={'LC_ALL': 'C'})
|
out, err = run_cmd(cwd, ["git", "rev-parse", "HEAD"], env={"LC_ALL": "C"})
|
||||||
|
|
||||||
return out.decode('utf-8')
|
return out.decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
def remove_remote(cwd):
|
def remove_remote(cwd):
|
||||||
return run_cmd(cwd, ['git', 'remote', 'rm', 'origin'], env={'LC_ALL': 'C'})
|
return run_cmd(cwd, ["git", "remote", "rm", "origin"], env={"LC_ALL": "C"})
|
||||||
|
|
||||||
|
|
||||||
def clone_and_checkout(repo, cwd, dirname=None, remove_git_dir=False,
|
def clone_and_checkout(repo, cwd, dirname=None, remove_git_dir=False, branch=None):
|
||||||
branch=None):
|
|
||||||
exists = None
|
exists = None
|
||||||
try:
|
try:
|
||||||
_, err = clone(repo, cwd, dirname=dirname,
|
_, err = clone(repo, cwd, dirname=dirname, remove_git_dir=remove_git_dir)
|
||||||
remove_git_dir=remove_git_dir)
|
|
||||||
except dbt.exceptions.CommandResultError as exc:
|
except dbt.exceptions.CommandResultError as exc:
|
||||||
err = exc.stderr.decode('utf-8')
|
err = exc.stderr.decode("utf-8")
|
||||||
exists = re.match("fatal: destination path '(.+)' already exists", err)
|
exists = re.match("fatal: destination path '(.+)' already exists", err)
|
||||||
if not exists: # something else is wrong, raise it
|
if not exists: # something else is wrong, raise it
|
||||||
raise
|
raise
|
||||||
@@ -86,25 +83,26 @@ def clone_and_checkout(repo, cwd, dirname=None, remove_git_dir=False,
|
|||||||
start_sha = None
|
start_sha = None
|
||||||
if exists:
|
if exists:
|
||||||
directory = exists.group(1)
|
directory = exists.group(1)
|
||||||
logger.debug('Updating existing dependency {}.', directory)
|
logger.debug("Updating existing dependency {}.", directory)
|
||||||
else:
|
else:
|
||||||
matches = re.match("Cloning into '(.+)'", err.decode('utf-8'))
|
matches = re.match("Cloning into '(.+)'", err.decode("utf-8"))
|
||||||
if matches is None:
|
if matches is None:
|
||||||
raise dbt.exceptions.RuntimeException(
|
raise dbt.exceptions.RuntimeException(
|
||||||
f'Error cloning {repo} - never saw "Cloning into ..." from git'
|
f'Error cloning {repo} - never saw "Cloning into ..." from git'
|
||||||
)
|
)
|
||||||
directory = matches.group(1)
|
directory = matches.group(1)
|
||||||
logger.debug('Pulling new dependency {}.', directory)
|
logger.debug("Pulling new dependency {}.", directory)
|
||||||
full_path = os.path.join(cwd, directory)
|
full_path = os.path.join(cwd, directory)
|
||||||
start_sha = get_current_sha(full_path)
|
start_sha = get_current_sha(full_path)
|
||||||
checkout(full_path, repo, branch)
|
checkout(full_path, repo, branch)
|
||||||
end_sha = get_current_sha(full_path)
|
end_sha = get_current_sha(full_path)
|
||||||
if exists:
|
if exists:
|
||||||
if start_sha == end_sha:
|
if start_sha == end_sha:
|
||||||
logger.debug(' Already at {}, nothing to do.', start_sha[:7])
|
logger.debug(" Already at {}, nothing to do.", start_sha[:7])
|
||||||
else:
|
else:
|
||||||
logger.debug(' Updated checkout from {} to {}.',
|
logger.debug(
|
||||||
start_sha[:7], end_sha[:7])
|
" Updated checkout from {} to {}.", start_sha[:7], end_sha[:7]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug(' Checked out at {}.', end_sha[:7])
|
logger.debug(" Checked out at {}.", end_sha[:7])
|
||||||
return directory
|
return directory
|
||||||
|
|||||||
@@ -8,8 +8,17 @@ from ast import literal_eval
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from itertools import chain, islice
|
from itertools import chain, islice
|
||||||
from typing import (
|
from typing import (
|
||||||
List, Union, Set, Optional, Dict, Any, Iterator, Type, NoReturn, Tuple,
|
List,
|
||||||
Callable
|
Union,
|
||||||
|
Set,
|
||||||
|
Optional,
|
||||||
|
Dict,
|
||||||
|
Any,
|
||||||
|
Iterator,
|
||||||
|
Type,
|
||||||
|
NoReturn,
|
||||||
|
Tuple,
|
||||||
|
Callable,
|
||||||
)
|
)
|
||||||
|
|
||||||
import jinja2
|
import jinja2
|
||||||
@@ -20,16 +29,22 @@ import jinja2.parser
|
|||||||
import jinja2.sandbox
|
import jinja2.sandbox
|
||||||
|
|
||||||
from dbt.utils import (
|
from dbt.utils import (
|
||||||
get_dbt_macro_name, get_docs_macro_name, get_materialization_macro_name,
|
get_dbt_macro_name,
|
||||||
deep_map
|
get_docs_macro_name,
|
||||||
|
get_materialization_macro_name,
|
||||||
|
deep_map,
|
||||||
)
|
)
|
||||||
|
|
||||||
from dbt.clients._jinja_blocks import BlockIterator, BlockData, BlockTag
|
from dbt.clients._jinja_blocks import BlockIterator, BlockData, BlockTag
|
||||||
from dbt.contracts.graph.compiled import CompiledSchemaTestNode
|
from dbt.contracts.graph.compiled import CompiledSchemaTestNode
|
||||||
from dbt.contracts.graph.parsed import ParsedSchemaTestNode
|
from dbt.contracts.graph.parsed import ParsedSchemaTestNode
|
||||||
from dbt.exceptions import (
|
from dbt.exceptions import (
|
||||||
InternalException, raise_compiler_error, CompilationException,
|
InternalException,
|
||||||
invalid_materialization_argument, MacroReturn, JinjaRenderingException
|
raise_compiler_error,
|
||||||
|
CompilationException,
|
||||||
|
invalid_materialization_argument,
|
||||||
|
MacroReturn,
|
||||||
|
JinjaRenderingException,
|
||||||
)
|
)
|
||||||
from dbt import flags
|
from dbt import flags
|
||||||
from dbt.logger import GLOBAL_LOGGER as logger # noqa
|
from dbt.logger import GLOBAL_LOGGER as logger # noqa
|
||||||
@@ -40,26 +55,26 @@ def _linecache_inject(source, write):
|
|||||||
# this is the only reliable way to accomplish this. Obviously, it's
|
# this is the only reliable way to accomplish this. Obviously, it's
|
||||||
# really darn noisy and will fill your temporary directory
|
# really darn noisy and will fill your temporary directory
|
||||||
tmp_file = tempfile.NamedTemporaryFile(
|
tmp_file = tempfile.NamedTemporaryFile(
|
||||||
prefix='dbt-macro-compiled-',
|
prefix="dbt-macro-compiled-",
|
||||||
suffix='.py',
|
suffix=".py",
|
||||||
delete=False,
|
delete=False,
|
||||||
mode='w+',
|
mode="w+",
|
||||||
encoding='utf-8',
|
encoding="utf-8",
|
||||||
)
|
)
|
||||||
tmp_file.write(source)
|
tmp_file.write(source)
|
||||||
filename = tmp_file.name
|
filename = tmp_file.name
|
||||||
else:
|
else:
|
||||||
# `codecs.encode` actually takes a `bytes` as the first argument if
|
# `codecs.encode` actually takes a `bytes` as the first argument if
|
||||||
# the second argument is 'hex' - mypy does not know this.
|
# the second argument is 'hex' - mypy does not know this.
|
||||||
rnd = codecs.encode(os.urandom(12), 'hex') # type: ignore
|
rnd = codecs.encode(os.urandom(12), "hex") # type: ignore
|
||||||
filename = rnd.decode('ascii')
|
filename = rnd.decode("ascii")
|
||||||
|
|
||||||
# put ourselves in the cache
|
# put ourselves in the cache
|
||||||
cache_entry = (
|
cache_entry = (
|
||||||
len(source),
|
len(source),
|
||||||
None,
|
None,
|
||||||
[line + '\n' for line in source.splitlines()],
|
[line + "\n" for line in source.splitlines()],
|
||||||
filename
|
filename,
|
||||||
)
|
)
|
||||||
# linecache does in fact have an attribute `cache`, thanks
|
# linecache does in fact have an attribute `cache`, thanks
|
||||||
linecache.cache[filename] = cache_entry # type: ignore
|
linecache.cache[filename] = cache_entry # type: ignore
|
||||||
@@ -73,12 +88,10 @@ class MacroFuzzParser(jinja2.parser.Parser):
|
|||||||
# modified to fuzz macros defined in the same file. this way
|
# modified to fuzz macros defined in the same file. this way
|
||||||
# dbt can understand the stack of macros being called.
|
# dbt can understand the stack of macros being called.
|
||||||
# - @cmcarthur
|
# - @cmcarthur
|
||||||
node.name = get_dbt_macro_name(
|
node.name = get_dbt_macro_name(self.parse_assign_target(name_only=True).name)
|
||||||
self.parse_assign_target(name_only=True).name)
|
|
||||||
|
|
||||||
self.parse_signature(node)
|
self.parse_signature(node)
|
||||||
node.body = self.parse_statements(('name:endmacro',),
|
node.body = self.parse_statements(("name:endmacro",), drop_needle=True)
|
||||||
drop_needle=True)
|
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
@@ -94,8 +107,8 @@ class MacroFuzzEnvironment(jinja2.sandbox.SandboxedEnvironment):
|
|||||||
If the value is 'write', also write the files to disk.
|
If the value is 'write', also write the files to disk.
|
||||||
WARNING: This can write a ton of data if you aren't careful.
|
WARNING: This can write a ton of data if you aren't careful.
|
||||||
"""
|
"""
|
||||||
if filename == '<template>' and flags.MACRO_DEBUGGING:
|
if filename == "<template>" and flags.MACRO_DEBUGGING:
|
||||||
write = flags.MACRO_DEBUGGING == 'write'
|
write = flags.MACRO_DEBUGGING == "write"
|
||||||
filename = _linecache_inject(source, write)
|
filename = _linecache_inject(source, write)
|
||||||
|
|
||||||
return super()._compile(source, filename) # type: ignore
|
return super()._compile(source, filename) # type: ignore
|
||||||
@@ -138,7 +151,7 @@ def quoted_native_concat(nodes):
|
|||||||
head = list(islice(nodes, 2))
|
head = list(islice(nodes, 2))
|
||||||
|
|
||||||
if not head:
|
if not head:
|
||||||
return ''
|
return ""
|
||||||
|
|
||||||
if len(head) == 1:
|
if len(head) == 1:
|
||||||
raw = head[0]
|
raw = head[0]
|
||||||
@@ -180,9 +193,7 @@ class NativeSandboxTemplate(jinja2.nativetypes.NativeTemplate): # mypy: ignore
|
|||||||
vars = dict(*args, **kwargs)
|
vars = dict(*args, **kwargs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return quoted_native_concat(
|
return quoted_native_concat(self.root_render_func(self.new_context(vars)))
|
||||||
self.root_render_func(self.new_context(vars))
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
return self.environment.handle_exception()
|
return self.environment.handle_exception()
|
||||||
|
|
||||||
@@ -221,10 +232,10 @@ class BaseMacroGenerator:
|
|||||||
self.context: Optional[Dict[str, Any]] = context
|
self.context: Optional[Dict[str, Any]] = context
|
||||||
|
|
||||||
def get_template(self):
|
def get_template(self):
|
||||||
raise NotImplementedError('get_template not implemented!')
|
raise NotImplementedError("get_template not implemented!")
|
||||||
|
|
||||||
def get_name(self) -> str:
|
def get_name(self) -> str:
|
||||||
raise NotImplementedError('get_name not implemented!')
|
raise NotImplementedError("get_name not implemented!")
|
||||||
|
|
||||||
def get_macro(self):
|
def get_macro(self):
|
||||||
name = self.get_name()
|
name = self.get_name()
|
||||||
@@ -247,9 +258,7 @@ class BaseMacroGenerator:
|
|||||||
def call_macro(self, *args, **kwargs):
|
def call_macro(self, *args, **kwargs):
|
||||||
# called from __call__ methods
|
# called from __call__ methods
|
||||||
if self.context is None:
|
if self.context is None:
|
||||||
raise InternalException(
|
raise InternalException("Context is still None in call_macro!")
|
||||||
'Context is still None in call_macro!'
|
|
||||||
)
|
|
||||||
assert self.context is not None
|
assert self.context is not None
|
||||||
|
|
||||||
macro = self.get_macro()
|
macro = self.get_macro()
|
||||||
@@ -276,7 +285,7 @@ class MacroStack(threading.local):
|
|||||||
def pop(self, name):
|
def pop(self, name):
|
||||||
got = self.call_stack.pop()
|
got = self.call_stack.pop()
|
||||||
if got != name:
|
if got != name:
|
||||||
raise InternalException(f'popped {got}, expected {name}')
|
raise InternalException(f"popped {got}, expected {name}")
|
||||||
|
|
||||||
|
|
||||||
class MacroGenerator(BaseMacroGenerator):
|
class MacroGenerator(BaseMacroGenerator):
|
||||||
@@ -285,7 +294,7 @@ class MacroGenerator(BaseMacroGenerator):
|
|||||||
macro,
|
macro,
|
||||||
context: Optional[Dict[str, Any]] = None,
|
context: Optional[Dict[str, Any]] = None,
|
||||||
node: Optional[Any] = None,
|
node: Optional[Any] = None,
|
||||||
stack: Optional[MacroStack] = None
|
stack: Optional[MacroStack] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(context)
|
super().__init__(context)
|
||||||
self.macro = macro
|
self.macro = macro
|
||||||
@@ -333,9 +342,7 @@ class MacroGenerator(BaseMacroGenerator):
|
|||||||
|
|
||||||
|
|
||||||
class QueryStringGenerator(BaseMacroGenerator):
|
class QueryStringGenerator(BaseMacroGenerator):
|
||||||
def __init__(
|
def __init__(self, template_str: str, context: Dict[str, Any]) -> None:
|
||||||
self, template_str: str, context: Dict[str, Any]
|
|
||||||
) -> None:
|
|
||||||
super().__init__(context)
|
super().__init__(context)
|
||||||
self.template_str: str = template_str
|
self.template_str: str = template_str
|
||||||
env = get_environment()
|
env = get_environment()
|
||||||
@@ -345,7 +352,7 @@ class QueryStringGenerator(BaseMacroGenerator):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_name(self) -> str:
|
def get_name(self) -> str:
|
||||||
return 'query_comment_macro'
|
return "query_comment_macro"
|
||||||
|
|
||||||
def get_template(self):
|
def get_template(self):
|
||||||
"""Don't use the template cache, we don't have a node"""
|
"""Don't use the template cache, we don't have a node"""
|
||||||
@@ -356,45 +363,41 @@ class QueryStringGenerator(BaseMacroGenerator):
|
|||||||
|
|
||||||
|
|
||||||
class MaterializationExtension(jinja2.ext.Extension):
|
class MaterializationExtension(jinja2.ext.Extension):
|
||||||
tags = ['materialization']
|
tags = ["materialization"]
|
||||||
|
|
||||||
def parse(self, parser):
|
def parse(self, parser):
|
||||||
node = jinja2.nodes.Macro(lineno=next(parser.stream).lineno)
|
node = jinja2.nodes.Macro(lineno=next(parser.stream).lineno)
|
||||||
materialization_name = \
|
materialization_name = parser.parse_assign_target(name_only=True).name
|
||||||
parser.parse_assign_target(name_only=True).name
|
|
||||||
|
|
||||||
adapter_name = 'default'
|
adapter_name = "default"
|
||||||
node.args = []
|
node.args = []
|
||||||
node.defaults = []
|
node.defaults = []
|
||||||
|
|
||||||
while parser.stream.skip_if('comma'):
|
while parser.stream.skip_if("comma"):
|
||||||
target = parser.parse_assign_target(name_only=True)
|
target = parser.parse_assign_target(name_only=True)
|
||||||
|
|
||||||
if target.name == 'default':
|
if target.name == "default":
|
||||||
pass
|
pass
|
||||||
|
|
||||||
elif target.name == 'adapter':
|
elif target.name == "adapter":
|
||||||
parser.stream.expect('assign')
|
parser.stream.expect("assign")
|
||||||
value = parser.parse_expression()
|
value = parser.parse_expression()
|
||||||
adapter_name = value.value
|
adapter_name = value.value
|
||||||
|
|
||||||
else:
|
else:
|
||||||
invalid_materialization_argument(
|
invalid_materialization_argument(materialization_name, target.name)
|
||||||
materialization_name, target.name
|
|
||||||
)
|
|
||||||
|
|
||||||
node.name = get_materialization_macro_name(
|
node.name = get_materialization_macro_name(materialization_name, adapter_name)
|
||||||
materialization_name, adapter_name
|
|
||||||
)
|
|
||||||
|
|
||||||
node.body = parser.parse_statements(('name:endmaterialization',),
|
node.body = parser.parse_statements(
|
||||||
drop_needle=True)
|
("name:endmaterialization",), drop_needle=True
|
||||||
|
)
|
||||||
|
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
class DocumentationExtension(jinja2.ext.Extension):
|
class DocumentationExtension(jinja2.ext.Extension):
|
||||||
tags = ['docs']
|
tags = ["docs"]
|
||||||
|
|
||||||
def parse(self, parser):
|
def parse(self, parser):
|
||||||
node = jinja2.nodes.Macro(lineno=next(parser.stream).lineno)
|
node = jinja2.nodes.Macro(lineno=next(parser.stream).lineno)
|
||||||
@@ -403,13 +406,12 @@ class DocumentationExtension(jinja2.ext.Extension):
|
|||||||
node.args = []
|
node.args = []
|
||||||
node.defaults = []
|
node.defaults = []
|
||||||
node.name = get_docs_macro_name(docs_name)
|
node.name = get_docs_macro_name(docs_name)
|
||||||
node.body = parser.parse_statements(('name:enddocs',),
|
node.body = parser.parse_statements(("name:enddocs",), drop_needle=True)
|
||||||
drop_needle=True)
|
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
def _is_dunder_name(name):
|
def _is_dunder_name(name):
|
||||||
return name.startswith('__') and name.endswith('__')
|
return name.startswith("__") and name.endswith("__")
|
||||||
|
|
||||||
|
|
||||||
def create_undefined(node=None):
|
def create_undefined(node=None):
|
||||||
@@ -430,10 +432,11 @@ def create_undefined(node=None):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
if name == 'name' or _is_dunder_name(name):
|
if name == "name" or _is_dunder_name(name):
|
||||||
raise AttributeError(
|
raise AttributeError(
|
||||||
"'{}' object has no attribute '{}'"
|
"'{}' object has no attribute '{}'".format(
|
||||||
.format(type(self).__name__, name)
|
type(self).__name__, name
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.name = name
|
self.name = name
|
||||||
@@ -444,24 +447,24 @@ def create_undefined(node=None):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
raise_compiler_error(f'{self.name} is undefined', node=node)
|
raise_compiler_error(f"{self.name} is undefined", node=node)
|
||||||
|
|
||||||
return Undefined
|
return Undefined
|
||||||
|
|
||||||
|
|
||||||
NATIVE_FILTERS: Dict[str, Callable[[Any], Any]] = {
|
NATIVE_FILTERS: Dict[str, Callable[[Any], Any]] = {
|
||||||
'as_text': TextMarker,
|
"as_text": TextMarker,
|
||||||
'as_bool': BoolMarker,
|
"as_bool": BoolMarker,
|
||||||
'as_native': NativeMarker,
|
"as_native": NativeMarker,
|
||||||
'as_number': NumberMarker,
|
"as_number": NumberMarker,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
TEXT_FILTERS: Dict[str, Callable[[Any], Any]] = {
|
TEXT_FILTERS: Dict[str, Callable[[Any], Any]] = {
|
||||||
'as_text': lambda x: x,
|
"as_text": lambda x: x,
|
||||||
'as_bool': lambda x: x,
|
"as_bool": lambda x: x,
|
||||||
'as_native': lambda x: x,
|
"as_native": lambda x: x,
|
||||||
'as_number': lambda x: x,
|
"as_number": lambda x: x,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -471,14 +474,14 @@ def get_environment(
|
|||||||
native: bool = False,
|
native: bool = False,
|
||||||
) -> jinja2.Environment:
|
) -> jinja2.Environment:
|
||||||
args: Dict[str, List[Union[str, Type[jinja2.ext.Extension]]]] = {
|
args: Dict[str, List[Union[str, Type[jinja2.ext.Extension]]]] = {
|
||||||
'extensions': ['jinja2.ext.do']
|
"extensions": ["jinja2.ext.do"]
|
||||||
}
|
}
|
||||||
|
|
||||||
if capture_macros:
|
if capture_macros:
|
||||||
args['undefined'] = create_undefined(node)
|
args["undefined"] = create_undefined(node)
|
||||||
|
|
||||||
args['extensions'].append(MaterializationExtension)
|
args["extensions"].append(MaterializationExtension)
|
||||||
args['extensions'].append(DocumentationExtension)
|
args["extensions"].append(DocumentationExtension)
|
||||||
|
|
||||||
env_cls: Type[jinja2.Environment]
|
env_cls: Type[jinja2.Environment]
|
||||||
text_filter: Type
|
text_filter: Type
|
||||||
@@ -541,8 +544,8 @@ def _requote_result(raw_value: str, rendered: str) -> str:
|
|||||||
elif single_quoted:
|
elif single_quoted:
|
||||||
quote_char = "'"
|
quote_char = "'"
|
||||||
else:
|
else:
|
||||||
quote_char = ''
|
quote_char = ""
|
||||||
return f'{quote_char}{rendered}{quote_char}'
|
return f"{quote_char}{rendered}{quote_char}"
|
||||||
|
|
||||||
|
|
||||||
# performance note: Local benmcharking (so take it with a big grain of salt!)
|
# performance note: Local benmcharking (so take it with a big grain of salt!)
|
||||||
@@ -550,7 +553,7 @@ def _requote_result(raw_value: str, rendered: str) -> str:
|
|||||||
# checking two separate patterns, but the standard deviation is smaller with
|
# checking two separate patterns, but the standard deviation is smaller with
|
||||||
# one pattern. The time difference between the two was ~2 std deviations, which
|
# one pattern. The time difference between the two was ~2 std deviations, which
|
||||||
# is small enough that I've just chosen the more readable option.
|
# is small enough that I've just chosen the more readable option.
|
||||||
_HAS_RENDER_CHARS_PAT = re.compile(r'({[{%#]|[#}%]})')
|
_HAS_RENDER_CHARS_PAT = re.compile(r"({[{%#]|[#}%]})")
|
||||||
|
|
||||||
|
|
||||||
def get_rendered(
|
def get_rendered(
|
||||||
@@ -567,9 +570,9 @@ def get_rendered(
|
|||||||
# native=True case by passing the input string to ast.literal_eval, like
|
# native=True case by passing the input string to ast.literal_eval, like
|
||||||
# the native renderer does.
|
# the native renderer does.
|
||||||
if (
|
if (
|
||||||
not native and
|
not native
|
||||||
isinstance(string, str) and
|
and isinstance(string, str)
|
||||||
_HAS_RENDER_CHARS_PAT.search(string) is None
|
and _HAS_RENDER_CHARS_PAT.search(string) is None
|
||||||
):
|
):
|
||||||
return string
|
return string
|
||||||
template = get_template(
|
template = get_template(
|
||||||
@@ -606,12 +609,11 @@ def extract_toplevel_blocks(
|
|||||||
`collect_raw_data` is `True`) `BlockData` objects.
|
`collect_raw_data` is `True`) `BlockData` objects.
|
||||||
"""
|
"""
|
||||||
return BlockIterator(data).lex_for_blocks(
|
return BlockIterator(data).lex_for_blocks(
|
||||||
allowed_blocks=allowed_blocks,
|
allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data
|
||||||
collect_raw_data=collect_raw_data
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
SCHEMA_TEST_KWARGS_NAME = '_dbt_schema_test_kwargs'
|
SCHEMA_TEST_KWARGS_NAME = "_dbt_schema_test_kwargs"
|
||||||
|
|
||||||
|
|
||||||
def add_rendered_test_kwargs(
|
def add_rendered_test_kwargs(
|
||||||
@@ -623,24 +625,21 @@ def add_rendered_test_kwargs(
|
|||||||
renderer, then insert that value into the given context as the special test
|
renderer, then insert that value into the given context as the special test
|
||||||
keyword arguments member.
|
keyword arguments member.
|
||||||
"""
|
"""
|
||||||
looks_like_func = r'^\s*(env_var|ref|var|source|doc)\s*\(.+\)\s*$'
|
looks_like_func = r"^\s*(env_var|ref|var|source|doc)\s*\(.+\)\s*$"
|
||||||
|
|
||||||
def _convert_function(
|
def _convert_function(value: Any, keypath: Tuple[Union[str, int], ...]) -> Any:
|
||||||
value: Any, keypath: Tuple[Union[str, int], ...]
|
|
||||||
) -> Any:
|
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
if keypath == ('column_name',):
|
if keypath == ("column_name",):
|
||||||
# special case: Don't render column names as native, make them
|
# special case: Don't render column names as native, make them
|
||||||
# be strings
|
# be strings
|
||||||
return value
|
return value
|
||||||
|
|
||||||
if re.match(looks_like_func, value) is not None:
|
if re.match(looks_like_func, value) is not None:
|
||||||
# curly braces to make rendering happy
|
# curly braces to make rendering happy
|
||||||
value = f'{{{{ {value} }}}}'
|
value = f"{{{{ {value} }}}}"
|
||||||
|
|
||||||
value = get_rendered(
|
value = get_rendered(
|
||||||
value, context, node, capture_macros=capture_macros,
|
value, context, node, capture_macros=capture_macros, native=True
|
||||||
native=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|||||||
@@ -6,17 +6,17 @@ from dbt.logger import GLOBAL_LOGGER as logger
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
if os.getenv('DBT_PACKAGE_HUB_URL'):
|
if os.getenv("DBT_PACKAGE_HUB_URL"):
|
||||||
DEFAULT_REGISTRY_BASE_URL = os.getenv('DBT_PACKAGE_HUB_URL')
|
DEFAULT_REGISTRY_BASE_URL = os.getenv("DBT_PACKAGE_HUB_URL")
|
||||||
else:
|
else:
|
||||||
DEFAULT_REGISTRY_BASE_URL = 'https://hub.getdbt.com/'
|
DEFAULT_REGISTRY_BASE_URL = "https://hub.getdbt.com/"
|
||||||
|
|
||||||
|
|
||||||
def _get_url(url, registry_base_url=None):
|
def _get_url(url, registry_base_url=None):
|
||||||
if registry_base_url is None:
|
if registry_base_url is None:
|
||||||
registry_base_url = DEFAULT_REGISTRY_BASE_URL
|
registry_base_url = DEFAULT_REGISTRY_BASE_URL
|
||||||
|
|
||||||
return '{}{}'.format(registry_base_url, url)
|
return "{}{}".format(registry_base_url, url)
|
||||||
|
|
||||||
|
|
||||||
def _wrap_exceptions(fn):
|
def _wrap_exceptions(fn):
|
||||||
@@ -33,42 +33,40 @@ def _wrap_exceptions(fn):
|
|||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
raise RegistryException(
|
raise RegistryException("Unable to connect to registry hub") from exc
|
||||||
'Unable to connect to registry hub'
|
|
||||||
) from exc
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
@_wrap_exceptions
|
@_wrap_exceptions
|
||||||
def _get(path, registry_base_url=None):
|
def _get(path, registry_base_url=None):
|
||||||
url = _get_url(path, registry_base_url)
|
url = _get_url(path, registry_base_url)
|
||||||
logger.debug('Making package registry request: GET {}'.format(url))
|
logger.debug("Making package registry request: GET {}".format(url))
|
||||||
resp = requests.get(url)
|
resp = requests.get(url)
|
||||||
logger.debug('Response from registry: GET {} {}'.format(url,
|
logger.debug("Response from registry: GET {} {}".format(url, resp.status_code))
|
||||||
resp.status_code))
|
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
return resp.json()
|
return resp.json()
|
||||||
|
|
||||||
|
|
||||||
def index(registry_base_url=None):
|
def index(registry_base_url=None):
|
||||||
return _get('api/v1/index.json', registry_base_url)
|
return _get("api/v1/index.json", registry_base_url)
|
||||||
|
|
||||||
|
|
||||||
index_cached = memoized(index)
|
index_cached = memoized(index)
|
||||||
|
|
||||||
|
|
||||||
def packages(registry_base_url=None):
|
def packages(registry_base_url=None):
|
||||||
return _get('api/v1/packages.json', registry_base_url)
|
return _get("api/v1/packages.json", registry_base_url)
|
||||||
|
|
||||||
|
|
||||||
def package(name, registry_base_url=None):
|
def package(name, registry_base_url=None):
|
||||||
return _get('api/v1/{}.json'.format(name), registry_base_url)
|
return _get("api/v1/{}.json".format(name), registry_base_url)
|
||||||
|
|
||||||
|
|
||||||
def package_version(name, version, registry_base_url=None):
|
def package_version(name, version, registry_base_url=None):
|
||||||
return _get('api/v1/{}/{}.json'.format(name, version), registry_base_url)
|
return _get("api/v1/{}/{}.json".format(name, version), registry_base_url)
|
||||||
|
|
||||||
|
|
||||||
def get_available_versions(name):
|
def get_available_versions(name):
|
||||||
response = package(name)
|
response = package(name)
|
||||||
return list(response['versions'])
|
return list(response["versions"])
|
||||||
|
|||||||
@@ -10,16 +10,14 @@ import sys
|
|||||||
import tarfile
|
import tarfile
|
||||||
import requests
|
import requests
|
||||||
import stat
|
import stat
|
||||||
from typing import (
|
from typing import Type, NoReturn, List, Optional, Dict, Any, Tuple, Callable, Union
|
||||||
Type, NoReturn, List, Optional, Dict, Any, Tuple, Callable, Union
|
|
||||||
)
|
|
||||||
|
|
||||||
import dbt.exceptions
|
import dbt.exceptions
|
||||||
import dbt.utils
|
import dbt.utils
|
||||||
|
|
||||||
from dbt.logger import GLOBAL_LOGGER as logger
|
from dbt.logger import GLOBAL_LOGGER as logger
|
||||||
|
|
||||||
if sys.platform == 'win32':
|
if sys.platform == "win32":
|
||||||
from ctypes import WinDLL, c_bool
|
from ctypes import WinDLL, c_bool
|
||||||
else:
|
else:
|
||||||
WinDLL = None
|
WinDLL = None
|
||||||
@@ -51,30 +49,29 @@ def find_matching(
|
|||||||
reobj = re.compile(regex, re.IGNORECASE)
|
reobj = re.compile(regex, re.IGNORECASE)
|
||||||
|
|
||||||
for relative_path_to_search in relative_paths_to_search:
|
for relative_path_to_search in relative_paths_to_search:
|
||||||
absolute_path_to_search = os.path.join(
|
absolute_path_to_search = os.path.join(root_path, relative_path_to_search)
|
||||||
root_path, relative_path_to_search)
|
|
||||||
walk_results = os.walk(absolute_path_to_search)
|
walk_results = os.walk(absolute_path_to_search)
|
||||||
|
|
||||||
for current_path, subdirectories, local_files in walk_results:
|
for current_path, subdirectories, local_files in walk_results:
|
||||||
for local_file in local_files:
|
for local_file in local_files:
|
||||||
absolute_path = os.path.join(current_path, local_file)
|
absolute_path = os.path.join(current_path, local_file)
|
||||||
relative_path = os.path.relpath(
|
relative_path = os.path.relpath(absolute_path, absolute_path_to_search)
|
||||||
absolute_path, absolute_path_to_search
|
|
||||||
)
|
|
||||||
if reobj.match(local_file):
|
if reobj.match(local_file):
|
||||||
matching.append({
|
matching.append(
|
||||||
'searched_path': relative_path_to_search,
|
{
|
||||||
'absolute_path': absolute_path,
|
"searched_path": relative_path_to_search,
|
||||||
'relative_path': relative_path,
|
"absolute_path": absolute_path,
|
||||||
})
|
"relative_path": relative_path,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return matching
|
return matching
|
||||||
|
|
||||||
|
|
||||||
def load_file_contents(path: str, strip: bool = True) -> str:
|
def load_file_contents(path: str, strip: bool = True) -> str:
|
||||||
path = convert_path(path)
|
path = convert_path(path)
|
||||||
with open(path, 'rb') as handle:
|
with open(path, "rb") as handle:
|
||||||
to_return = handle.read().decode('utf-8')
|
to_return = handle.read().decode("utf-8")
|
||||||
|
|
||||||
if strip:
|
if strip:
|
||||||
to_return = to_return.strip()
|
to_return = to_return.strip()
|
||||||
@@ -101,14 +98,14 @@ def make_directory(path: str) -> None:
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def make_file(path: str, contents: str = '', overwrite: bool = False) -> bool:
|
def make_file(path: str, contents: str = "", overwrite: bool = False) -> bool:
|
||||||
"""
|
"""
|
||||||
Make a file at `path` assuming that the directory it resides in already
|
Make a file at `path` assuming that the directory it resides in already
|
||||||
exists. The file is saved with contents `contents`
|
exists. The file is saved with contents `contents`
|
||||||
"""
|
"""
|
||||||
if overwrite or not os.path.exists(path):
|
if overwrite or not os.path.exists(path):
|
||||||
path = convert_path(path)
|
path = convert_path(path)
|
||||||
with open(path, 'w') as fh:
|
with open(path, "w") as fh:
|
||||||
fh.write(contents)
|
fh.write(contents)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -120,7 +117,7 @@ def make_symlink(source: str, link_path: str) -> None:
|
|||||||
Create a symlink at `link_path` referring to `source`.
|
Create a symlink at `link_path` referring to `source`.
|
||||||
"""
|
"""
|
||||||
if not supports_symlinks():
|
if not supports_symlinks():
|
||||||
dbt.exceptions.system_error('create a symbolic link')
|
dbt.exceptions.system_error("create a symbolic link")
|
||||||
|
|
||||||
os.symlink(source, link_path)
|
os.symlink(source, link_path)
|
||||||
|
|
||||||
@@ -129,11 +126,11 @@ def supports_symlinks() -> bool:
|
|||||||
return getattr(os, "symlink", None) is not None
|
return getattr(os, "symlink", None) is not None
|
||||||
|
|
||||||
|
|
||||||
def write_file(path: str, contents: str = '') -> bool:
|
def write_file(path: str, contents: str = "") -> bool:
|
||||||
path = convert_path(path)
|
path = convert_path(path)
|
||||||
try:
|
try:
|
||||||
make_directory(os.path.dirname(path))
|
make_directory(os.path.dirname(path))
|
||||||
with open(path, 'w', encoding='utf-8') as f:
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
f.write(str(contents))
|
f.write(str(contents))
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
# note that you can't just catch FileNotFound, because sometimes
|
# note that you can't just catch FileNotFound, because sometimes
|
||||||
@@ -142,20 +139,20 @@ def write_file(path: str, contents: str = '') -> bool:
|
|||||||
# sometimes windows fails to write paths that are less than the length
|
# sometimes windows fails to write paths that are less than the length
|
||||||
# limit. So on windows, suppress all errors that happen from writing
|
# limit. So on windows, suppress all errors that happen from writing
|
||||||
# to disk.
|
# to disk.
|
||||||
if os.name == 'nt':
|
if os.name == "nt":
|
||||||
# sometimes we get a winerror of 3 which means the path was
|
# sometimes we get a winerror of 3 which means the path was
|
||||||
# definitely too long, but other times we don't and it means the
|
# definitely too long, but other times we don't and it means the
|
||||||
# path was just probably too long. This is probably based on the
|
# path was just probably too long. This is probably based on the
|
||||||
# windows/python version.
|
# windows/python version.
|
||||||
if getattr(exc, 'winerror', 0) == 3:
|
if getattr(exc, "winerror", 0) == 3:
|
||||||
reason = 'Path was too long'
|
reason = "Path was too long"
|
||||||
else:
|
else:
|
||||||
reason = 'Path was possibly too long'
|
reason = "Path was possibly too long"
|
||||||
# all our hard work and the path was still too long. Log and
|
# all our hard work and the path was still too long. Log and
|
||||||
# continue.
|
# continue.
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f'Could not write to path {path}({len(path)} characters): '
|
f"Could not write to path {path}({len(path)} characters): "
|
||||||
f'{reason}\nexception: {exc}'
|
f"{reason}\nexception: {exc}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
@@ -189,10 +186,7 @@ def resolve_path_from_base(path_to_resolve: str, base_path: str) -> str:
|
|||||||
If path_to_resolve is an absolute path or a user path (~), just
|
If path_to_resolve is an absolute path or a user path (~), just
|
||||||
resolve it to an absolute path and return.
|
resolve it to an absolute path and return.
|
||||||
"""
|
"""
|
||||||
return os.path.abspath(
|
return os.path.abspath(os.path.join(base_path, os.path.expanduser(path_to_resolve)))
|
||||||
os.path.join(
|
|
||||||
base_path,
|
|
||||||
os.path.expanduser(path_to_resolve)))
|
|
||||||
|
|
||||||
|
|
||||||
def rmdir(path: str) -> None:
|
def rmdir(path: str) -> None:
|
||||||
@@ -202,7 +196,7 @@ def rmdir(path: str) -> None:
|
|||||||
cloned via git) can cause rmtree to throw a PermissionError exception
|
cloned via git) can cause rmtree to throw a PermissionError exception
|
||||||
"""
|
"""
|
||||||
path = convert_path(path)
|
path = convert_path(path)
|
||||||
if sys.platform == 'win32':
|
if sys.platform == "win32":
|
||||||
onerror = _windows_rmdir_readonly
|
onerror = _windows_rmdir_readonly
|
||||||
else:
|
else:
|
||||||
onerror = None
|
onerror = None
|
||||||
@@ -221,7 +215,7 @@ def _win_prepare_path(path: str) -> str:
|
|||||||
# letter back in.
|
# letter back in.
|
||||||
# Unless it starts with '\\'. In that case, the path is a UNC mount point
|
# Unless it starts with '\\'. In that case, the path is a UNC mount point
|
||||||
# and splitdrive will be fine.
|
# and splitdrive will be fine.
|
||||||
if not path.startswith('\\\\') and path.startswith('\\'):
|
if not path.startswith("\\\\") and path.startswith("\\"):
|
||||||
curdrive = os.path.splitdrive(os.getcwd())[0]
|
curdrive = os.path.splitdrive(os.getcwd())[0]
|
||||||
path = curdrive + path
|
path = curdrive + path
|
||||||
|
|
||||||
@@ -236,7 +230,7 @@ def _win_prepare_path(path: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def _supports_long_paths() -> bool:
|
def _supports_long_paths() -> bool:
|
||||||
if sys.platform != 'win32':
|
if sys.platform != "win32":
|
||||||
return True
|
return True
|
||||||
# Eryk Sun says to use `WinDLL('ntdll')` instead of `windll.ntdll` because
|
# Eryk Sun says to use `WinDLL('ntdll')` instead of `windll.ntdll` because
|
||||||
# of pointer caching in a comment here:
|
# of pointer caching in a comment here:
|
||||||
@@ -244,11 +238,11 @@ def _supports_long_paths() -> bool:
|
|||||||
# I don't know exaclty what he means, but I am inclined to believe him as
|
# I don't know exaclty what he means, but I am inclined to believe him as
|
||||||
# he's pretty active on Python windows bugs!
|
# he's pretty active on Python windows bugs!
|
||||||
try:
|
try:
|
||||||
dll = WinDLL('ntdll')
|
dll = WinDLL("ntdll")
|
||||||
except OSError: # I don't think this happens? you need ntdll to run python
|
except OSError: # I don't think this happens? you need ntdll to run python
|
||||||
return False
|
return False
|
||||||
# not all windows versions have it at all
|
# not all windows versions have it at all
|
||||||
if not hasattr(dll, 'RtlAreLongPathsEnabled'):
|
if not hasattr(dll, "RtlAreLongPathsEnabled"):
|
||||||
return False
|
return False
|
||||||
# tell windows we want to get back a single unsigned byte (a bool).
|
# tell windows we want to get back a single unsigned byte (a bool).
|
||||||
dll.RtlAreLongPathsEnabled.restype = c_bool
|
dll.RtlAreLongPathsEnabled.restype = c_bool
|
||||||
@@ -268,7 +262,7 @@ def convert_path(path: str) -> str:
|
|||||||
if _supports_long_paths():
|
if _supports_long_paths():
|
||||||
return path
|
return path
|
||||||
|
|
||||||
prefix = '\\\\?\\'
|
prefix = "\\\\?\\"
|
||||||
# Nothing to do
|
# Nothing to do
|
||||||
if path.startswith(prefix):
|
if path.startswith(prefix):
|
||||||
return path
|
return path
|
||||||
@@ -299,39 +293,35 @@ def path_is_symlink(path: str) -> bool:
|
|||||||
|
|
||||||
def open_dir_cmd() -> str:
|
def open_dir_cmd() -> str:
|
||||||
# https://docs.python.org/2/library/sys.html#sys.platform
|
# https://docs.python.org/2/library/sys.html#sys.platform
|
||||||
if sys.platform == 'win32':
|
if sys.platform == "win32":
|
||||||
return 'start'
|
return "start"
|
||||||
|
|
||||||
elif sys.platform == 'darwin':
|
elif sys.platform == "darwin":
|
||||||
return 'open'
|
return "open"
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return 'xdg-open'
|
return "xdg-open"
|
||||||
|
|
||||||
|
|
||||||
def _handle_posix_cwd_error(
|
def _handle_posix_cwd_error(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn:
|
||||||
exc: OSError, cwd: str, cmd: List[str]
|
|
||||||
) -> NoReturn:
|
|
||||||
if exc.errno == errno.ENOENT:
|
if exc.errno == errno.ENOENT:
|
||||||
message = 'Directory does not exist'
|
message = "Directory does not exist"
|
||||||
elif exc.errno == errno.EACCES:
|
elif exc.errno == errno.EACCES:
|
||||||
message = 'Current user cannot access directory, check permissions'
|
message = "Current user cannot access directory, check permissions"
|
||||||
elif exc.errno == errno.ENOTDIR:
|
elif exc.errno == errno.ENOTDIR:
|
||||||
message = 'Not a directory'
|
message = "Not a directory"
|
||||||
else:
|
else:
|
||||||
message = 'Unknown OSError: {} - cwd'.format(str(exc))
|
message = "Unknown OSError: {} - cwd".format(str(exc))
|
||||||
raise dbt.exceptions.WorkingDirectoryError(cwd, cmd, message)
|
raise dbt.exceptions.WorkingDirectoryError(cwd, cmd, message)
|
||||||
|
|
||||||
|
|
||||||
def _handle_posix_cmd_error(
|
def _handle_posix_cmd_error(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn:
|
||||||
exc: OSError, cwd: str, cmd: List[str]
|
|
||||||
) -> NoReturn:
|
|
||||||
if exc.errno == errno.ENOENT:
|
if exc.errno == errno.ENOENT:
|
||||||
message = "Could not find command, ensure it is in the user's PATH"
|
message = "Could not find command, ensure it is in the user's PATH"
|
||||||
elif exc.errno == errno.EACCES:
|
elif exc.errno == errno.EACCES:
|
||||||
message = 'User does not have permissions for this command'
|
message = "User does not have permissions for this command"
|
||||||
else:
|
else:
|
||||||
message = 'Unknown OSError: {} - cmd'.format(str(exc))
|
message = "Unknown OSError: {} - cmd".format(str(exc))
|
||||||
raise dbt.exceptions.ExecutableError(cwd, cmd, message)
|
raise dbt.exceptions.ExecutableError(cwd, cmd, message)
|
||||||
|
|
||||||
|
|
||||||
@@ -356,7 +346,7 @@ def _handle_posix_error(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn:
|
|||||||
- exc.errno == EACCES
|
- exc.errno == EACCES
|
||||||
- exc.filename == None(?)
|
- exc.filename == None(?)
|
||||||
"""
|
"""
|
||||||
if getattr(exc, 'filename', None) == cwd:
|
if getattr(exc, "filename", None) == cwd:
|
||||||
_handle_posix_cwd_error(exc, cwd, cmd)
|
_handle_posix_cwd_error(exc, cwd, cmd)
|
||||||
else:
|
else:
|
||||||
_handle_posix_cmd_error(exc, cwd, cmd)
|
_handle_posix_cmd_error(exc, cwd, cmd)
|
||||||
@@ -365,46 +355,48 @@ def _handle_posix_error(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn:
|
|||||||
def _handle_windows_error(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn:
|
def _handle_windows_error(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn:
|
||||||
cls: Type[dbt.exceptions.Exception] = dbt.exceptions.CommandError
|
cls: Type[dbt.exceptions.Exception] = dbt.exceptions.CommandError
|
||||||
if exc.errno == errno.ENOENT:
|
if exc.errno == errno.ENOENT:
|
||||||
message = ("Could not find command, ensure it is in the user's PATH "
|
message = (
|
||||||
"and that the user has permissions to run it")
|
"Could not find command, ensure it is in the user's PATH "
|
||||||
|
"and that the user has permissions to run it"
|
||||||
|
)
|
||||||
cls = dbt.exceptions.ExecutableError
|
cls = dbt.exceptions.ExecutableError
|
||||||
elif exc.errno == errno.ENOEXEC:
|
elif exc.errno == errno.ENOEXEC:
|
||||||
message = ('Command was not executable, ensure it is valid')
|
message = "Command was not executable, ensure it is valid"
|
||||||
cls = dbt.exceptions.ExecutableError
|
cls = dbt.exceptions.ExecutableError
|
||||||
elif exc.errno == errno.ENOTDIR:
|
elif exc.errno == errno.ENOTDIR:
|
||||||
message = ('Unable to cd: path does not exist, user does not have'
|
message = (
|
||||||
' permissions, or not a directory')
|
"Unable to cd: path does not exist, user does not have"
|
||||||
|
" permissions, or not a directory"
|
||||||
|
)
|
||||||
cls = dbt.exceptions.WorkingDirectoryError
|
cls = dbt.exceptions.WorkingDirectoryError
|
||||||
else:
|
else:
|
||||||
message = 'Unknown error: {} (errno={}: "{}")'.format(
|
message = 'Unknown error: {} (errno={}: "{}")'.format(
|
||||||
str(exc), exc.errno, errno.errorcode.get(exc.errno, '<Unknown!>')
|
str(exc), exc.errno, errno.errorcode.get(exc.errno, "<Unknown!>")
|
||||||
)
|
)
|
||||||
raise cls(cwd, cmd, message)
|
raise cls(cwd, cmd, message)
|
||||||
|
|
||||||
|
|
||||||
def _interpret_oserror(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn:
|
def _interpret_oserror(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn:
|
||||||
"""Interpret an OSError exc and raise the appropriate dbt exception.
|
"""Interpret an OSError exc and raise the appropriate dbt exception."""
|
||||||
|
|
||||||
"""
|
|
||||||
if len(cmd) == 0:
|
if len(cmd) == 0:
|
||||||
raise dbt.exceptions.CommandError(cwd, cmd)
|
raise dbt.exceptions.CommandError(cwd, cmd)
|
||||||
|
|
||||||
# all of these functions raise unconditionally
|
# all of these functions raise unconditionally
|
||||||
if os.name == 'nt':
|
if os.name == "nt":
|
||||||
_handle_windows_error(exc, cwd, cmd)
|
_handle_windows_error(exc, cwd, cmd)
|
||||||
else:
|
else:
|
||||||
_handle_posix_error(exc, cwd, cmd)
|
_handle_posix_error(exc, cwd, cmd)
|
||||||
|
|
||||||
# this should not be reachable, raise _something_ at least!
|
# this should not be reachable, raise _something_ at least!
|
||||||
raise dbt.exceptions.InternalException(
|
raise dbt.exceptions.InternalException(
|
||||||
'Unhandled exception in _interpret_oserror: {}'.format(exc)
|
"Unhandled exception in _interpret_oserror: {}".format(exc)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_cmd(
|
def run_cmd(
|
||||||
cwd: str, cmd: List[str], env: Optional[Dict[str, Any]] = None
|
cwd: str, cmd: List[str], env: Optional[Dict[str, Any]] = None
|
||||||
) -> Tuple[bytes, bytes]:
|
) -> Tuple[bytes, bytes]:
|
||||||
logger.debug('Executing "{}"'.format(' '.join(cmd)))
|
logger.debug('Executing "{}"'.format(" ".join(cmd)))
|
||||||
if len(cmd) == 0:
|
if len(cmd) == 0:
|
||||||
raise dbt.exceptions.CommandError(cwd, cmd)
|
raise dbt.exceptions.CommandError(cwd, cmd)
|
||||||
|
|
||||||
@@ -417,11 +409,8 @@ def run_cmd(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
proc = subprocess.Popen(
|
proc = subprocess.Popen(
|
||||||
cmd,
|
cmd, cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=full_env
|
||||||
cwd=cwd,
|
)
|
||||||
stdout=subprocess.PIPE,
|
|
||||||
stderr=subprocess.PIPE,
|
|
||||||
env=full_env)
|
|
||||||
|
|
||||||
out, err = proc.communicate()
|
out, err = proc.communicate()
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
@@ -431,9 +420,8 @@ def run_cmd(
|
|||||||
logger.debug('STDERR: "{!s}"'.format(err))
|
logger.debug('STDERR: "{!s}"'.format(err))
|
||||||
|
|
||||||
if proc.returncode != 0:
|
if proc.returncode != 0:
|
||||||
logger.debug('command return code={}'.format(proc.returncode))
|
logger.debug("command return code={}".format(proc.returncode))
|
||||||
raise dbt.exceptions.CommandResultError(cwd, cmd, proc.returncode,
|
raise dbt.exceptions.CommandResultError(cwd, cmd, proc.returncode, out, err)
|
||||||
out, err)
|
|
||||||
|
|
||||||
return out, err
|
return out, err
|
||||||
|
|
||||||
@@ -442,9 +430,9 @@ def download(
|
|||||||
url: str, path: str, timeout: Optional[Union[float, tuple]] = None
|
url: str, path: str, timeout: Optional[Union[float, tuple]] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
path = convert_path(path)
|
path = convert_path(path)
|
||||||
connection_timeout = timeout or float(os.getenv('DBT_HTTP_TIMEOUT', 10))
|
connection_timeout = timeout or float(os.getenv("DBT_HTTP_TIMEOUT", 10))
|
||||||
response = requests.get(url, timeout=connection_timeout)
|
response = requests.get(url, timeout=connection_timeout)
|
||||||
with open(path, 'wb') as handle:
|
with open(path, "wb") as handle:
|
||||||
for block in response.iter_content(1024 * 64):
|
for block in response.iter_content(1024 * 64):
|
||||||
handle.write(block)
|
handle.write(block)
|
||||||
|
|
||||||
@@ -468,7 +456,7 @@ def untar_package(
|
|||||||
) -> None:
|
) -> None:
|
||||||
tar_path = convert_path(tar_path)
|
tar_path = convert_path(tar_path)
|
||||||
tar_dir_name = None
|
tar_dir_name = None
|
||||||
with tarfile.open(tar_path, 'r') as tarball:
|
with tarfile.open(tar_path, "r") as tarball:
|
||||||
tarball.extractall(dest_dir)
|
tarball.extractall(dest_dir)
|
||||||
tar_dir_name = os.path.commonprefix(tarball.getnames())
|
tar_dir_name = os.path.commonprefix(tarball.getnames())
|
||||||
if rename_to:
|
if rename_to:
|
||||||
@@ -484,7 +472,7 @@ def chmod_and_retry(func, path, exc_info):
|
|||||||
We want to retry most operations here, but listdir is one that we know will
|
We want to retry most operations here, but listdir is one that we know will
|
||||||
be useless.
|
be useless.
|
||||||
"""
|
"""
|
||||||
if func is os.listdir or os.name != 'nt':
|
if func is os.listdir or os.name != "nt":
|
||||||
raise
|
raise
|
||||||
os.chmod(path, stat.S_IREAD | stat.S_IWRITE)
|
os.chmod(path, stat.S_IREAD | stat.S_IWRITE)
|
||||||
# on error,this will raise.
|
# on error,this will raise.
|
||||||
@@ -505,7 +493,7 @@ def move(src, dst):
|
|||||||
"""
|
"""
|
||||||
src = convert_path(src)
|
src = convert_path(src)
|
||||||
dst = convert_path(dst)
|
dst = convert_path(dst)
|
||||||
if os.name != 'nt':
|
if os.name != "nt":
|
||||||
return shutil.move(src, dst)
|
return shutil.move(src, dst)
|
||||||
|
|
||||||
if os.path.isdir(dst):
|
if os.path.isdir(dst):
|
||||||
@@ -513,7 +501,7 @@ def move(src, dst):
|
|||||||
os.rename(src, dst)
|
os.rename(src, dst)
|
||||||
return
|
return
|
||||||
|
|
||||||
dst = os.path.join(dst, os.path.basename(src.rstrip('/\\')))
|
dst = os.path.join(dst, os.path.basename(src.rstrip("/\\")))
|
||||||
if os.path.exists(dst):
|
if os.path.exists(dst):
|
||||||
raise EnvironmentError("Path '{}' already exists".format(dst))
|
raise EnvironmentError("Path '{}' already exists".format(dst))
|
||||||
|
|
||||||
@@ -522,11 +510,10 @@ def move(src, dst):
|
|||||||
except OSError:
|
except OSError:
|
||||||
# probably different drives
|
# probably different drives
|
||||||
if os.path.isdir(src):
|
if os.path.isdir(src):
|
||||||
if _absnorm(dst + '\\').startswith(_absnorm(src + '\\')):
|
if _absnorm(dst + "\\").startswith(_absnorm(src + "\\")):
|
||||||
# dst is inside src
|
# dst is inside src
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
"Cannot move a directory '{}' into itself '{}'"
|
"Cannot move a directory '{}' into itself '{}'".format(src, dst)
|
||||||
.format(src, dst)
|
|
||||||
)
|
)
|
||||||
shutil.copytree(src, dst, symlinks=True)
|
shutil.copytree(src, dst, symlinks=True)
|
||||||
rmtree(src)
|
rmtree(src)
|
||||||
|
|||||||
@@ -5,15 +5,9 @@ import yaml.scanner
|
|||||||
|
|
||||||
# the C version is faster, but it doesn't always exist
|
# the C version is faster, but it doesn't always exist
|
||||||
try:
|
try:
|
||||||
from yaml import (
|
from yaml import CLoader as Loader, CSafeLoader as SafeLoader, CDumper as Dumper
|
||||||
CLoader as Loader,
|
|
||||||
CSafeLoader as SafeLoader,
|
|
||||||
CDumper as Dumper
|
|
||||||
)
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from yaml import ( # type: ignore # noqa: F401
|
from yaml import Loader, SafeLoader, Dumper # type: ignore # noqa: F401
|
||||||
Loader, SafeLoader, Dumper
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
YAML_ERROR_MESSAGE = """
|
YAML_ERROR_MESSAGE = """
|
||||||
@@ -33,14 +27,14 @@ def line_no(i, line, width=3):
|
|||||||
|
|
||||||
|
|
||||||
def prefix_with_line_numbers(string, no_start, no_end):
|
def prefix_with_line_numbers(string, no_start, no_end):
|
||||||
line_list = string.split('\n')
|
line_list = string.split("\n")
|
||||||
|
|
||||||
numbers = range(no_start, no_end)
|
numbers = range(no_start, no_end)
|
||||||
relevant_lines = line_list[no_start:no_end]
|
relevant_lines = line_list[no_start:no_end]
|
||||||
|
|
||||||
return "\n".join([
|
return "\n".join(
|
||||||
line_no(i + 1, line) for (i, line) in zip(numbers, relevant_lines)
|
[line_no(i + 1, line) for (i, line) in zip(numbers, relevant_lines)]
|
||||||
])
|
)
|
||||||
|
|
||||||
|
|
||||||
def contextualized_yaml_error(raw_contents, error):
|
def contextualized_yaml_error(raw_contents, error):
|
||||||
@@ -51,9 +45,9 @@ def contextualized_yaml_error(raw_contents, error):
|
|||||||
|
|
||||||
nice_error = prefix_with_line_numbers(raw_contents, min_line, max_line)
|
nice_error = prefix_with_line_numbers(raw_contents, min_line, max_line)
|
||||||
|
|
||||||
return YAML_ERROR_MESSAGE.format(line_number=mark.line + 1,
|
return YAML_ERROR_MESSAGE.format(
|
||||||
nice_error=nice_error,
|
line_number=mark.line + 1, nice_error=nice_error, raw_error=error
|
||||||
raw_error=error)
|
)
|
||||||
|
|
||||||
|
|
||||||
def safe_load(contents):
|
def safe_load(contents):
|
||||||
@@ -64,7 +58,7 @@ def load_yaml_text(contents):
|
|||||||
try:
|
try:
|
||||||
return safe_load(contents)
|
return safe_load(contents)
|
||||||
except (yaml.scanner.ScannerError, yaml.YAMLError) as e:
|
except (yaml.scanner.ScannerError, yaml.YAMLError) as e:
|
||||||
if hasattr(e, 'problem_mark'):
|
if hasattr(e, "problem_mark"):
|
||||||
error = contextualized_yaml_error(contents, e)
|
error = contextualized_yaml_error(contents, e)
|
||||||
else:
|
else:
|
||||||
error = str(e)
|
error = str(e)
|
||||||
|
|||||||
@@ -32,28 +32,28 @@ from dbt.node_types import NodeType
|
|||||||
from dbt.utils import pluralize
|
from dbt.utils import pluralize
|
||||||
import dbt.tracking
|
import dbt.tracking
|
||||||
|
|
||||||
graph_file_name = 'graph.gpickle'
|
graph_file_name = "graph.gpickle"
|
||||||
|
|
||||||
|
|
||||||
def _compiled_type_for(model: ParsedNode):
|
def _compiled_type_for(model: ParsedNode):
|
||||||
if type(model) not in COMPILED_TYPES:
|
if type(model) not in COMPILED_TYPES:
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
f'Asked to compile {type(model)} node, but it has no compiled form'
|
f"Asked to compile {type(model)} node, but it has no compiled form"
|
||||||
)
|
)
|
||||||
return COMPILED_TYPES[type(model)]
|
return COMPILED_TYPES[type(model)]
|
||||||
|
|
||||||
|
|
||||||
def print_compile_stats(stats):
|
def print_compile_stats(stats):
|
||||||
names = {
|
names = {
|
||||||
NodeType.Model: 'model',
|
NodeType.Model: "model",
|
||||||
NodeType.Test: 'test',
|
NodeType.Test: "test",
|
||||||
NodeType.Snapshot: 'snapshot',
|
NodeType.Snapshot: "snapshot",
|
||||||
NodeType.Analysis: 'analysis',
|
NodeType.Analysis: "analysis",
|
||||||
NodeType.Macro: 'macro',
|
NodeType.Macro: "macro",
|
||||||
NodeType.Operation: 'operation',
|
NodeType.Operation: "operation",
|
||||||
NodeType.Seed: 'seed file',
|
NodeType.Seed: "seed file",
|
||||||
NodeType.Source: 'source',
|
NodeType.Source: "source",
|
||||||
NodeType.Exposure: 'exposure',
|
NodeType.Exposure: "exposure",
|
||||||
}
|
}
|
||||||
|
|
||||||
results = {k: 0 for k in names.keys()}
|
results = {k: 0 for k in names.keys()}
|
||||||
@@ -64,10 +64,9 @@ def print_compile_stats(stats):
|
|||||||
resource_counts = {k.pluralize(): v for k, v in results.items()}
|
resource_counts = {k.pluralize(): v for k, v in results.items()}
|
||||||
dbt.tracking.track_resource_counts(resource_counts)
|
dbt.tracking.track_resource_counts(resource_counts)
|
||||||
|
|
||||||
stat_line = ", ".join([
|
stat_line = ", ".join(
|
||||||
pluralize(ct, names.get(t)) for t, ct in results.items()
|
[pluralize(ct, names.get(t)) for t, ct in results.items() if t in names]
|
||||||
if t in names
|
)
|
||||||
])
|
|
||||||
|
|
||||||
logger.info("Found {}".format(stat_line))
|
logger.info("Found {}".format(stat_line))
|
||||||
|
|
||||||
@@ -166,9 +165,7 @@ class Compiler:
|
|||||||
extra_context: Dict[str, Any],
|
extra_context: Dict[str, Any],
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
|
|
||||||
context = generate_runtime_model(
|
context = generate_runtime_model(node, self.config, manifest)
|
||||||
node, self.config, manifest
|
|
||||||
)
|
|
||||||
context.update(extra_context)
|
context.update(extra_context)
|
||||||
if isinstance(node, CompiledSchemaTestNode):
|
if isinstance(node, CompiledSchemaTestNode):
|
||||||
# for test nodes, add a special keyword args value to the context
|
# for test nodes, add a special keyword args value to the context
|
||||||
@@ -183,8 +180,7 @@ class Compiler:
|
|||||||
|
|
||||||
def _get_relation_name(self, node: ParsedNode):
|
def _get_relation_name(self, node: ParsedNode):
|
||||||
relation_name = None
|
relation_name = None
|
||||||
if (node.resource_type in NodeType.refable() and
|
if node.resource_type in NodeType.refable() and not node.is_ephemeral_model:
|
||||||
not node.is_ephemeral_model):
|
|
||||||
adapter = get_adapter(self.config)
|
adapter = get_adapter(self.config)
|
||||||
relation_cls = adapter.Relation
|
relation_cls = adapter.Relation
|
||||||
relation_name = str(relation_cls.create_from(self.config, node))
|
relation_name = str(relation_cls.create_from(self.config, node))
|
||||||
@@ -227,32 +223,29 @@ class Compiler:
|
|||||||
|
|
||||||
with_stmt = None
|
with_stmt = None
|
||||||
for token in parsed.tokens:
|
for token in parsed.tokens:
|
||||||
if token.is_keyword and token.normalized == 'WITH':
|
if token.is_keyword and token.normalized == "WITH":
|
||||||
with_stmt = token
|
with_stmt = token
|
||||||
break
|
break
|
||||||
|
|
||||||
if with_stmt is None:
|
if with_stmt is None:
|
||||||
# no with stmt, add one, and inject CTEs right at the beginning
|
# no with stmt, add one, and inject CTEs right at the beginning
|
||||||
first_token = parsed.token_first()
|
first_token = parsed.token_first()
|
||||||
with_stmt = sqlparse.sql.Token(sqlparse.tokens.Keyword, 'with')
|
with_stmt = sqlparse.sql.Token(sqlparse.tokens.Keyword, "with")
|
||||||
parsed.insert_before(first_token, with_stmt)
|
parsed.insert_before(first_token, with_stmt)
|
||||||
else:
|
else:
|
||||||
# stmt exists, add a comma (which will come after injected CTEs)
|
# stmt exists, add a comma (which will come after injected CTEs)
|
||||||
trailing_comma = sqlparse.sql.Token(
|
trailing_comma = sqlparse.sql.Token(sqlparse.tokens.Punctuation, ",")
|
||||||
sqlparse.tokens.Punctuation, ','
|
|
||||||
)
|
|
||||||
parsed.insert_after(with_stmt, trailing_comma)
|
parsed.insert_after(with_stmt, trailing_comma)
|
||||||
|
|
||||||
token = sqlparse.sql.Token(
|
token = sqlparse.sql.Token(
|
||||||
sqlparse.tokens.Keyword,
|
sqlparse.tokens.Keyword, ", ".join(c.sql for c in ctes)
|
||||||
", ".join(c.sql for c in ctes)
|
|
||||||
)
|
)
|
||||||
parsed.insert_after(with_stmt, token)
|
parsed.insert_after(with_stmt, token)
|
||||||
|
|
||||||
return str(parsed)
|
return str(parsed)
|
||||||
|
|
||||||
def _get_dbt_test_name(self) -> str:
|
def _get_dbt_test_name(self) -> str:
|
||||||
return 'dbt__cte__internal_test'
|
return "dbt__cte__internal_test"
|
||||||
|
|
||||||
# This method is called by the 'compile_node' method. Starting
|
# This method is called by the 'compile_node' method. Starting
|
||||||
# from the node that it is passed in, it will recursively call
|
# from the node that it is passed in, it will recursively call
|
||||||
@@ -268,9 +261,7 @@ class Compiler:
|
|||||||
) -> Tuple[NonSourceCompiledNode, List[InjectedCTE]]:
|
) -> Tuple[NonSourceCompiledNode, List[InjectedCTE]]:
|
||||||
|
|
||||||
if model.compiled_sql is None:
|
if model.compiled_sql is None:
|
||||||
raise RuntimeException(
|
raise RuntimeException("Cannot inject ctes into an unparsed node", model)
|
||||||
'Cannot inject ctes into an unparsed node', model
|
|
||||||
)
|
|
||||||
if model.extra_ctes_injected:
|
if model.extra_ctes_injected:
|
||||||
return (model, model.extra_ctes)
|
return (model, model.extra_ctes)
|
||||||
|
|
||||||
@@ -296,19 +287,18 @@ class Compiler:
|
|||||||
else:
|
else:
|
||||||
if cte.id not in manifest.nodes:
|
if cte.id not in manifest.nodes:
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
f'During compilation, found a cte reference that '
|
f"During compilation, found a cte reference that "
|
||||||
f'could not be resolved: {cte.id}'
|
f"could not be resolved: {cte.id}"
|
||||||
)
|
)
|
||||||
cte_model = manifest.nodes[cte.id]
|
cte_model = manifest.nodes[cte.id]
|
||||||
|
|
||||||
if not cte_model.is_ephemeral_model:
|
if not cte_model.is_ephemeral_model:
|
||||||
raise InternalException(f'{cte.id} is not ephemeral')
|
raise InternalException(f"{cte.id} is not ephemeral")
|
||||||
|
|
||||||
# This model has already been compiled, so it's been
|
# This model has already been compiled, so it's been
|
||||||
# through here before
|
# through here before
|
||||||
if getattr(cte_model, 'compiled', False):
|
if getattr(cte_model, "compiled", False):
|
||||||
assert isinstance(cte_model,
|
assert isinstance(cte_model, tuple(COMPILED_TYPES.values()))
|
||||||
tuple(COMPILED_TYPES.values()))
|
|
||||||
cte_model = cast(NonSourceCompiledNode, cte_model)
|
cte_model = cast(NonSourceCompiledNode, cte_model)
|
||||||
new_prepended_ctes = cte_model.extra_ctes
|
new_prepended_ctes = cte_model.extra_ctes
|
||||||
|
|
||||||
@@ -316,11 +306,9 @@ class Compiler:
|
|||||||
else:
|
else:
|
||||||
# This is an ephemeral parsed model that we can compile.
|
# This is an ephemeral parsed model that we can compile.
|
||||||
# Compile and update the node
|
# Compile and update the node
|
||||||
cte_model = self._compile_node(
|
cte_model = self._compile_node(cte_model, manifest, extra_context)
|
||||||
cte_model, manifest, extra_context)
|
|
||||||
# recursively call this method
|
# recursively call this method
|
||||||
cte_model, new_prepended_ctes = \
|
cte_model, new_prepended_ctes = self._recursively_prepend_ctes(
|
||||||
self._recursively_prepend_ctes(
|
|
||||||
cte_model, manifest, extra_context
|
cte_model, manifest, extra_context
|
||||||
)
|
)
|
||||||
# Save compiled SQL file and sync manifest
|
# Save compiled SQL file and sync manifest
|
||||||
@@ -330,7 +318,7 @@ class Compiler:
|
|||||||
_extend_prepended_ctes(prepended_ctes, new_prepended_ctes)
|
_extend_prepended_ctes(prepended_ctes, new_prepended_ctes)
|
||||||
|
|
||||||
new_cte_name = self.add_ephemeral_prefix(cte_model.name)
|
new_cte_name = self.add_ephemeral_prefix(cte_model.name)
|
||||||
sql = f' {new_cte_name} as (\n{cte_model.compiled_sql}\n)'
|
sql = f" {new_cte_name} as (\n{cte_model.compiled_sql}\n)"
|
||||||
|
|
||||||
_add_prepended_cte(prepended_ctes, InjectedCTE(id=cte.id, sql=sql))
|
_add_prepended_cte(prepended_ctes, InjectedCTE(id=cte.id, sql=sql))
|
||||||
|
|
||||||
@@ -371,11 +359,10 @@ class Compiler:
|
|||||||
# compiled_sql, and do the regular prepend logic from CTEs.
|
# compiled_sql, and do the regular prepend logic from CTEs.
|
||||||
name = self._get_dbt_test_name()
|
name = self._get_dbt_test_name()
|
||||||
cte = InjectedCTE(
|
cte = InjectedCTE(
|
||||||
id=name,
|
id=name, sql=f" {name} as (\n{compiled_node.compiled_sql}\n)"
|
||||||
sql=f' {name} as (\n{compiled_node.compiled_sql}\n)'
|
|
||||||
)
|
)
|
||||||
compiled_node.extra_ctes.append(cte)
|
compiled_node.extra_ctes.append(cte)
|
||||||
compiled_node.compiled_sql = f'\nselect count(*) from {name}'
|
compiled_node.compiled_sql = f"\nselect count(*) from {name}"
|
||||||
|
|
||||||
return compiled_node
|
return compiled_node
|
||||||
|
|
||||||
@@ -395,17 +382,17 @@ class Compiler:
|
|||||||
logger.debug("Compiling {}".format(node.unique_id))
|
logger.debug("Compiling {}".format(node.unique_id))
|
||||||
|
|
||||||
data = node.to_dict(omit_none=True)
|
data = node.to_dict(omit_none=True)
|
||||||
data.update({
|
data.update(
|
||||||
'compiled': False,
|
{
|
||||||
'compiled_sql': None,
|
"compiled": False,
|
||||||
'extra_ctes_injected': False,
|
"compiled_sql": None,
|
||||||
'extra_ctes': [],
|
"extra_ctes_injected": False,
|
||||||
})
|
"extra_ctes": [],
|
||||||
|
}
|
||||||
|
)
|
||||||
compiled_node = _compiled_type_for(node).from_dict(data)
|
compiled_node = _compiled_type_for(node).from_dict(data)
|
||||||
|
|
||||||
context = self._create_node_context(
|
context = self._create_node_context(compiled_node, manifest, extra_context)
|
||||||
compiled_node, manifest, extra_context
|
|
||||||
)
|
|
||||||
|
|
||||||
compiled_node.compiled_sql = jinja.get_rendered(
|
compiled_node.compiled_sql = jinja.get_rendered(
|
||||||
node.raw_sql,
|
node.raw_sql,
|
||||||
@@ -419,9 +406,7 @@ class Compiler:
|
|||||||
|
|
||||||
# add ctes for specific test nodes, and also for
|
# add ctes for specific test nodes, and also for
|
||||||
# possible future use in adapters
|
# possible future use in adapters
|
||||||
compiled_node = self._add_ctes(
|
compiled_node = self._add_ctes(compiled_node, manifest, extra_context)
|
||||||
compiled_node, manifest, extra_context
|
|
||||||
)
|
|
||||||
|
|
||||||
return compiled_node
|
return compiled_node
|
||||||
|
|
||||||
@@ -431,21 +416,17 @@ class Compiler:
|
|||||||
if flags.WRITE_JSON:
|
if flags.WRITE_JSON:
|
||||||
linker.write_graph(graph_path, manifest)
|
linker.write_graph(graph_path, manifest)
|
||||||
|
|
||||||
def link_node(
|
def link_node(self, linker: Linker, node: GraphMemberNode, manifest: Manifest):
|
||||||
self, linker: Linker, node: GraphMemberNode, manifest: Manifest
|
|
||||||
):
|
|
||||||
linker.add_node(node.unique_id)
|
linker.add_node(node.unique_id)
|
||||||
|
|
||||||
for dependency in node.depends_on_nodes:
|
for dependency in node.depends_on_nodes:
|
||||||
if dependency in manifest.nodes:
|
if dependency in manifest.nodes:
|
||||||
linker.dependency(
|
linker.dependency(
|
||||||
node.unique_id,
|
node.unique_id, (manifest.nodes[dependency].unique_id)
|
||||||
(manifest.nodes[dependency].unique_id)
|
|
||||||
)
|
)
|
||||||
elif dependency in manifest.sources:
|
elif dependency in manifest.sources:
|
||||||
linker.dependency(
|
linker.dependency(
|
||||||
node.unique_id,
|
node.unique_id, (manifest.sources[dependency].unique_id)
|
||||||
(manifest.sources[dependency].unique_id)
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
dependency_not_found(node, dependency)
|
dependency_not_found(node, dependency)
|
||||||
@@ -480,16 +461,13 @@ class Compiler:
|
|||||||
|
|
||||||
# writes the "compiled_sql" into the target/compiled directory
|
# writes the "compiled_sql" into the target/compiled directory
|
||||||
def _write_node(self, node: NonSourceCompiledNode) -> ManifestNode:
|
def _write_node(self, node: NonSourceCompiledNode) -> ManifestNode:
|
||||||
if (not node.extra_ctes_injected or
|
if not node.extra_ctes_injected or node.resource_type == NodeType.Snapshot:
|
||||||
node.resource_type == NodeType.Snapshot):
|
|
||||||
return node
|
return node
|
||||||
logger.debug(f'Writing injected SQL for node "{node.unique_id}"')
|
logger.debug(f'Writing injected SQL for node "{node.unique_id}"')
|
||||||
|
|
||||||
if node.compiled_sql:
|
if node.compiled_sql:
|
||||||
node.build_path = node.write_node(
|
node.build_path = node.write_node(
|
||||||
self.config.target_path,
|
self.config.target_path, "compiled", node.compiled_sql
|
||||||
'compiled',
|
|
||||||
node.compiled_sql
|
|
||||||
)
|
)
|
||||||
return node
|
return node
|
||||||
|
|
||||||
@@ -507,9 +485,7 @@ class Compiler:
|
|||||||
) -> NonSourceCompiledNode:
|
) -> NonSourceCompiledNode:
|
||||||
node = self._compile_node(node, manifest, extra_context)
|
node = self._compile_node(node, manifest, extra_context)
|
||||||
|
|
||||||
node, _ = self._recursively_prepend_ctes(
|
node, _ = self._recursively_prepend_ctes(node, manifest, extra_context)
|
||||||
node, manifest, extra_context
|
|
||||||
)
|
|
||||||
if write:
|
if write:
|
||||||
self._write_node(node)
|
self._write_node(node)
|
||||||
return node
|
return node
|
||||||
|
|||||||
@@ -20,10 +20,8 @@ from dbt.utils import coerce_dict_str
|
|||||||
from .renderer import ProfileRenderer
|
from .renderer import ProfileRenderer
|
||||||
|
|
||||||
DEFAULT_THREADS = 1
|
DEFAULT_THREADS = 1
|
||||||
DEFAULT_PROFILES_DIR = os.path.join(os.path.expanduser('~'), '.dbt')
|
DEFAULT_PROFILES_DIR = os.path.join(os.path.expanduser("~"), ".dbt")
|
||||||
PROFILES_DIR = os.path.expanduser(
|
PROFILES_DIR = os.path.expanduser(os.getenv("DBT_PROFILES_DIR", DEFAULT_PROFILES_DIR))
|
||||||
os.getenv('DBT_PROFILES_DIR', DEFAULT_PROFILES_DIR)
|
|
||||||
)
|
|
||||||
|
|
||||||
INVALID_PROFILE_MESSAGE = """
|
INVALID_PROFILE_MESSAGE = """
|
||||||
dbt encountered an error while trying to read your profiles.yml file.
|
dbt encountered an error while trying to read your profiles.yml file.
|
||||||
@@ -43,11 +41,13 @@ Here, [profile name] should be replaced with a profile name
|
|||||||
defined in your profiles.yml file. You can find profiles.yml here:
|
defined in your profiles.yml file. You can find profiles.yml here:
|
||||||
|
|
||||||
{profiles_file}/profiles.yml
|
{profiles_file}/profiles.yml
|
||||||
""".format(profiles_file=PROFILES_DIR)
|
""".format(
|
||||||
|
profiles_file=PROFILES_DIR
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def read_profile(profiles_dir: str) -> Dict[str, Any]:
|
def read_profile(profiles_dir: str) -> Dict[str, Any]:
|
||||||
path = os.path.join(profiles_dir, 'profiles.yml')
|
path = os.path.join(profiles_dir, "profiles.yml")
|
||||||
|
|
||||||
contents = None
|
contents = None
|
||||||
if os.path.isfile(path):
|
if os.path.isfile(path):
|
||||||
@@ -55,12 +55,8 @@ def read_profile(profiles_dir: str) -> Dict[str, Any]:
|
|||||||
contents = load_file_contents(path, strip=False)
|
contents = load_file_contents(path, strip=False)
|
||||||
yaml_content = load_yaml_text(contents)
|
yaml_content = load_yaml_text(contents)
|
||||||
if not yaml_content:
|
if not yaml_content:
|
||||||
msg = f'The profiles.yml file at {path} is empty'
|
msg = f"The profiles.yml file at {path} is empty"
|
||||||
raise DbtProfileError(
|
raise DbtProfileError(INVALID_PROFILE_MESSAGE.format(error_string=msg))
|
||||||
INVALID_PROFILE_MESSAGE.format(
|
|
||||||
error_string=msg
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return yaml_content
|
return yaml_content
|
||||||
except ValidationException as e:
|
except ValidationException as e:
|
||||||
msg = INVALID_PROFILE_MESSAGE.format(error_string=e)
|
msg = INVALID_PROFILE_MESSAGE.format(error_string=e)
|
||||||
@@ -73,7 +69,7 @@ def read_user_config(directory: str) -> UserConfig:
|
|||||||
try:
|
try:
|
||||||
profile = read_profile(directory)
|
profile = read_profile(directory)
|
||||||
if profile:
|
if profile:
|
||||||
user_cfg = coerce_dict_str(profile.get('config', {}))
|
user_cfg = coerce_dict_str(profile.get("config", {}))
|
||||||
if user_cfg is not None:
|
if user_cfg is not None:
|
||||||
UserConfig.validate(user_cfg)
|
UserConfig.validate(user_cfg)
|
||||||
return UserConfig.from_dict(user_cfg)
|
return UserConfig.from_dict(user_cfg)
|
||||||
@@ -92,9 +88,7 @@ class Profile(HasCredentials):
|
|||||||
threads: int
|
threads: int
|
||||||
credentials: Credentials
|
credentials: Credentials
|
||||||
|
|
||||||
def to_profile_info(
|
def to_profile_info(self, serialize_credentials: bool = False) -> Dict[str, Any]:
|
||||||
self, serialize_credentials: bool = False
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Unlike to_project_config, this dict is not a mirror of any existing
|
"""Unlike to_project_config, this dict is not a mirror of any existing
|
||||||
on-disk data structure. It's used when creating a new profile from an
|
on-disk data structure. It's used when creating a new profile from an
|
||||||
existing one.
|
existing one.
|
||||||
@@ -104,34 +98,35 @@ class Profile(HasCredentials):
|
|||||||
:returns dict: The serialized profile.
|
:returns dict: The serialized profile.
|
||||||
"""
|
"""
|
||||||
result = {
|
result = {
|
||||||
'profile_name': self.profile_name,
|
"profile_name": self.profile_name,
|
||||||
'target_name': self.target_name,
|
"target_name": self.target_name,
|
||||||
'config': self.config,
|
"config": self.config,
|
||||||
'threads': self.threads,
|
"threads": self.threads,
|
||||||
'credentials': self.credentials,
|
"credentials": self.credentials,
|
||||||
}
|
}
|
||||||
if serialize_credentials:
|
if serialize_credentials:
|
||||||
result['config'] = self.config.to_dict(omit_none=True)
|
result["config"] = self.config.to_dict(omit_none=True)
|
||||||
result['credentials'] = self.credentials.to_dict(omit_none=True)
|
result["credentials"] = self.credentials.to_dict(omit_none=True)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def to_target_dict(self) -> Dict[str, Any]:
|
def to_target_dict(self) -> Dict[str, Any]:
|
||||||
target = dict(
|
target = dict(self.credentials.connection_info(with_aliases=True))
|
||||||
self.credentials.connection_info(with_aliases=True)
|
target.update(
|
||||||
|
{
|
||||||
|
"type": self.credentials.type,
|
||||||
|
"threads": self.threads,
|
||||||
|
"name": self.target_name,
|
||||||
|
"target_name": self.target_name,
|
||||||
|
"profile_name": self.profile_name,
|
||||||
|
"config": self.config.to_dict(omit_none=True),
|
||||||
|
}
|
||||||
)
|
)
|
||||||
target.update({
|
|
||||||
'type': self.credentials.type,
|
|
||||||
'threads': self.threads,
|
|
||||||
'name': self.target_name,
|
|
||||||
'target_name': self.target_name,
|
|
||||||
'profile_name': self.profile_name,
|
|
||||||
'config': self.config.to_dict(omit_none=True),
|
|
||||||
})
|
|
||||||
return target
|
return target
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
def __eq__(self, other: object) -> bool:
|
||||||
if not (isinstance(other, self.__class__) and
|
if not (
|
||||||
isinstance(self, other.__class__)):
|
isinstance(other, self.__class__) and isinstance(self, other.__class__)
|
||||||
|
):
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
return self.to_profile_info() == other.to_profile_info()
|
return self.to_profile_info() == other.to_profile_info()
|
||||||
|
|
||||||
@@ -151,14 +146,17 @@ class Profile(HasCredentials):
|
|||||||
) -> Credentials:
|
) -> Credentials:
|
||||||
# avoid an import cycle
|
# avoid an import cycle
|
||||||
from dbt.adapters.factory import load_plugin
|
from dbt.adapters.factory import load_plugin
|
||||||
|
|
||||||
# credentials carry their 'type' in their actual type, not their
|
# credentials carry their 'type' in their actual type, not their
|
||||||
# attributes. We do want this in order to pick our Credentials class.
|
# attributes. We do want this in order to pick our Credentials class.
|
||||||
if 'type' not in profile:
|
if "type" not in profile:
|
||||||
raise DbtProfileError(
|
raise DbtProfileError(
|
||||||
'required field "type" not found in profile {} and target {}'
|
'required field "type" not found in profile {} and target {}'.format(
|
||||||
.format(profile_name, target_name))
|
profile_name, target_name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
typename = profile.pop('type')
|
typename = profile.pop("type")
|
||||||
try:
|
try:
|
||||||
cls = load_plugin(typename)
|
cls = load_plugin(typename)
|
||||||
data = cls.translate_aliases(profile)
|
data = cls.translate_aliases(profile)
|
||||||
@@ -167,8 +165,9 @@ class Profile(HasCredentials):
|
|||||||
except (RuntimeException, ValidationError) as e:
|
except (RuntimeException, ValidationError) as e:
|
||||||
msg = str(e) if isinstance(e, RuntimeException) else e.message
|
msg = str(e) if isinstance(e, RuntimeException) else e.message
|
||||||
raise DbtProfileError(
|
raise DbtProfileError(
|
||||||
'Credentials in profile "{}", target "{}" invalid: {}'
|
'Credentials in profile "{}", target "{}" invalid: {}'.format(
|
||||||
.format(profile_name, target_name, msg)
|
profile_name, target_name, msg
|
||||||
|
)
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
return credentials
|
return credentials
|
||||||
@@ -189,19 +188,21 @@ class Profile(HasCredentials):
|
|||||||
def _get_profile_data(
|
def _get_profile_data(
|
||||||
profile: Dict[str, Any], profile_name: str, target_name: str
|
profile: Dict[str, Any], profile_name: str, target_name: str
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
if 'outputs' not in profile:
|
if "outputs" not in profile:
|
||||||
raise DbtProfileError(
|
raise DbtProfileError(
|
||||||
"outputs not specified in profile '{}'".format(profile_name)
|
"outputs not specified in profile '{}'".format(profile_name)
|
||||||
)
|
)
|
||||||
outputs = profile['outputs']
|
outputs = profile["outputs"]
|
||||||
|
|
||||||
if target_name not in outputs:
|
if target_name not in outputs:
|
||||||
outputs = '\n'.join(' - {}'.format(output)
|
outputs = "\n".join(" - {}".format(output) for output in outputs)
|
||||||
for output in outputs)
|
msg = (
|
||||||
msg = ("The profile '{}' does not have a target named '{}'. The "
|
"The profile '{}' does not have a target named '{}'. The "
|
||||||
"valid target names for this profile are:\n{}"
|
"valid target names for this profile are:\n{}".format(
|
||||||
.format(profile_name, target_name, outputs))
|
profile_name, target_name, outputs
|
||||||
raise DbtProfileError(msg, result_type='invalid_target')
|
)
|
||||||
|
)
|
||||||
|
raise DbtProfileError(msg, result_type="invalid_target")
|
||||||
profile_data = outputs[target_name]
|
profile_data = outputs[target_name]
|
||||||
|
|
||||||
if not isinstance(profile_data, dict):
|
if not isinstance(profile_data, dict):
|
||||||
@@ -209,7 +210,7 @@ class Profile(HasCredentials):
|
|||||||
f"output '{target_name}' of profile '{profile_name}' is "
|
f"output '{target_name}' of profile '{profile_name}' is "
|
||||||
f"misconfigured in profiles.yml"
|
f"misconfigured in profiles.yml"
|
||||||
)
|
)
|
||||||
raise DbtProfileError(msg, result_type='invalid_target')
|
raise DbtProfileError(msg, result_type="invalid_target")
|
||||||
|
|
||||||
return profile_data
|
return profile_data
|
||||||
|
|
||||||
@@ -220,8 +221,8 @@ class Profile(HasCredentials):
|
|||||||
threads: int,
|
threads: int,
|
||||||
profile_name: str,
|
profile_name: str,
|
||||||
target_name: str,
|
target_name: str,
|
||||||
user_cfg: Optional[Dict[str, Any]] = None
|
user_cfg: Optional[Dict[str, Any]] = None,
|
||||||
) -> 'Profile':
|
) -> "Profile":
|
||||||
"""Create a profile from an existing set of Credentials and the
|
"""Create a profile from an existing set of Credentials and the
|
||||||
remaining information.
|
remaining information.
|
||||||
|
|
||||||
@@ -244,7 +245,7 @@ class Profile(HasCredentials):
|
|||||||
target_name=target_name,
|
target_name=target_name,
|
||||||
config=config,
|
config=config,
|
||||||
threads=threads,
|
threads=threads,
|
||||||
credentials=credentials
|
credentials=credentials,
|
||||||
)
|
)
|
||||||
profile.validate()
|
profile.validate()
|
||||||
return profile
|
return profile
|
||||||
@@ -269,19 +270,18 @@ class Profile(HasCredentials):
|
|||||||
# name to extract a profile that we can render.
|
# name to extract a profile that we can render.
|
||||||
if target_override is not None:
|
if target_override is not None:
|
||||||
target_name = target_override
|
target_name = target_override
|
||||||
elif 'target' in raw_profile:
|
elif "target" in raw_profile:
|
||||||
# render the target if it was parsed from yaml
|
# render the target if it was parsed from yaml
|
||||||
target_name = renderer.render_value(raw_profile['target'])
|
target_name = renderer.render_value(raw_profile["target"])
|
||||||
else:
|
else:
|
||||||
target_name = 'default'
|
target_name = "default"
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"target not specified in profile '{}', using '{}'"
|
"target not specified in profile '{}', using '{}'".format(
|
||||||
.format(profile_name, target_name)
|
profile_name, target_name
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
raw_profile_data = cls._get_profile_data(
|
raw_profile_data = cls._get_profile_data(raw_profile, profile_name, target_name)
|
||||||
raw_profile, profile_name, target_name
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
profile_data = renderer.render_data(raw_profile_data)
|
profile_data = renderer.render_data(raw_profile_data)
|
||||||
@@ -298,7 +298,7 @@ class Profile(HasCredentials):
|
|||||||
user_cfg: Optional[Dict[str, Any]] = None,
|
user_cfg: Optional[Dict[str, Any]] = None,
|
||||||
target_override: Optional[str] = None,
|
target_override: Optional[str] = None,
|
||||||
threads_override: Optional[int] = None,
|
threads_override: Optional[int] = None,
|
||||||
) -> 'Profile':
|
) -> "Profile":
|
||||||
"""Create a profile from its raw profile information.
|
"""Create a profile from its raw profile information.
|
||||||
|
|
||||||
(this is an intermediate step, mostly useful for unit testing)
|
(this is an intermediate step, mostly useful for unit testing)
|
||||||
@@ -319,7 +319,7 @@ class Profile(HasCredentials):
|
|||||||
"""
|
"""
|
||||||
# user_cfg is not rendered.
|
# user_cfg is not rendered.
|
||||||
if user_cfg is None:
|
if user_cfg is None:
|
||||||
user_cfg = raw_profile.get('config')
|
user_cfg = raw_profile.get("config")
|
||||||
# TODO: should it be, and the values coerced to bool?
|
# TODO: should it be, and the values coerced to bool?
|
||||||
target_name, profile_data = cls.render_profile(
|
target_name, profile_data = cls.render_profile(
|
||||||
raw_profile, profile_name, target_override, renderer
|
raw_profile, profile_name, target_override, renderer
|
||||||
@@ -327,7 +327,7 @@ class Profile(HasCredentials):
|
|||||||
|
|
||||||
# valid connections never include the number of threads, but it's
|
# valid connections never include the number of threads, but it's
|
||||||
# stored on a per-connection level in the raw configs
|
# stored on a per-connection level in the raw configs
|
||||||
threads = profile_data.pop('threads', DEFAULT_THREADS)
|
threads = profile_data.pop("threads", DEFAULT_THREADS)
|
||||||
if threads_override is not None:
|
if threads_override is not None:
|
||||||
threads = threads_override
|
threads = threads_override
|
||||||
|
|
||||||
@@ -340,7 +340,7 @@ class Profile(HasCredentials):
|
|||||||
profile_name=profile_name,
|
profile_name=profile_name,
|
||||||
target_name=target_name,
|
target_name=target_name,
|
||||||
threads=threads,
|
threads=threads,
|
||||||
user_cfg=user_cfg
|
user_cfg=user_cfg,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -351,7 +351,7 @@ class Profile(HasCredentials):
|
|||||||
renderer: ProfileRenderer,
|
renderer: ProfileRenderer,
|
||||||
target_override: Optional[str] = None,
|
target_override: Optional[str] = None,
|
||||||
threads_override: Optional[int] = None,
|
threads_override: Optional[int] = None,
|
||||||
) -> 'Profile':
|
) -> "Profile":
|
||||||
"""
|
"""
|
||||||
:param raw_profiles: The profile data, from disk as yaml.
|
:param raw_profiles: The profile data, from disk as yaml.
|
||||||
:param profile_name: The profile name to use.
|
:param profile_name: The profile name to use.
|
||||||
@@ -375,15 +375,9 @@ class Profile(HasCredentials):
|
|||||||
# don't render keys, so we can pluck that out
|
# don't render keys, so we can pluck that out
|
||||||
raw_profile = raw_profiles[profile_name]
|
raw_profile = raw_profiles[profile_name]
|
||||||
if not raw_profile:
|
if not raw_profile:
|
||||||
msg = (
|
msg = f"Profile {profile_name} in profiles.yml is empty"
|
||||||
f'Profile {profile_name} in profiles.yml is empty'
|
raise DbtProfileError(INVALID_PROFILE_MESSAGE.format(error_string=msg))
|
||||||
)
|
user_cfg = raw_profiles.get("config")
|
||||||
raise DbtProfileError(
|
|
||||||
INVALID_PROFILE_MESSAGE.format(
|
|
||||||
error_string=msg
|
|
||||||
)
|
|
||||||
)
|
|
||||||
user_cfg = raw_profiles.get('config')
|
|
||||||
|
|
||||||
return cls.from_raw_profile_info(
|
return cls.from_raw_profile_info(
|
||||||
raw_profile=raw_profile,
|
raw_profile=raw_profile,
|
||||||
@@ -400,7 +394,7 @@ class Profile(HasCredentials):
|
|||||||
args: Any,
|
args: Any,
|
||||||
renderer: ProfileRenderer,
|
renderer: ProfileRenderer,
|
||||||
project_profile_name: Optional[str],
|
project_profile_name: Optional[str],
|
||||||
) -> 'Profile':
|
) -> "Profile":
|
||||||
"""Given the raw profiles as read from disk and the name of the desired
|
"""Given the raw profiles as read from disk and the name of the desired
|
||||||
profile if specified, return the profile component of the runtime
|
profile if specified, return the profile component of the runtime
|
||||||
config.
|
config.
|
||||||
@@ -415,15 +409,16 @@ class Profile(HasCredentials):
|
|||||||
target could not be found.
|
target could not be found.
|
||||||
:returns Profile: The new Profile object.
|
:returns Profile: The new Profile object.
|
||||||
"""
|
"""
|
||||||
threads_override = getattr(args, 'threads', None)
|
threads_override = getattr(args, "threads", None)
|
||||||
target_override = getattr(args, 'target', None)
|
target_override = getattr(args, "target", None)
|
||||||
raw_profiles = read_profile(args.profiles_dir)
|
raw_profiles = read_profile(args.profiles_dir)
|
||||||
profile_name = cls.pick_profile_name(getattr(args, 'profile', None),
|
profile_name = cls.pick_profile_name(
|
||||||
project_profile_name)
|
getattr(args, "profile", None), project_profile_name
|
||||||
|
)
|
||||||
return cls.from_raw_profiles(
|
return cls.from_raw_profiles(
|
||||||
raw_profiles=raw_profiles,
|
raw_profiles=raw_profiles,
|
||||||
profile_name=profile_name,
|
profile_name=profile_name,
|
||||||
renderer=renderer,
|
renderer=renderer,
|
||||||
target_override=target_override,
|
target_override=target_override,
|
||||||
threads_override=threads_override
|
threads_override=threads_override,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2,7 +2,13 @@ from copy import deepcopy
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import (
|
from typing import (
|
||||||
List, Dict, Any, Optional, TypeVar, Union, Mapping,
|
List,
|
||||||
|
Dict,
|
||||||
|
Any,
|
||||||
|
Optional,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
Mapping,
|
||||||
)
|
)
|
||||||
from typing_extensions import Protocol, runtime_checkable
|
from typing_extensions import Protocol, runtime_checkable
|
||||||
|
|
||||||
@@ -82,9 +88,7 @@ def _load_yaml(path):
|
|||||||
|
|
||||||
|
|
||||||
def package_data_from_root(project_root):
|
def package_data_from_root(project_root):
|
||||||
package_filepath = resolve_path_from_base(
|
package_filepath = resolve_path_from_base("packages.yml", project_root)
|
||||||
'packages.yml', project_root
|
|
||||||
)
|
|
||||||
|
|
||||||
if path_exists(package_filepath):
|
if path_exists(package_filepath):
|
||||||
packages_dict = _load_yaml(package_filepath)
|
packages_dict = _load_yaml(package_filepath)
|
||||||
@@ -95,7 +99,7 @@ def package_data_from_root(project_root):
|
|||||||
|
|
||||||
def package_config_from_data(packages_data: Dict[str, Any]):
|
def package_config_from_data(packages_data: Dict[str, Any]):
|
||||||
if not packages_data:
|
if not packages_data:
|
||||||
packages_data = {'packages': []}
|
packages_data = {"packages": []}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
PackageConfig.validate(packages_data)
|
PackageConfig.validate(packages_data)
|
||||||
@@ -118,7 +122,7 @@ def _parse_versions(versions: Union[List[str], str]) -> List[VersionSpecifier]:
|
|||||||
Regardless, this will return a list of VersionSpecifiers
|
Regardless, this will return a list of VersionSpecifiers
|
||||||
"""
|
"""
|
||||||
if isinstance(versions, str):
|
if isinstance(versions, str):
|
||||||
versions = versions.split(',')
|
versions = versions.split(",")
|
||||||
return [VersionSpecifier.from_version_string(v) for v in versions]
|
return [VersionSpecifier.from_version_string(v) for v in versions]
|
||||||
|
|
||||||
|
|
||||||
@@ -129,11 +133,12 @@ def _all_source_paths(
|
|||||||
analysis_paths: List[str],
|
analysis_paths: List[str],
|
||||||
macro_paths: List[str],
|
macro_paths: List[str],
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
return list(chain(source_paths, data_paths, snapshot_paths, analysis_paths,
|
return list(
|
||||||
macro_paths))
|
chain(source_paths, data_paths, snapshot_paths, analysis_paths, macro_paths)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar('T')
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
def value_or(value: Optional[T], default: T) -> T:
|
def value_or(value: Optional[T], default: T) -> T:
|
||||||
@@ -146,21 +151,18 @@ def value_or(value: Optional[T], default: T) -> T:
|
|||||||
def _raw_project_from(project_root: str) -> Dict[str, Any]:
|
def _raw_project_from(project_root: str) -> Dict[str, Any]:
|
||||||
|
|
||||||
project_root = os.path.normpath(project_root)
|
project_root = os.path.normpath(project_root)
|
||||||
project_yaml_filepath = os.path.join(project_root, 'dbt_project.yml')
|
project_yaml_filepath = os.path.join(project_root, "dbt_project.yml")
|
||||||
|
|
||||||
# get the project.yml contents
|
# get the project.yml contents
|
||||||
if not path_exists(project_yaml_filepath):
|
if not path_exists(project_yaml_filepath):
|
||||||
raise DbtProjectError(
|
raise DbtProjectError(
|
||||||
'no dbt_project.yml found at expected path {}'
|
"no dbt_project.yml found at expected path {}".format(project_yaml_filepath)
|
||||||
.format(project_yaml_filepath)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
project_dict = _load_yaml(project_yaml_filepath)
|
project_dict = _load_yaml(project_yaml_filepath)
|
||||||
|
|
||||||
if not isinstance(project_dict, dict):
|
if not isinstance(project_dict, dict):
|
||||||
raise DbtProjectError(
|
raise DbtProjectError("dbt_project.yml does not parse to a dictionary")
|
||||||
'dbt_project.yml does not parse to a dictionary'
|
|
||||||
)
|
|
||||||
|
|
||||||
return project_dict
|
return project_dict
|
||||||
|
|
||||||
@@ -169,7 +171,7 @@ def _query_comment_from_cfg(
|
|||||||
cfg_query_comment: Union[QueryComment, NoValue, str, None]
|
cfg_query_comment: Union[QueryComment, NoValue, str, None]
|
||||||
) -> QueryComment:
|
) -> QueryComment:
|
||||||
if not cfg_query_comment:
|
if not cfg_query_comment:
|
||||||
return QueryComment(comment='')
|
return QueryComment(comment="")
|
||||||
|
|
||||||
if isinstance(cfg_query_comment, str):
|
if isinstance(cfg_query_comment, str):
|
||||||
return QueryComment(comment=cfg_query_comment)
|
return QueryComment(comment=cfg_query_comment)
|
||||||
@@ -186,9 +188,7 @@ def validate_version(dbt_version: List[VersionSpecifier], project_name: str):
|
|||||||
if not versions_compatible(*dbt_version):
|
if not versions_compatible(*dbt_version):
|
||||||
msg = IMPOSSIBLE_VERSION_ERROR.format(
|
msg = IMPOSSIBLE_VERSION_ERROR.format(
|
||||||
package=project_name,
|
package=project_name,
|
||||||
version_spec=[
|
version_spec=[x.to_version_string() for x in dbt_version],
|
||||||
x.to_version_string() for x in dbt_version
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
raise DbtProjectError(msg)
|
raise DbtProjectError(msg)
|
||||||
|
|
||||||
@@ -196,9 +196,7 @@ def validate_version(dbt_version: List[VersionSpecifier], project_name: str):
|
|||||||
msg = INVALID_VERSION_ERROR.format(
|
msg = INVALID_VERSION_ERROR.format(
|
||||||
package=project_name,
|
package=project_name,
|
||||||
installed=installed.to_version_string(),
|
installed=installed.to_version_string(),
|
||||||
version_spec=[
|
version_spec=[x.to_version_string() for x in dbt_version],
|
||||||
x.to_version_string() for x in dbt_version
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
raise DbtProjectError(msg)
|
raise DbtProjectError(msg)
|
||||||
|
|
||||||
@@ -207,8 +205,8 @@ def _get_required_version(
|
|||||||
project_dict: Dict[str, Any],
|
project_dict: Dict[str, Any],
|
||||||
verify_version: bool,
|
verify_version: bool,
|
||||||
) -> List[VersionSpecifier]:
|
) -> List[VersionSpecifier]:
|
||||||
dbt_raw_version: Union[List[str], str] = '>=0.0.0'
|
dbt_raw_version: Union[List[str], str] = ">=0.0.0"
|
||||||
required = project_dict.get('require-dbt-version')
|
required = project_dict.get("require-dbt-version")
|
||||||
if required is not None:
|
if required is not None:
|
||||||
dbt_raw_version = required
|
dbt_raw_version = required
|
||||||
|
|
||||||
@@ -219,11 +217,11 @@ def _get_required_version(
|
|||||||
|
|
||||||
if verify_version:
|
if verify_version:
|
||||||
# no name is also an error that we want to raise
|
# no name is also an error that we want to raise
|
||||||
if 'name' not in project_dict:
|
if "name" not in project_dict:
|
||||||
raise DbtProjectError(
|
raise DbtProjectError(
|
||||||
'Required "name" field not present in project',
|
'Required "name" field not present in project',
|
||||||
)
|
)
|
||||||
validate_version(dbt_version, project_dict['name'])
|
validate_version(dbt_version, project_dict["name"])
|
||||||
|
|
||||||
return dbt_version
|
return dbt_version
|
||||||
|
|
||||||
@@ -231,34 +229,36 @@ def _get_required_version(
|
|||||||
@dataclass
|
@dataclass
|
||||||
class RenderComponents:
|
class RenderComponents:
|
||||||
project_dict: Dict[str, Any] = field(
|
project_dict: Dict[str, Any] = field(
|
||||||
metadata=dict(description='The project dictionary')
|
metadata=dict(description="The project dictionary")
|
||||||
)
|
)
|
||||||
packages_dict: Dict[str, Any] = field(
|
packages_dict: Dict[str, Any] = field(
|
||||||
metadata=dict(description='The packages dictionary')
|
metadata=dict(description="The packages dictionary")
|
||||||
)
|
)
|
||||||
selectors_dict: Dict[str, Any] = field(
|
selectors_dict: Dict[str, Any] = field(
|
||||||
metadata=dict(description='The selectors dictionary')
|
metadata=dict(description="The selectors dictionary")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PartialProject(RenderComponents):
|
class PartialProject(RenderComponents):
|
||||||
profile_name: Optional[str] = field(metadata=dict(
|
profile_name: Optional[str] = field(
|
||||||
description='The unrendered profile name in the project, if set'
|
metadata=dict(description="The unrendered profile name in the project, if set")
|
||||||
))
|
)
|
||||||
project_name: Optional[str] = field(metadata=dict(
|
project_name: Optional[str] = field(
|
||||||
description=(
|
metadata=dict(
|
||||||
'The name of the project. This should always be set and will not '
|
description=(
|
||||||
'be rendered'
|
"The name of the project. This should always be set and will not "
|
||||||
|
"be rendered"
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
))
|
|
||||||
project_root: str = field(
|
project_root: str = field(
|
||||||
metadata=dict(description='The root directory of the project'),
|
metadata=dict(description="The root directory of the project"),
|
||||||
)
|
)
|
||||||
verify_version: bool = field(
|
verify_version: bool = field(
|
||||||
metadata=dict(description=(
|
metadata=dict(
|
||||||
'If True, verify the dbt version matches the required version'
|
description=("If True, verify the dbt version matches the required version")
|
||||||
))
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def render_profile_name(self, renderer) -> Optional[str]:
|
def render_profile_name(self, renderer) -> Optional[str]:
|
||||||
@@ -271,9 +271,7 @@ class PartialProject(RenderComponents):
|
|||||||
renderer: DbtProjectYamlRenderer,
|
renderer: DbtProjectYamlRenderer,
|
||||||
) -> RenderComponents:
|
) -> RenderComponents:
|
||||||
|
|
||||||
rendered_project = renderer.render_project(
|
rendered_project = renderer.render_project(self.project_dict, self.project_root)
|
||||||
self.project_dict, self.project_root
|
|
||||||
)
|
|
||||||
rendered_packages = renderer.render_packages(self.packages_dict)
|
rendered_packages = renderer.render_packages(self.packages_dict)
|
||||||
rendered_selectors = renderer.render_selectors(self.selectors_dict)
|
rendered_selectors = renderer.render_selectors(self.selectors_dict)
|
||||||
|
|
||||||
@@ -283,16 +281,16 @@ class PartialProject(RenderComponents):
|
|||||||
selectors_dict=rendered_selectors,
|
selectors_dict=rendered_selectors,
|
||||||
)
|
)
|
||||||
|
|
||||||
def render(self, renderer: DbtProjectYamlRenderer) -> 'Project':
|
def render(self, renderer: DbtProjectYamlRenderer) -> "Project":
|
||||||
try:
|
try:
|
||||||
rendered = self.get_rendered(renderer)
|
rendered = self.get_rendered(renderer)
|
||||||
return self.create_project(rendered)
|
return self.create_project(rendered)
|
||||||
except DbtProjectError as exc:
|
except DbtProjectError as exc:
|
||||||
if exc.path is None:
|
if exc.path is None:
|
||||||
exc.path = os.path.join(self.project_root, 'dbt_project.yml')
|
exc.path = os.path.join(self.project_root, "dbt_project.yml")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def create_project(self, rendered: RenderComponents) -> 'Project':
|
def create_project(self, rendered: RenderComponents) -> "Project":
|
||||||
unrendered = RenderComponents(
|
unrendered = RenderComponents(
|
||||||
project_dict=self.project_dict,
|
project_dict=self.project_dict,
|
||||||
packages_dict=self.packages_dict,
|
packages_dict=self.packages_dict,
|
||||||
@@ -305,9 +303,7 @@ class PartialProject(RenderComponents):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
ProjectContract.validate(rendered.project_dict)
|
ProjectContract.validate(rendered.project_dict)
|
||||||
cfg = ProjectContract.from_dict(
|
cfg = ProjectContract.from_dict(rendered.project_dict)
|
||||||
rendered.project_dict
|
|
||||||
)
|
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
raise DbtProjectError(validator_error_message(e)) from e
|
raise DbtProjectError(validator_error_message(e)) from e
|
||||||
# name/version are required in the Project definition, so we can assume
|
# name/version are required in the Project definition, so we can assume
|
||||||
@@ -317,31 +313,30 @@ class PartialProject(RenderComponents):
|
|||||||
# this is added at project_dict parse time and should always be here
|
# this is added at project_dict parse time and should always be here
|
||||||
# once we see it.
|
# once we see it.
|
||||||
if cfg.project_root is None:
|
if cfg.project_root is None:
|
||||||
raise DbtProjectError('cfg must have a project root!')
|
raise DbtProjectError("cfg must have a project root!")
|
||||||
else:
|
else:
|
||||||
project_root = cfg.project_root
|
project_root = cfg.project_root
|
||||||
# this is only optional in the sense that if it's not present, it needs
|
# this is only optional in the sense that if it's not present, it needs
|
||||||
# to have been a cli argument.
|
# to have been a cli argument.
|
||||||
profile_name = cfg.profile
|
profile_name = cfg.profile
|
||||||
# these are all the defaults
|
# these are all the defaults
|
||||||
source_paths: List[str] = value_or(cfg.source_paths, ['models'])
|
source_paths: List[str] = value_or(cfg.source_paths, ["models"])
|
||||||
macro_paths: List[str] = value_or(cfg.macro_paths, ['macros'])
|
macro_paths: List[str] = value_or(cfg.macro_paths, ["macros"])
|
||||||
data_paths: List[str] = value_or(cfg.data_paths, ['data'])
|
data_paths: List[str] = value_or(cfg.data_paths, ["data"])
|
||||||
test_paths: List[str] = value_or(cfg.test_paths, ['test'])
|
test_paths: List[str] = value_or(cfg.test_paths, ["test"])
|
||||||
analysis_paths: List[str] = value_or(cfg.analysis_paths, [])
|
analysis_paths: List[str] = value_or(cfg.analysis_paths, [])
|
||||||
snapshot_paths: List[str] = value_or(cfg.snapshot_paths, ['snapshots'])
|
snapshot_paths: List[str] = value_or(cfg.snapshot_paths, ["snapshots"])
|
||||||
|
|
||||||
all_source_paths: List[str] = _all_source_paths(
|
all_source_paths: List[str] = _all_source_paths(
|
||||||
source_paths, data_paths, snapshot_paths, analysis_paths,
|
source_paths, data_paths, snapshot_paths, analysis_paths, macro_paths
|
||||||
macro_paths
|
|
||||||
)
|
)
|
||||||
|
|
||||||
docs_paths: List[str] = value_or(cfg.docs_paths, all_source_paths)
|
docs_paths: List[str] = value_or(cfg.docs_paths, all_source_paths)
|
||||||
asset_paths: List[str] = value_or(cfg.asset_paths, [])
|
asset_paths: List[str] = value_or(cfg.asset_paths, [])
|
||||||
target_path: str = value_or(cfg.target_path, 'target')
|
target_path: str = value_or(cfg.target_path, "target")
|
||||||
clean_targets: List[str] = value_or(cfg.clean_targets, [target_path])
|
clean_targets: List[str] = value_or(cfg.clean_targets, [target_path])
|
||||||
log_path: str = value_or(cfg.log_path, 'logs')
|
log_path: str = value_or(cfg.log_path, "logs")
|
||||||
modules_path: str = value_or(cfg.modules_path, 'dbt_modules')
|
modules_path: str = value_or(cfg.modules_path, "dbt_modules")
|
||||||
# in the default case we'll populate this once we know the adapter type
|
# in the default case we'll populate this once we know the adapter type
|
||||||
# It would be nice to just pass along a Quoting here, but that would
|
# It would be nice to just pass along a Quoting here, but that would
|
||||||
# break many things
|
# break many things
|
||||||
@@ -373,11 +368,12 @@ class PartialProject(RenderComponents):
|
|||||||
packages = package_config_from_data(rendered.packages_dict)
|
packages = package_config_from_data(rendered.packages_dict)
|
||||||
selectors = selector_config_from_data(rendered.selectors_dict)
|
selectors = selector_config_from_data(rendered.selectors_dict)
|
||||||
manifest_selectors: Dict[str, Any] = {}
|
manifest_selectors: Dict[str, Any] = {}
|
||||||
if rendered.selectors_dict and rendered.selectors_dict['selectors']:
|
if rendered.selectors_dict and rendered.selectors_dict["selectors"]:
|
||||||
# this is a dict with a single key 'selectors' pointing to a list
|
# this is a dict with a single key 'selectors' pointing to a list
|
||||||
# of dicts.
|
# of dicts.
|
||||||
manifest_selectors = SelectorDict.parse_from_selectors_list(
|
manifest_selectors = SelectorDict.parse_from_selectors_list(
|
||||||
rendered.selectors_dict['selectors'])
|
rendered.selectors_dict["selectors"]
|
||||||
|
)
|
||||||
|
|
||||||
project = Project(
|
project = Project(
|
||||||
project_name=name,
|
project_name=name,
|
||||||
@@ -426,10 +422,9 @@ class PartialProject(RenderComponents):
|
|||||||
*,
|
*,
|
||||||
verify_version: bool = False,
|
verify_version: bool = False,
|
||||||
):
|
):
|
||||||
"""Construct a partial project from its constituent dicts.
|
"""Construct a partial project from its constituent dicts."""
|
||||||
"""
|
project_name = project_dict.get("name")
|
||||||
project_name = project_dict.get('name')
|
profile_name = project_dict.get("profile")
|
||||||
profile_name = project_dict.get('profile')
|
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
profile_name=profile_name,
|
profile_name=profile_name,
|
||||||
@@ -444,14 +439,14 @@ class PartialProject(RenderComponents):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_project_root(
|
def from_project_root(
|
||||||
cls, project_root: str, *, verify_version: bool = False
|
cls, project_root: str, *, verify_version: bool = False
|
||||||
) -> 'PartialProject':
|
) -> "PartialProject":
|
||||||
project_root = os.path.normpath(project_root)
|
project_root = os.path.normpath(project_root)
|
||||||
project_dict = _raw_project_from(project_root)
|
project_dict = _raw_project_from(project_root)
|
||||||
config_version = project_dict.get('config-version', 1)
|
config_version = project_dict.get("config-version", 1)
|
||||||
if config_version != 2:
|
if config_version != 2:
|
||||||
raise DbtProjectError(
|
raise DbtProjectError(
|
||||||
f'Invalid config version: {config_version}, expected 2',
|
f"Invalid config version: {config_version}, expected 2",
|
||||||
path=os.path.join(project_root, 'dbt_project.yml')
|
path=os.path.join(project_root, "dbt_project.yml"),
|
||||||
)
|
)
|
||||||
|
|
||||||
packages_dict = package_data_from_root(project_root)
|
packages_dict = package_data_from_root(project_root)
|
||||||
@@ -468,15 +463,10 @@ class PartialProject(RenderComponents):
|
|||||||
class VarProvider:
|
class VarProvider:
|
||||||
"""Var providers are tied to a particular Project."""
|
"""Var providers are tied to a particular Project."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, vars: Dict[str, Dict[str, Any]]) -> None:
|
||||||
self,
|
|
||||||
vars: Dict[str, Dict[str, Any]]
|
|
||||||
) -> None:
|
|
||||||
self.vars = vars
|
self.vars = vars
|
||||||
|
|
||||||
def vars_for(
|
def vars_for(self, node: IsFQNResource, adapter_type: str) -> Mapping[str, Any]:
|
||||||
self, node: IsFQNResource, adapter_type: str
|
|
||||||
) -> Mapping[str, Any]:
|
|
||||||
# in v2, vars are only either project or globally scoped
|
# in v2, vars are only either project or globally scoped
|
||||||
merged = MultiDict([self.vars])
|
merged = MultiDict([self.vars])
|
||||||
merged.add(self.vars.get(node.package_name, {}))
|
merged.add(self.vars.get(node.package_name, {}))
|
||||||
@@ -525,8 +515,11 @@ class Project:
|
|||||||
@property
|
@property
|
||||||
def all_source_paths(self) -> List[str]:
|
def all_source_paths(self) -> List[str]:
|
||||||
return _all_source_paths(
|
return _all_source_paths(
|
||||||
self.source_paths, self.data_paths, self.snapshot_paths,
|
self.source_paths,
|
||||||
self.analysis_paths, self.macro_paths
|
self.data_paths,
|
||||||
|
self.snapshot_paths,
|
||||||
|
self.analysis_paths,
|
||||||
|
self.macro_paths,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
@@ -534,11 +527,13 @@ class Project:
|
|||||||
return str(cfg)
|
return str(cfg)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if not (isinstance(other, self.__class__) and
|
if not (
|
||||||
isinstance(self, other.__class__)):
|
isinstance(other, self.__class__) and isinstance(self, other.__class__)
|
||||||
|
):
|
||||||
return False
|
return False
|
||||||
return self.to_project_config(with_packages=True) == \
|
return self.to_project_config(with_packages=True) == other.to_project_config(
|
||||||
other.to_project_config(with_packages=True)
|
with_packages=True
|
||||||
|
)
|
||||||
|
|
||||||
def to_project_config(self, with_packages=False):
|
def to_project_config(self, with_packages=False):
|
||||||
"""Return a dict representation of the config that could be written to
|
"""Return a dict representation of the config that could be written to
|
||||||
@@ -548,38 +543,39 @@ class Project:
|
|||||||
file in the root.
|
file in the root.
|
||||||
:returns dict: The serialized profile.
|
:returns dict: The serialized profile.
|
||||||
"""
|
"""
|
||||||
result = deepcopy({
|
result = deepcopy(
|
||||||
'name': self.project_name,
|
{
|
||||||
'version': self.version,
|
"name": self.project_name,
|
||||||
'project-root': self.project_root,
|
"version": self.version,
|
||||||
'profile': self.profile_name,
|
"project-root": self.project_root,
|
||||||
'source-paths': self.source_paths,
|
"profile": self.profile_name,
|
||||||
'macro-paths': self.macro_paths,
|
"source-paths": self.source_paths,
|
||||||
'data-paths': self.data_paths,
|
"macro-paths": self.macro_paths,
|
||||||
'test-paths': self.test_paths,
|
"data-paths": self.data_paths,
|
||||||
'analysis-paths': self.analysis_paths,
|
"test-paths": self.test_paths,
|
||||||
'docs-paths': self.docs_paths,
|
"analysis-paths": self.analysis_paths,
|
||||||
'asset-paths': self.asset_paths,
|
"docs-paths": self.docs_paths,
|
||||||
'target-path': self.target_path,
|
"asset-paths": self.asset_paths,
|
||||||
'snapshot-paths': self.snapshot_paths,
|
"target-path": self.target_path,
|
||||||
'clean-targets': self.clean_targets,
|
"snapshot-paths": self.snapshot_paths,
|
||||||
'log-path': self.log_path,
|
"clean-targets": self.clean_targets,
|
||||||
'quoting': self.quoting,
|
"log-path": self.log_path,
|
||||||
'models': self.models,
|
"quoting": self.quoting,
|
||||||
'on-run-start': self.on_run_start,
|
"models": self.models,
|
||||||
'on-run-end': self.on_run_end,
|
"on-run-start": self.on_run_start,
|
||||||
'seeds': self.seeds,
|
"on-run-end": self.on_run_end,
|
||||||
'snapshots': self.snapshots,
|
"seeds": self.seeds,
|
||||||
'sources': self.sources,
|
"snapshots": self.snapshots,
|
||||||
'vars': self.vars.to_dict(),
|
"sources": self.sources,
|
||||||
'require-dbt-version': [
|
"vars": self.vars.to_dict(),
|
||||||
|
"require-dbt-version": [
|
||||||
v.to_version_string() for v in self.dbt_version
|
v.to_version_string() for v in self.dbt_version
|
||||||
],
|
],
|
||||||
'config-version': self.config_version,
|
"config-version": self.config_version,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
if self.query_comment:
|
if self.query_comment:
|
||||||
result['query-comment'] = \
|
result["query-comment"] = self.query_comment.to_dict(omit_none=True)
|
||||||
self.query_comment.to_dict(omit_none=True)
|
|
||||||
|
|
||||||
if with_packages:
|
if with_packages:
|
||||||
result.update(self.packages.to_dict(omit_none=True))
|
result.update(self.packages.to_dict(omit_none=True))
|
||||||
@@ -610,8 +606,8 @@ class Project:
|
|||||||
selectors_dict: Dict[str, Any],
|
selectors_dict: Dict[str, Any],
|
||||||
renderer: DbtProjectYamlRenderer,
|
renderer: DbtProjectYamlRenderer,
|
||||||
*,
|
*,
|
||||||
verify_version: bool = False
|
verify_version: bool = False,
|
||||||
) -> 'Project':
|
) -> "Project":
|
||||||
partial = PartialProject.from_dicts(
|
partial = PartialProject.from_dicts(
|
||||||
project_root=project_root,
|
project_root=project_root,
|
||||||
project_dict=project_dict,
|
project_dict=project_dict,
|
||||||
@@ -628,17 +624,17 @@ class Project:
|
|||||||
renderer: DbtProjectYamlRenderer,
|
renderer: DbtProjectYamlRenderer,
|
||||||
*,
|
*,
|
||||||
verify_version: bool = False,
|
verify_version: bool = False,
|
||||||
) -> 'Project':
|
) -> "Project":
|
||||||
partial = cls.partial_load(project_root, verify_version=verify_version)
|
partial = cls.partial_load(project_root, verify_version=verify_version)
|
||||||
return partial.render(renderer)
|
return partial.render(renderer)
|
||||||
|
|
||||||
def hashed_name(self):
|
def hashed_name(self):
|
||||||
return hashlib.md5(self.project_name.encode('utf-8')).hexdigest()
|
return hashlib.md5(self.project_name.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
def get_selector(self, name: str) -> SelectionSpec:
|
def get_selector(self, name: str) -> SelectionSpec:
|
||||||
if name not in self.selectors:
|
if name not in self.selectors:
|
||||||
raise RuntimeException(
|
raise RuntimeException(
|
||||||
f'Could not find selector named {name}, expected one of '
|
f"Could not find selector named {name}, expected one of "
|
||||||
f'{list(self.selectors)}'
|
f"{list(self.selectors)}"
|
||||||
)
|
)
|
||||||
return self.selectors[name]
|
return self.selectors[name]
|
||||||
|
|||||||
@@ -2,9 +2,7 @@ from typing import Dict, Any, Tuple, Optional, Union, Callable
|
|||||||
|
|
||||||
from dbt.clients.jinja import get_rendered, catch_jinja
|
from dbt.clients.jinja import get_rendered, catch_jinja
|
||||||
|
|
||||||
from dbt.exceptions import (
|
from dbt.exceptions import DbtProjectError, CompilationException, RecursionException
|
||||||
DbtProjectError, CompilationException, RecursionException
|
|
||||||
)
|
|
||||||
from dbt.node_types import NodeType
|
from dbt.node_types import NodeType
|
||||||
from dbt.utils import deep_map
|
from dbt.utils import deep_map
|
||||||
|
|
||||||
@@ -18,7 +16,7 @@ class BaseRenderer:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
return 'Rendering'
|
return "Rendering"
|
||||||
|
|
||||||
def should_render_keypath(self, keypath: Keypath) -> bool:
|
def should_render_keypath(self, keypath: Keypath) -> bool:
|
||||||
return True
|
return True
|
||||||
@@ -29,9 +27,7 @@ class BaseRenderer:
|
|||||||
|
|
||||||
return self.render_value(value, keypath)
|
return self.render_value(value, keypath)
|
||||||
|
|
||||||
def render_value(
|
def render_value(self, value: Any, keypath: Optional[Keypath] = None) -> Any:
|
||||||
self, value: Any, keypath: Optional[Keypath] = None
|
|
||||||
) -> Any:
|
|
||||||
# keypath is ignored.
|
# keypath is ignored.
|
||||||
# if it wasn't read as a string, ignore it
|
# if it wasn't read as a string, ignore it
|
||||||
if not isinstance(value, str):
|
if not isinstance(value, str):
|
||||||
@@ -40,18 +36,16 @@ class BaseRenderer:
|
|||||||
with catch_jinja():
|
with catch_jinja():
|
||||||
return get_rendered(value, self.context, native=True)
|
return get_rendered(value, self.context, native=True)
|
||||||
except CompilationException as exc:
|
except CompilationException as exc:
|
||||||
msg = f'Could not render {value}: {exc.msg}'
|
msg = f"Could not render {value}: {exc.msg}"
|
||||||
raise CompilationException(msg) from exc
|
raise CompilationException(msg) from exc
|
||||||
|
|
||||||
def render_data(
|
def render_data(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
self, data: Dict[str, Any]
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
try:
|
try:
|
||||||
return deep_map(self.render_entry, data)
|
return deep_map(self.render_entry, data)
|
||||||
except RecursionException:
|
except RecursionException:
|
||||||
raise DbtProjectError(
|
raise DbtProjectError(
|
||||||
f'Cycle detected: {self.name} input has a reference to itself',
|
f"Cycle detected: {self.name} input has a reference to itself",
|
||||||
project=data
|
project=data,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -78,15 +72,15 @@ class ProjectPostprocessor(Dict[Keypath, Callable[[Any], Any]]):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self[('on-run-start',)] = _list_if_none_or_string
|
self[("on-run-start",)] = _list_if_none_or_string
|
||||||
self[('on-run-end',)] = _list_if_none_or_string
|
self[("on-run-end",)] = _list_if_none_or_string
|
||||||
|
|
||||||
for k in ('models', 'seeds', 'snapshots'):
|
for k in ("models", "seeds", "snapshots"):
|
||||||
self[(k,)] = _dict_if_none
|
self[(k,)] = _dict_if_none
|
||||||
self[(k, 'vars')] = _dict_if_none
|
self[(k, "vars")] = _dict_if_none
|
||||||
self[(k, 'pre-hook')] = _list_if_none_or_string
|
self[(k, "pre-hook")] = _list_if_none_or_string
|
||||||
self[(k, 'post-hook')] = _list_if_none_or_string
|
self[(k, "post-hook")] = _list_if_none_or_string
|
||||||
self[('seeds', 'column_types')] = _dict_if_none
|
self[("seeds", "column_types")] = _dict_if_none
|
||||||
|
|
||||||
def postprocess(self, value: Any, key: Keypath) -> Any:
|
def postprocess(self, value: Any, key: Keypath) -> Any:
|
||||||
if key in self:
|
if key in self:
|
||||||
@@ -101,7 +95,7 @@ class DbtProjectYamlRenderer(BaseRenderer):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
'Project config'
|
"Project config"
|
||||||
|
|
||||||
def get_package_renderer(self) -> BaseRenderer:
|
def get_package_renderer(self) -> BaseRenderer:
|
||||||
return PackageRenderer(self.context)
|
return PackageRenderer(self.context)
|
||||||
@@ -116,7 +110,7 @@ class DbtProjectYamlRenderer(BaseRenderer):
|
|||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Render the project and insert the project root after rendering."""
|
"""Render the project and insert the project root after rendering."""
|
||||||
rendered_project = self.render_data(project)
|
rendered_project = self.render_data(project)
|
||||||
rendered_project['project-root'] = project_root
|
rendered_project["project-root"] = project_root
|
||||||
return rendered_project
|
return rendered_project
|
||||||
|
|
||||||
def render_packages(self, packages: Dict[str, Any]):
|
def render_packages(self, packages: Dict[str, Any]):
|
||||||
@@ -138,20 +132,19 @@ class DbtProjectYamlRenderer(BaseRenderer):
|
|||||||
|
|
||||||
first = keypath[0]
|
first = keypath[0]
|
||||||
# run hooks are not rendered
|
# run hooks are not rendered
|
||||||
if first in {'on-run-start', 'on-run-end', 'query-comment'}:
|
if first in {"on-run-start", "on-run-end", "query-comment"}:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# don't render vars blocks until runtime
|
# don't render vars blocks until runtime
|
||||||
if first == 'vars':
|
if first == "vars":
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if first in {'seeds', 'models', 'snapshots', 'seeds'}:
|
if first in {"seeds", "models", "snapshots", "seeds"}:
|
||||||
keypath_parts = {
|
keypath_parts = {
|
||||||
(k.lstrip('+') if isinstance(k, str) else k)
|
(k.lstrip("+") if isinstance(k, str) else k) for k in keypath
|
||||||
for k in keypath
|
|
||||||
}
|
}
|
||||||
# model-level hooks
|
# model-level hooks
|
||||||
if 'pre-hook' in keypath_parts or 'post-hook' in keypath_parts:
|
if "pre-hook" in keypath_parts or "post-hook" in keypath_parts:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
@@ -160,17 +153,15 @@ class DbtProjectYamlRenderer(BaseRenderer):
|
|||||||
class ProfileRenderer(BaseRenderer):
|
class ProfileRenderer(BaseRenderer):
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
'Profile'
|
"Profile"
|
||||||
|
|
||||||
|
|
||||||
class SchemaYamlRenderer(BaseRenderer):
|
class SchemaYamlRenderer(BaseRenderer):
|
||||||
DOCUMENTABLE_NODES = frozenset(
|
DOCUMENTABLE_NODES = frozenset(n.pluralize() for n in NodeType.documentable())
|
||||||
n.pluralize() for n in NodeType.documentable()
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
return 'Rendering yaml'
|
return "Rendering yaml"
|
||||||
|
|
||||||
def _is_norender_key(self, keypath: Keypath) -> bool:
|
def _is_norender_key(self, keypath: Keypath) -> bool:
|
||||||
"""
|
"""
|
||||||
@@ -185,13 +176,13 @@ class SchemaYamlRenderer(BaseRenderer):
|
|||||||
|
|
||||||
Return True if it's tests or description - those aren't rendered
|
Return True if it's tests or description - those aren't rendered
|
||||||
"""
|
"""
|
||||||
if len(keypath) >= 2 and keypath[1] in ('tests', 'description'):
|
if len(keypath) >= 2 and keypath[1] in ("tests", "description"):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if (
|
if (
|
||||||
len(keypath) >= 4 and
|
len(keypath) >= 4
|
||||||
keypath[1] == 'columns' and
|
and keypath[1] == "columns"
|
||||||
keypath[3] in ('tests', 'description')
|
and keypath[3] in ("tests", "description")
|
||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -209,13 +200,13 @@ class SchemaYamlRenderer(BaseRenderer):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
if keypath[0] == NodeType.Source.pluralize():
|
if keypath[0] == NodeType.Source.pluralize():
|
||||||
if keypath[2] == 'description':
|
if keypath[2] == "description":
|
||||||
return False
|
return False
|
||||||
if keypath[2] == 'tables':
|
if keypath[2] == "tables":
|
||||||
if self._is_norender_key(keypath[3:]):
|
if self._is_norender_key(keypath[3:]):
|
||||||
return False
|
return False
|
||||||
elif keypath[0] == NodeType.Macro.pluralize():
|
elif keypath[0] == NodeType.Macro.pluralize():
|
||||||
if keypath[2] == 'arguments':
|
if keypath[2] == "arguments":
|
||||||
if self._is_norender_key(keypath[3:]):
|
if self._is_norender_key(keypath[3:]):
|
||||||
return False
|
return False
|
||||||
elif self._is_norender_key(keypath[1:]):
|
elif self._is_norender_key(keypath[1:]):
|
||||||
@@ -229,10 +220,10 @@ class SchemaYamlRenderer(BaseRenderer):
|
|||||||
class PackageRenderer(BaseRenderer):
|
class PackageRenderer(BaseRenderer):
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
return 'Packages config'
|
return "Packages config"
|
||||||
|
|
||||||
|
|
||||||
class SelectorRenderer(BaseRenderer):
|
class SelectorRenderer(BaseRenderer):
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
return 'Selector config'
|
return "Selector config"
|
||||||
|
|||||||
@@ -4,8 +4,16 @@ from copy import deepcopy
|
|||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass, fields
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import (
|
||||||
Dict, Any, Optional, Mapping, Iterator, Iterable, Tuple, List, MutableSet,
|
Dict,
|
||||||
Type
|
Any,
|
||||||
|
Optional,
|
||||||
|
Mapping,
|
||||||
|
Iterator,
|
||||||
|
Iterable,
|
||||||
|
Tuple,
|
||||||
|
List,
|
||||||
|
MutableSet,
|
||||||
|
Type,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .profile import Profile
|
from .profile import Profile
|
||||||
@@ -15,7 +23,7 @@ from .utils import parse_cli_vars
|
|||||||
from dbt import tracking
|
from dbt import tracking
|
||||||
from dbt.adapters.factory import get_relation_class_by_name, get_include_paths
|
from dbt.adapters.factory import get_relation_class_by_name, get_include_paths
|
||||||
from dbt.helper_types import FQNPath, PathSet
|
from dbt.helper_types import FQNPath, PathSet
|
||||||
from dbt.context.base import generate_base_context
|
from dbt.context import generate_base_context
|
||||||
from dbt.context.target import generate_target_context
|
from dbt.context.target import generate_target_context
|
||||||
from dbt.contracts.connection import AdapterRequiredConfig, Credentials
|
from dbt.contracts.connection import AdapterRequiredConfig, Credentials
|
||||||
from dbt.contracts.graph.manifest import ManifestMetadata
|
from dbt.contracts.graph.manifest import ManifestMetadata
|
||||||
@@ -30,15 +38,13 @@ from dbt.exceptions import (
|
|||||||
DbtProjectError,
|
DbtProjectError,
|
||||||
validator_error_message,
|
validator_error_message,
|
||||||
warn_or_error,
|
warn_or_error,
|
||||||
raise_compiler_error
|
raise_compiler_error,
|
||||||
)
|
)
|
||||||
|
|
||||||
from dbt.dataclass_schema import ValidationError
|
from dbt.dataclass_schema import ValidationError
|
||||||
|
|
||||||
|
|
||||||
def _project_quoting_dict(
|
def _project_quoting_dict(proj: Project, profile: Profile) -> Dict[ComponentName, bool]:
|
||||||
proj: Project, profile: Profile
|
|
||||||
) -> Dict[ComponentName, bool]:
|
|
||||||
src: Dict[str, Any] = profile.credentials.translate_aliases(proj.quoting)
|
src: Dict[str, Any] = profile.credentials.translate_aliases(proj.quoting)
|
||||||
result: Dict[ComponentName, bool] = {}
|
result: Dict[ComponentName, bool] = {}
|
||||||
for key in ComponentName:
|
for key in ComponentName:
|
||||||
@@ -54,7 +60,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
|||||||
args: Any
|
args: Any
|
||||||
profile_name: str
|
profile_name: str
|
||||||
cli_vars: Dict[str, Any]
|
cli_vars: Dict[str, Any]
|
||||||
dependencies: Optional[Mapping[str, 'RuntimeConfig']] = None
|
dependencies: Optional[Mapping[str, "RuntimeConfig"]] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.validate()
|
self.validate()
|
||||||
@@ -65,8 +71,8 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
|||||||
project: Project,
|
project: Project,
|
||||||
profile: Profile,
|
profile: Profile,
|
||||||
args: Any,
|
args: Any,
|
||||||
dependencies: Optional[Mapping[str, 'RuntimeConfig']] = None,
|
dependencies: Optional[Mapping[str, "RuntimeConfig"]] = None,
|
||||||
) -> 'RuntimeConfig':
|
) -> "RuntimeConfig":
|
||||||
"""Instantiate a RuntimeConfig from its components.
|
"""Instantiate a RuntimeConfig from its components.
|
||||||
|
|
||||||
:param profile: A parsed dbt Profile.
|
:param profile: A parsed dbt Profile.
|
||||||
@@ -80,7 +86,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
|||||||
.replace_dict(_project_quoting_dict(project, profile))
|
.replace_dict(_project_quoting_dict(project, profile))
|
||||||
).to_dict(omit_none=True)
|
).to_dict(omit_none=True)
|
||||||
|
|
||||||
cli_vars: Dict[str, Any] = parse_cli_vars(getattr(args, 'vars', '{}'))
|
cli_vars: Dict[str, Any] = parse_cli_vars(getattr(args, "vars", "{}"))
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
project_name=project.project_name,
|
project_name=project.project_name,
|
||||||
@@ -123,7 +129,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
|||||||
dependencies=dependencies,
|
dependencies=dependencies,
|
||||||
)
|
)
|
||||||
|
|
||||||
def new_project(self, project_root: str) -> 'RuntimeConfig':
|
def new_project(self, project_root: str) -> "RuntimeConfig":
|
||||||
"""Given a new project root, read in its project dictionary, supply the
|
"""Given a new project root, read in its project dictionary, supply the
|
||||||
existing project's profile info, and create a new project file.
|
existing project's profile info, and create a new project file.
|
||||||
|
|
||||||
@@ -142,7 +148,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
|||||||
project = Project.from_project_root(
|
project = Project.from_project_root(
|
||||||
project_root,
|
project_root,
|
||||||
renderer,
|
renderer,
|
||||||
verify_version=getattr(self.args, 'version_check', False),
|
verify_version=getattr(self.args, "version_check", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
cfg = self.from_parts(
|
cfg = self.from_parts(
|
||||||
@@ -165,7 +171,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
|||||||
"""
|
"""
|
||||||
result = self.to_project_config(with_packages=True)
|
result = self.to_project_config(with_packages=True)
|
||||||
result.update(self.to_profile_info(serialize_credentials=True))
|
result.update(self.to_profile_info(serialize_credentials=True))
|
||||||
result['cli_vars'] = deepcopy(self.cli_vars)
|
result["cli_vars"] = deepcopy(self.cli_vars)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
@@ -185,30 +191,21 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
|||||||
profile_renderer: ProfileRenderer,
|
profile_renderer: ProfileRenderer,
|
||||||
profile_name: Optional[str],
|
profile_name: Optional[str],
|
||||||
) -> Profile:
|
) -> Profile:
|
||||||
return Profile.render_from_args(
|
return Profile.render_from_args(args, profile_renderer, profile_name)
|
||||||
args, profile_renderer, profile_name
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def collect_parts(
|
def collect_parts(cls: Type["RuntimeConfig"], args: Any) -> Tuple[Project, Profile]:
|
||||||
cls: Type['RuntimeConfig'], args: Any
|
|
||||||
) -> Tuple[Project, Profile]:
|
|
||||||
# profile_name from the project
|
# profile_name from the project
|
||||||
project_root = args.project_dir if args.project_dir else os.getcwd()
|
project_root = args.project_dir if args.project_dir else os.getcwd()
|
||||||
version_check = getattr(args, 'version_check', False)
|
version_check = getattr(args, "version_check", False)
|
||||||
partial = Project.partial_load(
|
partial = Project.partial_load(project_root, verify_version=version_check)
|
||||||
project_root,
|
|
||||||
verify_version=version_check
|
|
||||||
)
|
|
||||||
|
|
||||||
# build the profile using the base renderer and the one fact we know
|
# build the profile using the base renderer and the one fact we know
|
||||||
cli_vars: Dict[str, Any] = parse_cli_vars(getattr(args, 'vars', '{}'))
|
cli_vars: Dict[str, Any] = parse_cli_vars(getattr(args, "vars", "{}"))
|
||||||
profile_renderer = ProfileRenderer(generate_base_context(cli_vars))
|
profile_renderer = ProfileRenderer(generate_base_context(cli_vars))
|
||||||
profile_name = partial.render_profile_name(profile_renderer)
|
profile_name = partial.render_profile_name(profile_renderer)
|
||||||
|
|
||||||
profile = cls._get_rendered_profile(
|
profile = cls._get_rendered_profile(args, profile_renderer, profile_name)
|
||||||
args, profile_renderer, profile_name
|
|
||||||
)
|
|
||||||
|
|
||||||
# get a new renderer using our target information and render the
|
# get a new renderer using our target information and render the
|
||||||
# project
|
# project
|
||||||
@@ -218,7 +215,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
|||||||
return (project, profile)
|
return (project, profile)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_args(cls, args: Any) -> 'RuntimeConfig':
|
def from_args(cls, args: Any) -> "RuntimeConfig":
|
||||||
"""Given arguments, read in dbt_project.yml from the current directory,
|
"""Given arguments, read in dbt_project.yml from the current directory,
|
||||||
read in packages.yml if it exists, and use them to find the profile to
|
read in packages.yml if it exists, and use them to find the profile to
|
||||||
load.
|
load.
|
||||||
@@ -238,8 +235,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
|||||||
|
|
||||||
def get_metadata(self) -> ManifestMetadata:
|
def get_metadata(self) -> ManifestMetadata:
|
||||||
return ManifestMetadata(
|
return ManifestMetadata(
|
||||||
project_id=self.hashed_name(),
|
project_id=self.hashed_name(), adapter_type=self.credentials.type
|
||||||
adapter_type=self.credentials.type
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_v2_config_paths(
|
def _get_v2_config_paths(
|
||||||
@@ -249,7 +245,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
|||||||
paths: MutableSet[FQNPath],
|
paths: MutableSet[FQNPath],
|
||||||
) -> PathSet:
|
) -> PathSet:
|
||||||
for key, value in config.items():
|
for key, value in config.items():
|
||||||
if isinstance(value, dict) and not key.startswith('+'):
|
if isinstance(value, dict) and not key.startswith("+"):
|
||||||
self._get_v2_config_paths(value, path + (key,), paths)
|
self._get_v2_config_paths(value, path + (key,), paths)
|
||||||
else:
|
else:
|
||||||
paths.add(path)
|
paths.add(path)
|
||||||
@@ -265,7 +261,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
|||||||
paths = set()
|
paths = set()
|
||||||
|
|
||||||
for key, value in config.items():
|
for key, value in config.items():
|
||||||
if isinstance(value, dict) and not key.startswith('+'):
|
if isinstance(value, dict) and not key.startswith("+"):
|
||||||
self._get_v2_config_paths(value, path + (key,), paths)
|
self._get_v2_config_paths(value, path + (key,), paths)
|
||||||
else:
|
else:
|
||||||
paths.add(path)
|
paths.add(path)
|
||||||
@@ -277,10 +273,10 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
|||||||
a configured path in the resource.
|
a configured path in the resource.
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
'models': self._get_config_paths(self.models),
|
"models": self._get_config_paths(self.models),
|
||||||
'seeds': self._get_config_paths(self.seeds),
|
"seeds": self._get_config_paths(self.seeds),
|
||||||
'snapshots': self._get_config_paths(self.snapshots),
|
"snapshots": self._get_config_paths(self.snapshots),
|
||||||
'sources': self._get_config_paths(self.sources),
|
"sources": self._get_config_paths(self.sources),
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_unused_resource_config_paths(
|
def get_unused_resource_config_paths(
|
||||||
@@ -301,9 +297,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
|||||||
|
|
||||||
for config_path in config_paths:
|
for config_path in config_paths:
|
||||||
if not _is_config_used(config_path, fqns):
|
if not _is_config_used(config_path, fqns):
|
||||||
unused_resource_config_paths.append(
|
unused_resource_config_paths.append((resource_type,) + config_path)
|
||||||
(resource_type,) + config_path
|
|
||||||
)
|
|
||||||
return unused_resource_config_paths
|
return unused_resource_config_paths
|
||||||
|
|
||||||
def warn_for_unused_resource_config_paths(
|
def warn_for_unused_resource_config_paths(
|
||||||
@@ -316,27 +310,25 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
|||||||
return
|
return
|
||||||
|
|
||||||
msg = UNUSED_RESOURCE_CONFIGURATION_PATH_MESSAGE.format(
|
msg = UNUSED_RESOURCE_CONFIGURATION_PATH_MESSAGE.format(
|
||||||
len(unused),
|
len(unused), "\n".join("- {}".format(".".join(u)) for u in unused)
|
||||||
'\n'.join('- {}'.format('.'.join(u)) for u in unused)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
warn_or_error(msg, log_fmt=warning_tag('{}'))
|
warn_or_error(msg, log_fmt=warning_tag("{}"))
|
||||||
|
|
||||||
def load_dependencies(self) -> Mapping[str, 'RuntimeConfig']:
|
def load_dependencies(self) -> Mapping[str, "RuntimeConfig"]:
|
||||||
if self.dependencies is None:
|
if self.dependencies is None:
|
||||||
all_projects = {self.project_name: self}
|
all_projects = {self.project_name: self}
|
||||||
internal_packages = get_include_paths(self.credentials.type)
|
internal_packages = get_include_paths(self.credentials.type)
|
||||||
project_paths = itertools.chain(
|
project_paths = itertools.chain(
|
||||||
internal_packages,
|
internal_packages, self._get_project_directories()
|
||||||
self._get_project_directories()
|
|
||||||
)
|
)
|
||||||
for project_name, project in self.load_projects(project_paths):
|
for project_name, project in self.load_projects(project_paths):
|
||||||
if project_name in all_projects:
|
if project_name in all_projects:
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
f'dbt found more than one package with the name '
|
f"dbt found more than one package with the name "
|
||||||
f'"{project_name}" included in this project. Package '
|
f'"{project_name}" included in this project. Package '
|
||||||
f'names must be unique in a project. Please rename '
|
f"names must be unique in a project. Please rename "
|
||||||
f'one of these packages.'
|
f"one of these packages."
|
||||||
)
|
)
|
||||||
all_projects[project_name] = project
|
all_projects[project_name] = project
|
||||||
self.dependencies = all_projects
|
self.dependencies = all_projects
|
||||||
@@ -347,14 +339,14 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
|||||||
|
|
||||||
def load_projects(
|
def load_projects(
|
||||||
self, paths: Iterable[Path]
|
self, paths: Iterable[Path]
|
||||||
) -> Iterator[Tuple[str, 'RuntimeConfig']]:
|
) -> Iterator[Tuple[str, "RuntimeConfig"]]:
|
||||||
for path in paths:
|
for path in paths:
|
||||||
try:
|
try:
|
||||||
project = self.new_project(str(path))
|
project = self.new_project(str(path))
|
||||||
except DbtProjectError as e:
|
except DbtProjectError as e:
|
||||||
raise DbtProjectError(
|
raise DbtProjectError(
|
||||||
f'Failed to read package: {e}',
|
f"Failed to read package: {e}",
|
||||||
result_type='invalid_project',
|
result_type="invalid_project",
|
||||||
path=path,
|
path=path,
|
||||||
) from e
|
) from e
|
||||||
else:
|
else:
|
||||||
@@ -365,13 +357,13 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
|||||||
|
|
||||||
if root.exists():
|
if root.exists():
|
||||||
for path in root.iterdir():
|
for path in root.iterdir():
|
||||||
if path.is_dir() and not path.name.startswith('__'):
|
if path.is_dir() and not path.name.startswith("__"):
|
||||||
yield path
|
yield path
|
||||||
|
|
||||||
|
|
||||||
class UnsetCredentials(Credentials):
|
class UnsetCredentials(Credentials):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__('', '')
|
super().__init__("", "")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def type(self):
|
def type(self):
|
||||||
@@ -387,9 +379,7 @@ class UnsetCredentials(Credentials):
|
|||||||
class UnsetConfig(UserConfig):
|
class UnsetConfig(UserConfig):
|
||||||
def __getattribute__(self, name):
|
def __getattribute__(self, name):
|
||||||
if name in {f.name for f in fields(UserConfig)}:
|
if name in {f.name for f in fields(UserConfig)}:
|
||||||
raise AttributeError(
|
raise AttributeError(f"'UnsetConfig' object has no attribute {name}")
|
||||||
f"'UnsetConfig' object has no attribute {name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __post_serialize__(self, dct):
|
def __post_serialize__(self, dct):
|
||||||
return {}
|
return {}
|
||||||
@@ -399,15 +389,15 @@ class UnsetProfile(Profile):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.credentials = UnsetCredentials()
|
self.credentials = UnsetCredentials()
|
||||||
self.config = UnsetConfig()
|
self.config = UnsetConfig()
|
||||||
self.profile_name = ''
|
self.profile_name = ""
|
||||||
self.target_name = ''
|
self.target_name = ""
|
||||||
self.threads = -1
|
self.threads = -1
|
||||||
|
|
||||||
def to_target_dict(self):
|
def to_target_dict(self):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def __getattribute__(self, name):
|
def __getattribute__(self, name):
|
||||||
if name in {'profile_name', 'target_name', 'threads'}:
|
if name in {"profile_name", "target_name", "threads"}:
|
||||||
raise RuntimeException(
|
raise RuntimeException(
|
||||||
f'Error: disallowed attribute "{name}" - no profile!'
|
f'Error: disallowed attribute "{name}" - no profile!'
|
||||||
)
|
)
|
||||||
@@ -431,7 +421,7 @@ class UnsetProfileConfig(RuntimeConfig):
|
|||||||
|
|
||||||
def __getattribute__(self, name):
|
def __getattribute__(self, name):
|
||||||
# Override __getattribute__ to check that the attribute isn't 'banned'.
|
# Override __getattribute__ to check that the attribute isn't 'banned'.
|
||||||
if name in {'profile_name', 'target_name'}:
|
if name in {"profile_name", "target_name"}:
|
||||||
raise RuntimeException(
|
raise RuntimeException(
|
||||||
f'Error: disallowed attribute "{name}" - no profile!'
|
f'Error: disallowed attribute "{name}" - no profile!'
|
||||||
)
|
)
|
||||||
@@ -449,8 +439,8 @@ class UnsetProfileConfig(RuntimeConfig):
|
|||||||
project: Project,
|
project: Project,
|
||||||
profile: Profile,
|
profile: Profile,
|
||||||
args: Any,
|
args: Any,
|
||||||
dependencies: Optional[Mapping[str, 'RuntimeConfig']] = None,
|
dependencies: Optional[Mapping[str, "RuntimeConfig"]] = None,
|
||||||
) -> 'RuntimeConfig':
|
) -> "RuntimeConfig":
|
||||||
"""Instantiate a RuntimeConfig from its components.
|
"""Instantiate a RuntimeConfig from its components.
|
||||||
|
|
||||||
:param profile: Ignored.
|
:param profile: Ignored.
|
||||||
@@ -458,7 +448,7 @@ class UnsetProfileConfig(RuntimeConfig):
|
|||||||
:param args: The parsed command-line arguments.
|
:param args: The parsed command-line arguments.
|
||||||
:returns RuntimeConfig: The new configuration.
|
:returns RuntimeConfig: The new configuration.
|
||||||
"""
|
"""
|
||||||
cli_vars: Dict[str, Any] = parse_cli_vars(getattr(args, 'vars', '{}'))
|
cli_vars: Dict[str, Any] = parse_cli_vars(getattr(args, "vars", "{}"))
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
project_name=project.project_name,
|
project_name=project.project_name,
|
||||||
@@ -491,10 +481,10 @@ class UnsetProfileConfig(RuntimeConfig):
|
|||||||
vars=project.vars,
|
vars=project.vars,
|
||||||
config_version=project.config_version,
|
config_version=project.config_version,
|
||||||
unrendered=project.unrendered,
|
unrendered=project.unrendered,
|
||||||
profile_name='',
|
profile_name="",
|
||||||
target_name='',
|
target_name="",
|
||||||
config=UnsetConfig(),
|
config=UnsetConfig(),
|
||||||
threads=getattr(args, 'threads', 1),
|
threads=getattr(args, "threads", 1),
|
||||||
credentials=UnsetCredentials(),
|
credentials=UnsetCredentials(),
|
||||||
args=args,
|
args=args,
|
||||||
cli_vars=cli_vars,
|
cli_vars=cli_vars,
|
||||||
@@ -509,16 +499,11 @@ class UnsetProfileConfig(RuntimeConfig):
|
|||||||
profile_name: Optional[str],
|
profile_name: Optional[str],
|
||||||
) -> Profile:
|
) -> Profile:
|
||||||
try:
|
try:
|
||||||
profile = Profile.render_from_args(
|
profile = Profile.render_from_args(args, profile_renderer, profile_name)
|
||||||
args, profile_renderer, profile_name
|
|
||||||
)
|
|
||||||
except (DbtProjectError, DbtProfileError) as exc:
|
except (DbtProjectError, DbtProfileError) as exc:
|
||||||
logger.debug(
|
logger.debug("Profile not loaded due to error: {}", exc, exc_info=True)
|
||||||
'Profile not loaded due to error: {}', exc, exc_info=True
|
|
||||||
)
|
|
||||||
logger.info(
|
logger.info(
|
||||||
'No profile "{}" found, continuing with no target',
|
'No profile "{}" found, continuing with no target', profile_name
|
||||||
profile_name
|
|
||||||
)
|
)
|
||||||
# return the poisoned form
|
# return the poisoned form
|
||||||
profile = UnsetProfile()
|
profile = UnsetProfile()
|
||||||
@@ -527,7 +512,7 @@ class UnsetProfileConfig(RuntimeConfig):
|
|||||||
return profile
|
return profile
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_args(cls: Type[RuntimeConfig], args: Any) -> 'RuntimeConfig':
|
def from_args(cls: Type[RuntimeConfig], args: Any) -> "RuntimeConfig":
|
||||||
"""Given arguments, read in dbt_project.yml from the current directory,
|
"""Given arguments, read in dbt_project.yml from the current directory,
|
||||||
read in packages.yml if it exists, and use them to find the profile to
|
read in packages.yml if it exists, and use them to find the profile to
|
||||||
load.
|
load.
|
||||||
@@ -542,11 +527,7 @@ class UnsetProfileConfig(RuntimeConfig):
|
|||||||
# if it's a real profile, return a real config
|
# if it's a real profile, return a real config
|
||||||
cls = RuntimeConfig
|
cls = RuntimeConfig
|
||||||
|
|
||||||
return cls.from_parts(
|
return cls.from_parts(project=project, profile=profile, args=args)
|
||||||
project=project,
|
|
||||||
profile=profile,
|
|
||||||
args=args
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
UNUSED_RESOURCE_CONFIGURATION_PATH_MESSAGE = """\
|
UNUSED_RESOURCE_CONFIGURATION_PATH_MESSAGE = """\
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
from dbt.clients.yaml_helper import ( # noqa: F401
|
from dbt.clients.yaml_helper import yaml, Loader, Dumper, load_yaml_text # noqa: F401
|
||||||
yaml, Loader, Dumper, load_yaml_text
|
|
||||||
)
|
|
||||||
from dbt.dataclass_schema import ValidationError
|
from dbt.dataclass_schema import ValidationError
|
||||||
|
|
||||||
from .renderer import SelectorRenderer
|
from .renderer import SelectorRenderer
|
||||||
@@ -30,9 +28,8 @@ Validator Error:
|
|||||||
|
|
||||||
|
|
||||||
class SelectorConfig(Dict[str, SelectionSpec]):
|
class SelectorConfig(Dict[str, SelectionSpec]):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def selectors_from_dict(cls, data: Dict[str, Any]) -> 'SelectorConfig':
|
def selectors_from_dict(cls, data: Dict[str, Any]) -> "SelectorConfig":
|
||||||
try:
|
try:
|
||||||
SelectorFile.validate(data)
|
SelectorFile.validate(data)
|
||||||
selector_file = SelectorFile.from_dict(data)
|
selector_file = SelectorFile.from_dict(data)
|
||||||
@@ -45,12 +42,12 @@ class SelectorConfig(Dict[str, SelectionSpec]):
|
|||||||
f"union, intersection, string, dictionary. No lists. "
|
f"union, intersection, string, dictionary. No lists. "
|
||||||
f"\nhttps://docs.getdbt.com/reference/node-selection/"
|
f"\nhttps://docs.getdbt.com/reference/node-selection/"
|
||||||
f"yaml-selectors",
|
f"yaml-selectors",
|
||||||
result_type='invalid_selector'
|
result_type="invalid_selector",
|
||||||
) from exc
|
) from exc
|
||||||
except RuntimeException as exc:
|
except RuntimeException as exc:
|
||||||
raise DbtSelectorsError(
|
raise DbtSelectorsError(
|
||||||
f'Could not read selector file data: {exc}',
|
f"Could not read selector file data: {exc}",
|
||||||
result_type='invalid_selector',
|
result_type="invalid_selector",
|
||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
return cls(selectors)
|
return cls(selectors)
|
||||||
@@ -60,26 +57,28 @@ class SelectorConfig(Dict[str, SelectionSpec]):
|
|||||||
cls,
|
cls,
|
||||||
data: Dict[str, Any],
|
data: Dict[str, Any],
|
||||||
renderer: SelectorRenderer,
|
renderer: SelectorRenderer,
|
||||||
) -> 'SelectorConfig':
|
) -> "SelectorConfig":
|
||||||
try:
|
try:
|
||||||
rendered = renderer.render_data(data)
|
rendered = renderer.render_data(data)
|
||||||
except (ValidationError, RuntimeException) as exc:
|
except (ValidationError, RuntimeException) as exc:
|
||||||
raise DbtSelectorsError(
|
raise DbtSelectorsError(
|
||||||
f'Could not render selector data: {exc}',
|
f"Could not render selector data: {exc}",
|
||||||
result_type='invalid_selector',
|
result_type="invalid_selector",
|
||||||
) from exc
|
) from exc
|
||||||
return cls.selectors_from_dict(rendered)
|
return cls.selectors_from_dict(rendered)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_path(
|
def from_path(
|
||||||
cls, path: Path, renderer: SelectorRenderer,
|
cls,
|
||||||
) -> 'SelectorConfig':
|
path: Path,
|
||||||
|
renderer: SelectorRenderer,
|
||||||
|
) -> "SelectorConfig":
|
||||||
try:
|
try:
|
||||||
data = load_yaml_text(load_file_contents(str(path)))
|
data = load_yaml_text(load_file_contents(str(path)))
|
||||||
except (ValidationError, RuntimeException) as exc:
|
except (ValidationError, RuntimeException) as exc:
|
||||||
raise DbtSelectorsError(
|
raise DbtSelectorsError(
|
||||||
f'Could not read selector file: {exc}',
|
f"Could not read selector file: {exc}",
|
||||||
result_type='invalid_selector',
|
result_type="invalid_selector",
|
||||||
path=path,
|
path=path,
|
||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
@@ -91,9 +90,7 @@ class SelectorConfig(Dict[str, SelectionSpec]):
|
|||||||
|
|
||||||
|
|
||||||
def selector_data_from_root(project_root: str) -> Dict[str, Any]:
|
def selector_data_from_root(project_root: str) -> Dict[str, Any]:
|
||||||
selector_filepath = resolve_path_from_base(
|
selector_filepath = resolve_path_from_base("selectors.yml", project_root)
|
||||||
'selectors.yml', project_root
|
|
||||||
)
|
|
||||||
|
|
||||||
if path_exists(selector_filepath):
|
if path_exists(selector_filepath):
|
||||||
selectors_dict = load_yaml_text(load_file_contents(selector_filepath))
|
selectors_dict = load_yaml_text(load_file_contents(selector_filepath))
|
||||||
@@ -102,18 +99,16 @@ def selector_data_from_root(project_root: str) -> Dict[str, Any]:
|
|||||||
return selectors_dict
|
return selectors_dict
|
||||||
|
|
||||||
|
|
||||||
def selector_config_from_data(
|
def selector_config_from_data(selectors_data: Dict[str, Any]) -> SelectorConfig:
|
||||||
selectors_data: Dict[str, Any]
|
|
||||||
) -> SelectorConfig:
|
|
||||||
if not selectors_data:
|
if not selectors_data:
|
||||||
selectors_data = {'selectors': []}
|
selectors_data = {"selectors": []}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
selectors = SelectorConfig.selectors_from_dict(selectors_data)
|
selectors = SelectorConfig.selectors_from_dict(selectors_data)
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
raise DbtSelectorsError(
|
raise DbtSelectorsError(
|
||||||
MALFORMED_SELECTOR_ERROR.format(error=str(e.message)),
|
MALFORMED_SELECTOR_ERROR.format(error=str(e.message)),
|
||||||
result_type='invalid_selector',
|
result_type="invalid_selector",
|
||||||
) from e
|
) from e
|
||||||
return selectors
|
return selectors
|
||||||
|
|
||||||
@@ -125,7 +120,6 @@ def selector_config_from_data(
|
|||||||
# be necessary to make changes here. Ideally it would be
|
# be necessary to make changes here. Ideally it would be
|
||||||
# good to combine the two flows into one at some point.
|
# good to combine the two flows into one at some point.
|
||||||
class SelectorDict:
|
class SelectorDict:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parse_dict_definition(cls, definition):
|
def parse_dict_definition(cls, definition):
|
||||||
key = list(definition)[0]
|
key = list(definition)[0]
|
||||||
@@ -136,10 +130,10 @@ class SelectorDict:
|
|||||||
new_value = cls.parse_from_definition(sel_def)
|
new_value = cls.parse_from_definition(sel_def)
|
||||||
new_values.append(new_value)
|
new_values.append(new_value)
|
||||||
value = new_values
|
value = new_values
|
||||||
if key == 'exclude':
|
if key == "exclude":
|
||||||
definition = {key: value}
|
definition = {key: value}
|
||||||
elif len(definition) == 1:
|
elif len(definition) == 1:
|
||||||
definition = {'method': key, 'value': value}
|
definition = {"method": key, "value": value}
|
||||||
return definition
|
return definition
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -161,10 +155,10 @@ class SelectorDict:
|
|||||||
def parse_from_definition(cls, definition):
|
def parse_from_definition(cls, definition):
|
||||||
if isinstance(definition, str):
|
if isinstance(definition, str):
|
||||||
definition = SelectionCriteria.dict_from_single_spec(definition)
|
definition = SelectionCriteria.dict_from_single_spec(definition)
|
||||||
elif 'union' in definition:
|
elif "union" in definition:
|
||||||
definition = cls.parse_a_definition('union', definition)
|
definition = cls.parse_a_definition("union", definition)
|
||||||
elif 'intersection' in definition:
|
elif "intersection" in definition:
|
||||||
definition = cls.parse_a_definition('intersection', definition)
|
definition = cls.parse_a_definition("intersection", definition)
|
||||||
elif isinstance(definition, dict):
|
elif isinstance(definition, dict):
|
||||||
definition = cls.parse_dict_definition(definition)
|
definition = cls.parse_dict_definition(definition)
|
||||||
return definition
|
return definition
|
||||||
@@ -175,8 +169,8 @@ class SelectorDict:
|
|||||||
def parse_from_selectors_list(cls, selectors):
|
def parse_from_selectors_list(cls, selectors):
|
||||||
selector_dict = {}
|
selector_dict = {}
|
||||||
for selector in selectors:
|
for selector in selectors:
|
||||||
sel_name = selector['name']
|
sel_name = selector["name"]
|
||||||
selector_dict[sel_name] = selector
|
selector_dict[sel_name] = selector
|
||||||
definition = cls.parse_from_definition(selector['definition'])
|
definition = cls.parse_from_definition(selector["definition"])
|
||||||
selector_dict[sel_name]['definition'] = definition
|
selector_dict[sel_name]["definition"] = definition
|
||||||
return selector_dict
|
return selector_dict
|
||||||
|
|||||||
@@ -15,9 +15,8 @@ def parse_cli_vars(var_string: str) -> Dict[str, Any]:
|
|||||||
type_name = var_type.__name__
|
type_name = var_type.__name__
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
"The --vars argument must be a YAML dictionary, but was "
|
"The --vars argument must be a YAML dictionary, but was "
|
||||||
"of type '{}'".format(type_name))
|
"of type '{}'".format(type_name)
|
||||||
except ValidationException:
|
|
||||||
logger.error(
|
|
||||||
"The YAML provided in the --vars argument is not valid.\n"
|
|
||||||
)
|
)
|
||||||
|
except ValidationException:
|
||||||
|
logger.error("The YAML provided in the --vars argument is not valid.\n")
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -1,14 +1,16 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import (
|
from typing import Any, Dict, NoReturn, Optional, Mapping
|
||||||
Any, Dict, NoReturn, Optional, Mapping
|
|
||||||
)
|
|
||||||
|
|
||||||
from dbt import flags
|
from dbt import flags
|
||||||
from dbt import tracking
|
from dbt import tracking
|
||||||
from dbt.clients.jinja import undefined_error, get_rendered
|
from dbt.clients.jinja import undefined_error, get_rendered
|
||||||
from dbt.clients.yaml_helper import ( # noqa: F401
|
from dbt.clients.yaml_helper import ( # noqa: F401
|
||||||
yaml, safe_load, SafeLoader, Loader, Dumper
|
yaml,
|
||||||
|
safe_load,
|
||||||
|
SafeLoader,
|
||||||
|
Loader,
|
||||||
|
Dumper,
|
||||||
)
|
)
|
||||||
from dbt.contracts.graph.compiled import CompiledResource
|
from dbt.contracts.graph.compiled import CompiledResource
|
||||||
from dbt.exceptions import raise_compiler_error, MacroReturn
|
from dbt.exceptions import raise_compiler_error, MacroReturn
|
||||||
@@ -25,38 +27,26 @@ import re
|
|||||||
def get_pytz_module_context() -> Dict[str, Any]:
|
def get_pytz_module_context() -> Dict[str, Any]:
|
||||||
context_exports = pytz.__all__ # type: ignore
|
context_exports = pytz.__all__ # type: ignore
|
||||||
|
|
||||||
return {
|
return {name: getattr(pytz, name) for name in context_exports}
|
||||||
name: getattr(pytz, name) for name in context_exports
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_datetime_module_context() -> Dict[str, Any]:
|
def get_datetime_module_context() -> Dict[str, Any]:
|
||||||
context_exports = [
|
context_exports = ["date", "datetime", "time", "timedelta", "tzinfo"]
|
||||||
'date',
|
|
||||||
'datetime',
|
|
||||||
'time',
|
|
||||||
'timedelta',
|
|
||||||
'tzinfo'
|
|
||||||
]
|
|
||||||
|
|
||||||
return {
|
return {name: getattr(datetime, name) for name in context_exports}
|
||||||
name: getattr(datetime, name) for name in context_exports
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_re_module_context() -> Dict[str, Any]:
|
def get_re_module_context() -> Dict[str, Any]:
|
||||||
context_exports = re.__all__
|
context_exports = re.__all__
|
||||||
|
|
||||||
return {
|
return {name: getattr(re, name) for name in context_exports}
|
||||||
name: getattr(re, name) for name in context_exports
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_context_modules() -> Dict[str, Dict[str, Any]]:
|
def get_context_modules() -> Dict[str, Dict[str, Any]]:
|
||||||
return {
|
return {
|
||||||
'pytz': get_pytz_module_context(),
|
"pytz": get_pytz_module_context(),
|
||||||
'datetime': get_datetime_module_context(),
|
"datetime": get_datetime_module_context(),
|
||||||
're': get_re_module_context(),
|
"re": get_re_module_context(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -90,8 +80,8 @@ class ContextMeta(type):
|
|||||||
new_dct = {}
|
new_dct = {}
|
||||||
|
|
||||||
for base in bases:
|
for base in bases:
|
||||||
context_members.update(getattr(base, '_context_members_', {}))
|
context_members.update(getattr(base, "_context_members_", {}))
|
||||||
context_attrs.update(getattr(base, '_context_attrs_', {}))
|
context_attrs.update(getattr(base, "_context_attrs_", {}))
|
||||||
|
|
||||||
for key, value in dct.items():
|
for key, value in dct.items():
|
||||||
if isinstance(value, ContextMember):
|
if isinstance(value, ContextMember):
|
||||||
@@ -100,21 +90,22 @@ class ContextMeta(type):
|
|||||||
context_attrs[context_key] = key
|
context_attrs[context_key] = key
|
||||||
value = value.inner
|
value = value.inner
|
||||||
new_dct[key] = value
|
new_dct[key] = value
|
||||||
new_dct['_context_members_'] = context_members
|
new_dct["_context_members_"] = context_members
|
||||||
new_dct['_context_attrs_'] = context_attrs
|
new_dct["_context_attrs_"] = context_attrs
|
||||||
return type.__new__(mcls, name, bases, new_dct)
|
return type.__new__(mcls, name, bases, new_dct)
|
||||||
|
|
||||||
|
|
||||||
class Var:
|
class Var:
|
||||||
UndefinedVarError = "Required var '{}' not found in config:\nVars "\
|
UndefinedVarError = (
|
||||||
"supplied to {} = {}"
|
"Required var '{}' not found in config:\nVars " "supplied to {} = {}"
|
||||||
|
)
|
||||||
_VAR_NOTSET = object()
|
_VAR_NOTSET = object()
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
context: Mapping[str, Any],
|
context: Mapping[str, Any],
|
||||||
cli_vars: Mapping[str, Any],
|
cli_vars: Mapping[str, Any],
|
||||||
node: Optional[CompiledResource] = None
|
node: Optional[CompiledResource] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._context: Mapping[str, Any] = context
|
self._context: Mapping[str, Any] = context
|
||||||
self._cli_vars: Mapping[str, Any] = cli_vars
|
self._cli_vars: Mapping[str, Any] = cli_vars
|
||||||
@@ -129,14 +120,12 @@ class Var:
|
|||||||
if self._node is not None:
|
if self._node is not None:
|
||||||
return self._node.name
|
return self._node.name
|
||||||
else:
|
else:
|
||||||
return '<Configuration>'
|
return "<Configuration>"
|
||||||
|
|
||||||
def get_missing_var(self, var_name):
|
def get_missing_var(self, var_name):
|
||||||
dct = {k: self._merged[k] for k in self._merged}
|
dct = {k: self._merged[k] for k in self._merged}
|
||||||
pretty_vars = json.dumps(dct, sort_keys=True, indent=4)
|
pretty_vars = json.dumps(dct, sort_keys=True, indent=4)
|
||||||
msg = self.UndefinedVarError.format(
|
msg = self.UndefinedVarError.format(var_name, self.node_name, pretty_vars)
|
||||||
var_name, self.node_name, pretty_vars
|
|
||||||
)
|
|
||||||
raise_compiler_error(msg, self._node)
|
raise_compiler_error(msg, self._node)
|
||||||
|
|
||||||
def has_var(self, var_name: str):
|
def has_var(self, var_name: str):
|
||||||
@@ -167,7 +156,7 @@ class BaseContext(metaclass=ContextMeta):
|
|||||||
def generate_builtins(self):
|
def generate_builtins(self):
|
||||||
builtins: Dict[str, Any] = {}
|
builtins: Dict[str, Any] = {}
|
||||||
for key, value in self._context_members_.items():
|
for key, value in self._context_members_.items():
|
||||||
if hasattr(value, '__get__'):
|
if hasattr(value, "__get__"):
|
||||||
# handle properties, bound methods, etc
|
# handle properties, bound methods, etc
|
||||||
value = value.__get__(self)
|
value = value.__get__(self)
|
||||||
builtins[key] = value
|
builtins[key] = value
|
||||||
@@ -175,9 +164,9 @@ class BaseContext(metaclass=ContextMeta):
|
|||||||
|
|
||||||
# no dbtClassMixin so this is not an actual override
|
# no dbtClassMixin so this is not an actual override
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
self._ctx['context'] = self._ctx
|
self._ctx["context"] = self._ctx
|
||||||
builtins = self.generate_builtins()
|
builtins = self.generate_builtins()
|
||||||
self._ctx['builtins'] = builtins
|
self._ctx["builtins"] = builtins
|
||||||
self._ctx.update(builtins)
|
self._ctx.update(builtins)
|
||||||
return self._ctx
|
return self._ctx
|
||||||
|
|
||||||
@@ -286,18 +275,20 @@ class BaseContext(metaclass=ContextMeta):
|
|||||||
msg = f"Env var required but not provided: '{var}'"
|
msg = f"Env var required but not provided: '{var}'"
|
||||||
undefined_error(msg)
|
undefined_error(msg)
|
||||||
|
|
||||||
if os.environ.get('DBT_MACRO_DEBUGGING'):
|
if os.environ.get("DBT_MACRO_DEBUGGING"):
|
||||||
|
|
||||||
@contextmember
|
@contextmember
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def debug():
|
def debug():
|
||||||
"""Enter a debugger at this line in the compiled jinja code."""
|
"""Enter a debugger at this line in the compiled jinja code."""
|
||||||
import sys
|
import sys
|
||||||
import ipdb # type: ignore
|
import ipdb # type: ignore
|
||||||
|
|
||||||
frame = sys._getframe(3)
|
frame = sys._getframe(3)
|
||||||
ipdb.set_trace(frame)
|
ipdb.set_trace(frame)
|
||||||
return ''
|
return ""
|
||||||
|
|
||||||
@contextmember('return')
|
@contextmember("return")
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _return(data: Any) -> NoReturn:
|
def _return(data: Any) -> NoReturn:
|
||||||
"""The `return` function can be used in macros to return data to the
|
"""The `return` function can be used in macros to return data to the
|
||||||
@@ -348,9 +339,7 @@ class BaseContext(metaclass=ContextMeta):
|
|||||||
|
|
||||||
@contextmember
|
@contextmember
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def tojson(
|
def tojson(value: Any, default: Any = None, sort_keys: bool = False) -> Any:
|
||||||
value: Any, default: Any = None, sort_keys: bool = False
|
|
||||||
) -> Any:
|
|
||||||
"""The `tojson` context method can be used to serialize a Python
|
"""The `tojson` context method can be used to serialize a Python
|
||||||
object primitive, eg. a `dict` or `list` to a json string.
|
object primitive, eg. a `dict` or `list` to a json string.
|
||||||
|
|
||||||
@@ -446,7 +435,7 @@ class BaseContext(metaclass=ContextMeta):
|
|||||||
logger.info(msg)
|
logger.info(msg)
|
||||||
else:
|
else:
|
||||||
logger.debug(msg)
|
logger.debug(msg)
|
||||||
return ''
|
return ""
|
||||||
|
|
||||||
@contextproperty
|
@contextproperty
|
||||||
def run_started_at(self) -> Optional[datetime.datetime]:
|
def run_started_at(self) -> Optional[datetime.datetime]:
|
||||||
@@ -4,16 +4,14 @@ from dbt.contracts.connection import AdapterRequiredConfig
|
|||||||
from dbt.node_types import NodeType
|
from dbt.node_types import NodeType
|
||||||
from dbt.utils import MultiDict
|
from dbt.utils import MultiDict
|
||||||
|
|
||||||
from dbt.context.base import contextproperty, Var
|
from dbt.context import contextproperty, Var
|
||||||
from dbt.context.target import TargetContext
|
from dbt.context.target import TargetContext
|
||||||
|
|
||||||
|
|
||||||
class ConfiguredContext(TargetContext):
|
class ConfiguredContext(TargetContext):
|
||||||
config: AdapterRequiredConfig
|
config: AdapterRequiredConfig
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, config: AdapterRequiredConfig) -> None:
|
||||||
self, config: AdapterRequiredConfig
|
|
||||||
) -> None:
|
|
||||||
super().__init__(config, config.cli_vars)
|
super().__init__(config, config.cli_vars)
|
||||||
|
|
||||||
@contextproperty
|
@contextproperty
|
||||||
@@ -70,9 +68,7 @@ class SchemaYamlContext(ConfiguredContext):
|
|||||||
|
|
||||||
@contextproperty
|
@contextproperty
|
||||||
def var(self) -> ConfiguredVar:
|
def var(self) -> ConfiguredVar:
|
||||||
return ConfiguredVar(
|
return ConfiguredVar(self._ctx, self.config, self._project_name)
|
||||||
self._ctx, self.config, self._project_name
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_schema_yml(
|
def generate_schema_yml(
|
||||||
|
|||||||
@@ -17,8 +17,8 @@ class ModelParts(IsFQNResource):
|
|||||||
package_name: str
|
package_name: str
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar('T') # any old type
|
T = TypeVar("T") # any old type
|
||||||
C = TypeVar('C', bound=BaseConfig)
|
C = TypeVar("C", bound=BaseConfig)
|
||||||
|
|
||||||
|
|
||||||
class ConfigSource:
|
class ConfigSource:
|
||||||
@@ -36,13 +36,13 @@ class UnrenderedConfig(ConfigSource):
|
|||||||
def get_config_dict(self, resource_type: NodeType) -> Dict[str, Any]:
|
def get_config_dict(self, resource_type: NodeType) -> Dict[str, Any]:
|
||||||
unrendered = self.project.unrendered.project_dict
|
unrendered = self.project.unrendered.project_dict
|
||||||
if resource_type == NodeType.Seed:
|
if resource_type == NodeType.Seed:
|
||||||
model_configs = unrendered.get('seeds')
|
model_configs = unrendered.get("seeds")
|
||||||
elif resource_type == NodeType.Snapshot:
|
elif resource_type == NodeType.Snapshot:
|
||||||
model_configs = unrendered.get('snapshots')
|
model_configs = unrendered.get("snapshots")
|
||||||
elif resource_type == NodeType.Source:
|
elif resource_type == NodeType.Source:
|
||||||
model_configs = unrendered.get('sources')
|
model_configs = unrendered.get("sources")
|
||||||
else:
|
else:
|
||||||
model_configs = unrendered.get('models')
|
model_configs = unrendered.get("models")
|
||||||
|
|
||||||
if model_configs is None:
|
if model_configs is None:
|
||||||
return {}
|
return {}
|
||||||
@@ -79,8 +79,8 @@ class BaseContextConfigGenerator(Generic[T]):
|
|||||||
dependencies = self._active_project.load_dependencies()
|
dependencies = self._active_project.load_dependencies()
|
||||||
if project_name not in dependencies:
|
if project_name not in dependencies:
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
f'Project name {project_name} not found in dependencies '
|
f"Project name {project_name} not found in dependencies "
|
||||||
f'(found {list(dependencies)})'
|
f"(found {list(dependencies)})"
|
||||||
)
|
)
|
||||||
return dependencies[project_name]
|
return dependencies[project_name]
|
||||||
|
|
||||||
@@ -92,7 +92,7 @@ class BaseContextConfigGenerator(Generic[T]):
|
|||||||
for level_config in fqn_search(model_configs, fqn):
|
for level_config in fqn_search(model_configs, fqn):
|
||||||
result = {}
|
result = {}
|
||||||
for key, value in level_config.items():
|
for key, value in level_config.items():
|
||||||
if key.startswith('+'):
|
if key.startswith("+"):
|
||||||
result[key[1:]] = deepcopy(value)
|
result[key[1:]] = deepcopy(value)
|
||||||
elif not isinstance(value, dict):
|
elif not isinstance(value, dict):
|
||||||
result[key] = deepcopy(value)
|
result[key] = deepcopy(value)
|
||||||
@@ -171,13 +171,9 @@ class ContextConfigGenerator(BaseContextConfigGenerator[C]):
|
|||||||
def _update_from_config(
|
def _update_from_config(
|
||||||
self, result: C, partial: Dict[str, Any], validate: bool = False
|
self, result: C, partial: Dict[str, Any], validate: bool = False
|
||||||
) -> C:
|
) -> C:
|
||||||
translated = self._active_project.credentials.translate_aliases(
|
translated = self._active_project.credentials.translate_aliases(partial)
|
||||||
partial
|
|
||||||
)
|
|
||||||
return result.update_from(
|
return result.update_from(
|
||||||
translated,
|
translated, self._active_project.credentials.type, validate=validate
|
||||||
self._active_project.credentials.type,
|
|
||||||
validate=validate
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def calculate_node_config_dict(
|
def calculate_node_config_dict(
|
||||||
@@ -219,11 +215,7 @@ class UnrenderedConfigGenerator(BaseContextConfigGenerator[Dict[str, Any]]):
|
|||||||
base=base,
|
base=base,
|
||||||
)
|
)
|
||||||
|
|
||||||
def initial_result(
|
def initial_result(self, resource_type: NodeType, base: bool) -> Dict[str, Any]:
|
||||||
self,
|
|
||||||
resource_type: NodeType,
|
|
||||||
base: bool
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def _update_from_config(
|
def _update_from_config(
|
||||||
@@ -232,9 +224,7 @@ class UnrenderedConfigGenerator(BaseContextConfigGenerator[Dict[str, Any]]):
|
|||||||
partial: Dict[str, Any],
|
partial: Dict[str, Any],
|
||||||
validate: bool = False,
|
validate: bool = False,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
translated = self._active_project.credentials.translate_aliases(
|
translated = self._active_project.credentials.translate_aliases(partial)
|
||||||
partial
|
|
||||||
)
|
|
||||||
result.update(translated)
|
result.update(translated)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
from typing import (
|
from typing import Any, Dict, Union
|
||||||
Any, Dict, Union
|
|
||||||
)
|
|
||||||
|
|
||||||
from dbt.exceptions import (
|
from dbt.exceptions import (
|
||||||
doc_invalid_args,
|
doc_invalid_args,
|
||||||
@@ -11,7 +9,7 @@ from dbt.contracts.graph.compiled import CompileResultNode
|
|||||||
from dbt.contracts.graph.manifest import Manifest
|
from dbt.contracts.graph.manifest import Manifest
|
||||||
from dbt.contracts.graph.parsed import ParsedMacro
|
from dbt.contracts.graph.parsed import ParsedMacro
|
||||||
|
|
||||||
from dbt.context.base import contextmember
|
from dbt.context import contextmember
|
||||||
from dbt.context.configured import SchemaYamlContext
|
from dbt.context.configured import SchemaYamlContext
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
from typing import (
|
from typing import Dict, MutableMapping, Optional
|
||||||
Dict, MutableMapping, Optional
|
|
||||||
)
|
|
||||||
from dbt.contracts.graph.parsed import ParsedMacro
|
from dbt.contracts.graph.parsed import ParsedMacro
|
||||||
from dbt.exceptions import raise_duplicate_macro_name, raise_compiler_error
|
from dbt.exceptions import raise_duplicate_macro_name, raise_compiler_error
|
||||||
from dbt.include.global_project import PROJECT_NAME as GLOBAL_PROJECT_NAME
|
from dbt.include.global_project import PROJECT_NAME as GLOBAL_PROJECT_NAME
|
||||||
@@ -45,8 +43,7 @@ class MacroResolver:
|
|||||||
for pkg in reversed(self.internal_package_names):
|
for pkg in reversed(self.internal_package_names):
|
||||||
if pkg in self.internal_packages:
|
if pkg in self.internal_packages:
|
||||||
# Turn the internal packages into a flat namespace
|
# Turn the internal packages into a flat namespace
|
||||||
self.internal_packages_namespace.update(
|
self.internal_packages_namespace.update(self.internal_packages[pkg])
|
||||||
self.internal_packages[pkg])
|
|
||||||
|
|
||||||
def _build_macros_by_name(self):
|
def _build_macros_by_name(self):
|
||||||
macros_by_name = {}
|
macros_by_name = {}
|
||||||
@@ -74,9 +71,7 @@ class MacroResolver:
|
|||||||
package_namespaces[macro.package_name] = namespace
|
package_namespaces[macro.package_name] = namespace
|
||||||
|
|
||||||
if macro.name in namespace:
|
if macro.name in namespace:
|
||||||
raise_duplicate_macro_name(
|
raise_duplicate_macro_name(macro, macro, macro.package_name)
|
||||||
macro, macro, macro.package_name
|
|
||||||
)
|
|
||||||
package_namespaces[macro.package_name][macro.name] = macro
|
package_namespaces[macro.package_name][macro.name] = macro
|
||||||
|
|
||||||
def add_macro(self, macro: ParsedMacro):
|
def add_macro(self, macro: ParsedMacro):
|
||||||
@@ -99,8 +94,10 @@ class MacroResolver:
|
|||||||
|
|
||||||
def get_macro_id(self, local_package, macro_name):
|
def get_macro_id(self, local_package, macro_name):
|
||||||
local_package_macros = {}
|
local_package_macros = {}
|
||||||
if (local_package not in self.internal_package_names and
|
if (
|
||||||
local_package in self.packages):
|
local_package not in self.internal_package_names
|
||||||
|
and local_package in self.packages
|
||||||
|
):
|
||||||
local_package_macros = self.packages[local_package]
|
local_package_macros = self.packages[local_package]
|
||||||
# First: search the local packages for this macro
|
# First: search the local packages for this macro
|
||||||
if macro_name in local_package_macros:
|
if macro_name in local_package_macros:
|
||||||
@@ -117,9 +114,7 @@ class MacroResolver:
|
|||||||
# is that you can limit the number of macros provided to the
|
# is that you can limit the number of macros provided to the
|
||||||
# context dictionary in the 'to_dict' manifest method.
|
# context dictionary in the 'to_dict' manifest method.
|
||||||
class TestMacroNamespace:
|
class TestMacroNamespace:
|
||||||
def __init__(
|
def __init__(self, macro_resolver, ctx, node, thread_ctx, depends_on_macros):
|
||||||
self, macro_resolver, ctx, node, thread_ctx, depends_on_macros
|
|
||||||
):
|
|
||||||
self.macro_resolver = macro_resolver
|
self.macro_resolver = macro_resolver
|
||||||
self.ctx = ctx
|
self.ctx = ctx
|
||||||
self.node = node
|
self.node = node
|
||||||
@@ -129,7 +124,10 @@ class TestMacroNamespace:
|
|||||||
for macro_unique_id in depends_on_macros:
|
for macro_unique_id in depends_on_macros:
|
||||||
macro = self.manifest.macros[macro_unique_id]
|
macro = self.manifest.macros[macro_unique_id]
|
||||||
local_namespace[macro.name] = MacroGenerator(
|
local_namespace[macro.name] = MacroGenerator(
|
||||||
macro, self.ctx, self.node, self.thread_ctx,
|
macro,
|
||||||
|
self.ctx,
|
||||||
|
self.node,
|
||||||
|
self.thread_ctx,
|
||||||
)
|
)
|
||||||
self.local_namespace = local_namespace
|
self.local_namespace = local_namespace
|
||||||
|
|
||||||
@@ -144,10 +142,6 @@ class TestMacroNamespace:
|
|||||||
elif package_name in self.resolver.packages:
|
elif package_name in self.resolver.packages:
|
||||||
macro = self.macro_resolver.packages[package_name].get(name)
|
macro = self.macro_resolver.packages[package_name].get(name)
|
||||||
else:
|
else:
|
||||||
raise_compiler_error(
|
raise_compiler_error(f"Could not find package '{package_name}'")
|
||||||
f"Could not find package '{package_name}'"
|
macro_func = MacroGenerator(macro, self.ctx, self.node, self.thread_ctx)
|
||||||
)
|
|
||||||
macro_func = MacroGenerator(
|
|
||||||
macro, self.ctx, self.node, self.thread_ctx
|
|
||||||
)
|
|
||||||
return macro_func
|
return macro_func
|
||||||
|
|||||||
@@ -1,13 +1,9 @@
|
|||||||
from typing import (
|
from typing import Any, Dict, Iterable, Union, Optional, List, Iterator, Mapping, Set
|
||||||
Any, Dict, Iterable, Union, Optional, List, Iterator, Mapping, Set
|
|
||||||
)
|
|
||||||
|
|
||||||
from dbt.clients.jinja import MacroGenerator, MacroStack
|
from dbt.clients.jinja import MacroGenerator, MacroStack
|
||||||
from dbt.contracts.graph.parsed import ParsedMacro
|
from dbt.contracts.graph.parsed import ParsedMacro
|
||||||
from dbt.include.global_project import PROJECT_NAME as GLOBAL_PROJECT_NAME
|
from dbt.include.global_project import PROJECT_NAME as GLOBAL_PROJECT_NAME
|
||||||
from dbt.exceptions import (
|
from dbt.exceptions import raise_duplicate_macro_name, raise_compiler_error
|
||||||
raise_duplicate_macro_name, raise_compiler_error
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
FlatNamespace = Dict[str, MacroGenerator]
|
FlatNamespace = Dict[str, MacroGenerator]
|
||||||
@@ -75,9 +71,7 @@ class MacroNamespace(Mapping):
|
|||||||
elif package_name in self.packages:
|
elif package_name in self.packages:
|
||||||
return self.packages[package_name].get(name)
|
return self.packages[package_name].get(name)
|
||||||
else:
|
else:
|
||||||
raise_compiler_error(
|
raise_compiler_error(f"Could not find package '{package_name}'")
|
||||||
f"Could not find package '{package_name}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# This class builds the MacroNamespace by adding macros to
|
# This class builds the MacroNamespace by adding macros to
|
||||||
@@ -122,9 +116,7 @@ class MacroNamespaceBuilder:
|
|||||||
hierarchy[macro.package_name] = namespace
|
hierarchy[macro.package_name] = namespace
|
||||||
|
|
||||||
if macro.name in namespace:
|
if macro.name in namespace:
|
||||||
raise_duplicate_macro_name(
|
raise_duplicate_macro_name(macro_func.macro, macro, macro.package_name)
|
||||||
macro_func.macro, macro, macro.package_name
|
|
||||||
)
|
|
||||||
hierarchy[macro.package_name][macro.name] = macro_func
|
hierarchy[macro.package_name][macro.name] = macro_func
|
||||||
|
|
||||||
def add_macro(self, macro: ParsedMacro, ctx: Dict[str, Any]):
|
def add_macro(self, macro: ParsedMacro, ctx: Dict[str, Any]):
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ class ManifestContext(ConfiguredContext):
|
|||||||
The given macros can override any previous context values, which will be
|
The given macros can override any previous context values, which will be
|
||||||
available as if they were accessed relative to the package name.
|
available as if they were accessed relative to the package name.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: AdapterRequiredConfig,
|
config: AdapterRequiredConfig,
|
||||||
@@ -37,13 +38,12 @@ class ManifestContext(ConfiguredContext):
|
|||||||
# this takes all the macros in the manifest and adds them
|
# this takes all the macros in the manifest and adds them
|
||||||
# to the MacroNamespaceBuilder stored in self.namespace
|
# to the MacroNamespaceBuilder stored in self.namespace
|
||||||
builder = self._get_namespace_builder()
|
builder = self._get_namespace_builder()
|
||||||
return builder.build_namespace(
|
return builder.build_namespace(self.manifest.macros.values(), self._ctx)
|
||||||
self.manifest.macros.values(), self._ctx
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_namespace_builder(self) -> MacroNamespaceBuilder:
|
def _get_namespace_builder(self) -> MacroNamespaceBuilder:
|
||||||
# avoid an import loop
|
# avoid an import loop
|
||||||
from dbt.adapters.factory import get_adapter_package_names
|
from dbt.adapters.factory import get_adapter_package_names
|
||||||
|
|
||||||
internal_packages: List[str] = get_adapter_package_names(
|
internal_packages: List[str] = get_adapter_package_names(
|
||||||
self.config.credentials.type
|
self.config.credentials.type
|
||||||
)
|
)
|
||||||
@@ -68,14 +68,10 @@ class ManifestContext(ConfiguredContext):
|
|||||||
|
|
||||||
|
|
||||||
class QueryHeaderContext(ManifestContext):
|
class QueryHeaderContext(ManifestContext):
|
||||||
def __init__(
|
def __init__(self, config: AdapterRequiredConfig, manifest: Manifest) -> None:
|
||||||
self, config: AdapterRequiredConfig, manifest: Manifest
|
|
||||||
) -> None:
|
|
||||||
super().__init__(config, manifest, config.project_name)
|
super().__init__(config, manifest, config.project_name)
|
||||||
|
|
||||||
|
|
||||||
def generate_query_header_context(
|
def generate_query_header_context(config: AdapterRequiredConfig, manifest: Manifest):
|
||||||
config: AdapterRequiredConfig, manifest: Manifest
|
|
||||||
):
|
|
||||||
ctx = QueryHeaderContext(config, manifest)
|
ctx = QueryHeaderContext(config, manifest)
|
||||||
return ctx.to_dict()
|
return ctx.to_dict()
|
||||||
|
|||||||
@@ -1,7 +1,15 @@
|
|||||||
import abc
|
import abc
|
||||||
import os
|
import os
|
||||||
from typing import (
|
from typing import (
|
||||||
Callable, Any, Dict, Optional, Union, List, TypeVar, Type, Iterable,
|
Callable,
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
List,
|
||||||
|
TypeVar,
|
||||||
|
Type,
|
||||||
|
Iterable,
|
||||||
Mapping,
|
Mapping,
|
||||||
)
|
)
|
||||||
from typing_extensions import Protocol
|
from typing_extensions import Protocol
|
||||||
@@ -12,16 +20,14 @@ from dbt.adapters.factory import get_adapter, get_adapter_package_names
|
|||||||
from dbt.clients import agate_helper
|
from dbt.clients import agate_helper
|
||||||
from dbt.clients.jinja import get_rendered, MacroGenerator, MacroStack
|
from dbt.clients.jinja import get_rendered, MacroGenerator, MacroStack
|
||||||
from dbt.config import RuntimeConfig, Project
|
from dbt.config import RuntimeConfig, Project
|
||||||
from .base import contextmember, contextproperty, Var
|
from dbt.context import contextmember, contextproperty, Var
|
||||||
from .configured import FQNLookup
|
from .configured import FQNLookup
|
||||||
from .context_config import ContextConfig
|
from .context_config import ContextConfig
|
||||||
from dbt.context.macro_resolver import MacroResolver, TestMacroNamespace
|
from dbt.context.macro_resolver import MacroResolver, TestMacroNamespace
|
||||||
from .macros import MacroNamespaceBuilder, MacroNamespace
|
from .macros import MacroNamespaceBuilder, MacroNamespace
|
||||||
from .manifest import ManifestContext
|
from .manifest import ManifestContext
|
||||||
from dbt.contracts.connection import AdapterResponse
|
from dbt.contracts.connection import AdapterResponse
|
||||||
from dbt.contracts.graph.manifest import (
|
from dbt.contracts.graph.manifest import Manifest, AnyManifest, Disabled, MacroManifest
|
||||||
Manifest, AnyManifest, Disabled, MacroManifest
|
|
||||||
)
|
|
||||||
from dbt.contracts.graph.compiled import (
|
from dbt.contracts.graph.compiled import (
|
||||||
CompiledResource,
|
CompiledResource,
|
||||||
CompiledSeedNode,
|
CompiledSeedNode,
|
||||||
@@ -50,9 +56,7 @@ from dbt.config import IsFQNResource
|
|||||||
from dbt.logger import GLOBAL_LOGGER as logger # noqa
|
from dbt.logger import GLOBAL_LOGGER as logger # noqa
|
||||||
from dbt.node_types import NodeType
|
from dbt.node_types import NodeType
|
||||||
|
|
||||||
from dbt.utils import (
|
from dbt.utils import merge, AttrDict, MultiDict
|
||||||
merge, AttrDict, MultiDict
|
|
||||||
)
|
|
||||||
|
|
||||||
import agate
|
import agate
|
||||||
|
|
||||||
@@ -75,9 +79,8 @@ class RelationProxy:
|
|||||||
return self._relation_type.create_from_source(*args, **kwargs)
|
return self._relation_type.create_from_source(*args, **kwargs)
|
||||||
|
|
||||||
def create(self, *args, **kwargs):
|
def create(self, *args, **kwargs):
|
||||||
kwargs['quote_policy'] = merge(
|
kwargs["quote_policy"] = merge(
|
||||||
self._quoting_config,
|
self._quoting_config, kwargs.pop("quote_policy", {})
|
||||||
kwargs.pop('quote_policy', {})
|
|
||||||
)
|
)
|
||||||
return self._relation_type.create(*args, **kwargs)
|
return self._relation_type.create(*args, **kwargs)
|
||||||
|
|
||||||
@@ -94,7 +97,7 @@ class BaseDatabaseWrapper:
|
|||||||
self._namespace = namespace
|
self._namespace = namespace
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
raise NotImplementedError('subclasses need to implement this')
|
raise NotImplementedError("subclasses need to implement this")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def config(self):
|
def config(self):
|
||||||
@@ -110,7 +113,7 @@ class BaseDatabaseWrapper:
|
|||||||
# a future version of this could have plugins automatically call fall
|
# a future version of this could have plugins automatically call fall
|
||||||
# back to their dependencies' dependencies by using
|
# back to their dependencies' dependencies by using
|
||||||
# `get_adapter_type_names` instead of `[self.config.credentials.type]`
|
# `get_adapter_type_names` instead of `[self.config.credentials.type]`
|
||||||
search_prefixes = [self._adapter.type(), 'default']
|
search_prefixes = [self._adapter.type(), "default"]
|
||||||
return search_prefixes
|
return search_prefixes
|
||||||
|
|
||||||
def dispatch(
|
def dispatch(
|
||||||
@@ -118,8 +121,8 @@ class BaseDatabaseWrapper:
|
|||||||
) -> MacroGenerator:
|
) -> MacroGenerator:
|
||||||
search_packages: List[Optional[str]]
|
search_packages: List[Optional[str]]
|
||||||
|
|
||||||
if '.' in macro_name:
|
if "." in macro_name:
|
||||||
suggest_package, suggest_macro_name = macro_name.split('.', 1)
|
suggest_package, suggest_macro_name = macro_name.split(".", 1)
|
||||||
msg = (
|
msg = (
|
||||||
f'In adapter.dispatch, got a macro name of "{macro_name}", '
|
f'In adapter.dispatch, got a macro name of "{macro_name}", '
|
||||||
f'but "." is not a valid macro name component. Did you mean '
|
f'but "." is not a valid macro name component. Did you mean '
|
||||||
@@ -132,7 +135,7 @@ class BaseDatabaseWrapper:
|
|||||||
search_packages = [None]
|
search_packages = [None]
|
||||||
elif isinstance(packages, str):
|
elif isinstance(packages, str):
|
||||||
raise CompilationException(
|
raise CompilationException(
|
||||||
f'In adapter.dispatch, got a string packages argument '
|
f"In adapter.dispatch, got a string packages argument "
|
||||||
f'("{packages}"), but packages should be None or a list.'
|
f'("{packages}"), but packages should be None or a list.'
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -142,26 +145,24 @@ class BaseDatabaseWrapper:
|
|||||||
|
|
||||||
for package_name in search_packages:
|
for package_name in search_packages:
|
||||||
for prefix in self._get_adapter_macro_prefixes():
|
for prefix in self._get_adapter_macro_prefixes():
|
||||||
search_name = f'{prefix}__{macro_name}'
|
search_name = f"{prefix}__{macro_name}"
|
||||||
try:
|
try:
|
||||||
# this uses the namespace from the context
|
# this uses the namespace from the context
|
||||||
macro = self._namespace.get_from_package(
|
macro = self._namespace.get_from_package(package_name, search_name)
|
||||||
package_name, search_name
|
|
||||||
)
|
|
||||||
except CompilationException as exc:
|
except CompilationException as exc:
|
||||||
raise CompilationException(
|
raise CompilationException(
|
||||||
f'In dispatch: {exc.msg}',
|
f"In dispatch: {exc.msg}",
|
||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
if package_name is None:
|
if package_name is None:
|
||||||
attempts.append(search_name)
|
attempts.append(search_name)
|
||||||
else:
|
else:
|
||||||
attempts.append(f'{package_name}.{search_name}')
|
attempts.append(f"{package_name}.{search_name}")
|
||||||
|
|
||||||
if macro is not None:
|
if macro is not None:
|
||||||
return macro
|
return macro
|
||||||
|
|
||||||
searched = ', '.join(repr(a) for a in attempts)
|
searched = ", ".join(repr(a) for a in attempts)
|
||||||
msg = (
|
msg = (
|
||||||
f"In dispatch: No macro named '{macro_name}' found\n"
|
f"In dispatch: No macro named '{macro_name}' found\n"
|
||||||
f" Searched for: {searched}"
|
f" Searched for: {searched}"
|
||||||
@@ -191,14 +192,10 @@ class BaseResolver(metaclass=abc.ABCMeta):
|
|||||||
|
|
||||||
class BaseRefResolver(BaseResolver):
|
class BaseRefResolver(BaseResolver):
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def resolve(
|
def resolve(self, name: str, package: Optional[str] = None) -> RelationProxy:
|
||||||
self, name: str, package: Optional[str] = None
|
|
||||||
) -> RelationProxy:
|
|
||||||
...
|
...
|
||||||
|
|
||||||
def _repack_args(
|
def _repack_args(self, name: str, package: Optional[str]) -> List[str]:
|
||||||
self, name: str, package: Optional[str]
|
|
||||||
) -> List[str]:
|
|
||||||
if package is None:
|
if package is None:
|
||||||
return [name]
|
return [name]
|
||||||
else:
|
else:
|
||||||
@@ -207,14 +204,13 @@ class BaseRefResolver(BaseResolver):
|
|||||||
def validate_args(self, name: str, package: Optional[str]):
|
def validate_args(self, name: str, package: Optional[str]):
|
||||||
if not isinstance(name, str):
|
if not isinstance(name, str):
|
||||||
raise CompilationException(
|
raise CompilationException(
|
||||||
f'The name argument to ref() must be a string, got '
|
f"The name argument to ref() must be a string, got " f"{type(name)}"
|
||||||
f'{type(name)}'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if package is not None and not isinstance(package, str):
|
if package is not None and not isinstance(package, str):
|
||||||
raise CompilationException(
|
raise CompilationException(
|
||||||
f'The package argument to ref() must be a string or None, got '
|
f"The package argument to ref() must be a string or None, got "
|
||||||
f'{type(package)}'
|
f"{type(package)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, *args: str) -> RelationProxy:
|
def __call__(self, *args: str) -> RelationProxy:
|
||||||
@@ -239,20 +235,19 @@ class BaseSourceResolver(BaseResolver):
|
|||||||
def validate_args(self, source_name: str, table_name: str):
|
def validate_args(self, source_name: str, table_name: str):
|
||||||
if not isinstance(source_name, str):
|
if not isinstance(source_name, str):
|
||||||
raise CompilationException(
|
raise CompilationException(
|
||||||
f'The source name (first) argument to source() must be a '
|
f"The source name (first) argument to source() must be a "
|
||||||
f'string, got {type(source_name)}'
|
f"string, got {type(source_name)}"
|
||||||
)
|
)
|
||||||
if not isinstance(table_name, str):
|
if not isinstance(table_name, str):
|
||||||
raise CompilationException(
|
raise CompilationException(
|
||||||
f'The table name (second) argument to source() must be a '
|
f"The table name (second) argument to source() must be a "
|
||||||
f'string, got {type(table_name)}'
|
f"string, got {type(table_name)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, *args: str) -> RelationProxy:
|
def __call__(self, *args: str) -> RelationProxy:
|
||||||
if len(args) != 2:
|
if len(args) != 2:
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
f"source() takes exactly two arguments ({len(args)} given)",
|
f"source() takes exactly two arguments ({len(args)} given)", self.model
|
||||||
self.model
|
|
||||||
)
|
)
|
||||||
self.validate_args(args[0], args[1])
|
self.validate_args(args[0], args[1])
|
||||||
return self.resolve(args[0], args[1])
|
return self.resolve(args[0], args[1])
|
||||||
@@ -270,14 +265,15 @@ class ParseConfigObject(Config):
|
|||||||
self.context_config = context_config
|
self.context_config = context_config
|
||||||
|
|
||||||
def _transform_config(self, config):
|
def _transform_config(self, config):
|
||||||
for oldkey in ('pre_hook', 'post_hook'):
|
for oldkey in ("pre_hook", "post_hook"):
|
||||||
if oldkey in config:
|
if oldkey in config:
|
||||||
newkey = oldkey.replace('_', '-')
|
newkey = oldkey.replace("_", "-")
|
||||||
if newkey in config:
|
if newkey in config:
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
'Invalid config, has conflicting keys "{}" and "{}"'
|
'Invalid config, has conflicting keys "{}" and "{}"'.format(
|
||||||
.format(oldkey, newkey),
|
oldkey, newkey
|
||||||
self.model
|
),
|
||||||
|
self.model,
|
||||||
)
|
)
|
||||||
config[newkey] = config.pop(oldkey)
|
config[newkey] = config.pop(oldkey)
|
||||||
return config
|
return config
|
||||||
@@ -288,29 +284,25 @@ class ParseConfigObject(Config):
|
|||||||
elif len(args) == 0 and len(kwargs) > 0:
|
elif len(args) == 0 and len(kwargs) > 0:
|
||||||
opts = kwargs
|
opts = kwargs
|
||||||
else:
|
else:
|
||||||
raise_compiler_error(
|
raise_compiler_error("Invalid inline model config", self.model)
|
||||||
"Invalid inline model config",
|
|
||||||
self.model)
|
|
||||||
|
|
||||||
opts = self._transform_config(opts)
|
opts = self._transform_config(opts)
|
||||||
|
|
||||||
# it's ok to have a parse context with no context config, but you must
|
# it's ok to have a parse context with no context config, but you must
|
||||||
# not call it!
|
# not call it!
|
||||||
if self.context_config is None:
|
if self.context_config is None:
|
||||||
raise RuntimeException(
|
raise RuntimeException("At parse time, did not receive a context config")
|
||||||
'At parse time, did not receive a context config'
|
|
||||||
)
|
|
||||||
self.context_config.update_in_model_config(opts)
|
self.context_config.update_in_model_config(opts)
|
||||||
return ''
|
return ""
|
||||||
|
|
||||||
def set(self, name, value):
|
def set(self, name, value):
|
||||||
return self.__call__({name: value})
|
return self.__call__({name: value})
|
||||||
|
|
||||||
def require(self, name, validator=None):
|
def require(self, name, validator=None):
|
||||||
return ''
|
return ""
|
||||||
|
|
||||||
def get(self, name, validator=None, default=None):
|
def get(self, name, validator=None, default=None):
|
||||||
return ''
|
return ""
|
||||||
|
|
||||||
def persist_relation_docs(self) -> bool:
|
def persist_relation_docs(self) -> bool:
|
||||||
return False
|
return False
|
||||||
@@ -320,14 +312,12 @@ class ParseConfigObject(Config):
|
|||||||
|
|
||||||
|
|
||||||
class RuntimeConfigObject(Config):
|
class RuntimeConfigObject(Config):
|
||||||
def __init__(
|
def __init__(self, model, context_config: Optional[ContextConfig] = None):
|
||||||
self, model, context_config: Optional[ContextConfig] = None
|
|
||||||
):
|
|
||||||
self.model = model
|
self.model = model
|
||||||
# we never use or get a config, only the parser cares
|
# we never use or get a config, only the parser cares
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
return ''
|
return ""
|
||||||
|
|
||||||
def set(self, name, value):
|
def set(self, name, value):
|
||||||
return self.__call__({name: value})
|
return self.__call__({name: value})
|
||||||
@@ -337,7 +327,7 @@ class RuntimeConfigObject(Config):
|
|||||||
|
|
||||||
def _lookup(self, name, default=_MISSING):
|
def _lookup(self, name, default=_MISSING):
|
||||||
# if this is a macro, there might be no `model.config`.
|
# if this is a macro, there might be no `model.config`.
|
||||||
if not hasattr(self.model, 'config'):
|
if not hasattr(self.model, "config"):
|
||||||
result = default
|
result = default
|
||||||
else:
|
else:
|
||||||
result = self.model.config.get(name, default)
|
result = self.model.config.get(name, default)
|
||||||
@@ -362,22 +352,24 @@ class RuntimeConfigObject(Config):
|
|||||||
return to_return
|
return to_return
|
||||||
|
|
||||||
def persist_relation_docs(self) -> bool:
|
def persist_relation_docs(self) -> bool:
|
||||||
persist_docs = self.get('persist_docs', default={})
|
persist_docs = self.get("persist_docs", default={})
|
||||||
if not isinstance(persist_docs, dict):
|
if not isinstance(persist_docs, dict):
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
f"Invalid value provided for 'persist_docs'. Expected dict "
|
f"Invalid value provided for 'persist_docs'. Expected dict "
|
||||||
f"but received {type(persist_docs)}")
|
f"but received {type(persist_docs)}"
|
||||||
|
)
|
||||||
|
|
||||||
return persist_docs.get('relation', False)
|
return persist_docs.get("relation", False)
|
||||||
|
|
||||||
def persist_column_docs(self) -> bool:
|
def persist_column_docs(self) -> bool:
|
||||||
persist_docs = self.get('persist_docs', default={})
|
persist_docs = self.get("persist_docs", default={})
|
||||||
if not isinstance(persist_docs, dict):
|
if not isinstance(persist_docs, dict):
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
f"Invalid value provided for 'persist_docs'. Expected dict "
|
f"Invalid value provided for 'persist_docs'. Expected dict "
|
||||||
f"but received {type(persist_docs)}")
|
f"but received {type(persist_docs)}"
|
||||||
|
)
|
||||||
|
|
||||||
return persist_docs.get('columns', False)
|
return persist_docs.get("columns", False)
|
||||||
|
|
||||||
|
|
||||||
# `adapter` implementations
|
# `adapter` implementations
|
||||||
@@ -387,8 +379,10 @@ class ParseDatabaseWrapper(BaseDatabaseWrapper):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
override = (name in self._adapter._available_ and
|
override = (
|
||||||
name in self._adapter._parse_replacements_)
|
name in self._adapter._available_
|
||||||
|
and name in self._adapter._parse_replacements_
|
||||||
|
)
|
||||||
|
|
||||||
if override:
|
if override:
|
||||||
return self._adapter._parse_replacements_[name]
|
return self._adapter._parse_replacements_[name]
|
||||||
@@ -420,9 +414,7 @@ class RuntimeDatabaseWrapper(BaseDatabaseWrapper):
|
|||||||
|
|
||||||
# `ref` implementations
|
# `ref` implementations
|
||||||
class ParseRefResolver(BaseRefResolver):
|
class ParseRefResolver(BaseRefResolver):
|
||||||
def resolve(
|
def resolve(self, name: str, package: Optional[str] = None) -> RelationProxy:
|
||||||
self, name: str, package: Optional[str] = None
|
|
||||||
) -> RelationProxy:
|
|
||||||
self.model.refs.append(self._repack_args(name, package))
|
self.model.refs.append(self._repack_args(name, package))
|
||||||
|
|
||||||
return self.Relation.create_from(self.config, self.model)
|
return self.Relation.create_from(self.config, self.model)
|
||||||
@@ -452,22 +444,15 @@ class RuntimeRefResolver(BaseRefResolver):
|
|||||||
self.validate(target_model, target_name, target_package)
|
self.validate(target_model, target_name, target_package)
|
||||||
return self.create_relation(target_model, target_name)
|
return self.create_relation(target_model, target_name)
|
||||||
|
|
||||||
def create_relation(
|
def create_relation(self, target_model: ManifestNode, name: str) -> RelationProxy:
|
||||||
self, target_model: ManifestNode, name: str
|
|
||||||
) -> RelationProxy:
|
|
||||||
if target_model.is_ephemeral_model:
|
if target_model.is_ephemeral_model:
|
||||||
self.model.set_cte(target_model.unique_id, None)
|
self.model.set_cte(target_model.unique_id, None)
|
||||||
return self.Relation.create_ephemeral_from_node(
|
return self.Relation.create_ephemeral_from_node(self.config, target_model)
|
||||||
self.config, target_model
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return self.Relation.create_from(self.config, target_model)
|
return self.Relation.create_from(self.config, target_model)
|
||||||
|
|
||||||
def validate(
|
def validate(
|
||||||
self,
|
self, resolved: ManifestNode, target_name: str, target_package: Optional[str]
|
||||||
resolved: ManifestNode,
|
|
||||||
target_name: str,
|
|
||||||
target_package: Optional[str]
|
|
||||||
) -> None:
|
) -> None:
|
||||||
if resolved.unique_id not in self.model.depends_on.nodes:
|
if resolved.unique_id not in self.model.depends_on.nodes:
|
||||||
args = self._repack_args(target_name, target_package)
|
args = self._repack_args(target_name, target_package)
|
||||||
@@ -483,16 +468,15 @@ class OperationRefResolver(RuntimeRefResolver):
|
|||||||
) -> None:
|
) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def create_relation(
|
def create_relation(self, target_model: ManifestNode, name: str) -> RelationProxy:
|
||||||
self, target_model: ManifestNode, name: str
|
|
||||||
) -> RelationProxy:
|
|
||||||
if target_model.is_ephemeral_model:
|
if target_model.is_ephemeral_model:
|
||||||
# In operations, we can't ref() ephemeral nodes, because
|
# In operations, we can't ref() ephemeral nodes, because
|
||||||
# ParsedMacros do not support set_cte
|
# ParsedMacros do not support set_cte
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
'Operations can not ref() ephemeral nodes, but {} is ephemeral'
|
"Operations can not ref() ephemeral nodes, but {} is ephemeral".format(
|
||||||
.format(target_model.name),
|
target_model.name
|
||||||
self.model
|
),
|
||||||
|
self.model,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return super().create_relation(target_model, name)
|
return super().create_relation(target_model, name)
|
||||||
@@ -544,8 +528,7 @@ class ModelConfiguredVar(Var):
|
|||||||
if package_name not in dependencies:
|
if package_name not in dependencies:
|
||||||
# I don't think this is actually reachable
|
# I don't think this is actually reachable
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
f'Node package named {package_name} not found!',
|
f"Node package named {package_name} not found!", self._node
|
||||||
self._node
|
|
||||||
)
|
)
|
||||||
yield dependencies[package_name]
|
yield dependencies[package_name]
|
||||||
yield self._config
|
yield self._config
|
||||||
@@ -617,7 +600,7 @@ class OperationProvider(RuntimeProvider):
|
|||||||
ref = OperationRefResolver
|
ref = OperationRefResolver
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar('T')
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
# Base context collection, used for parsing configs.
|
# Base context collection, used for parsing configs.
|
||||||
@@ -631,9 +614,7 @@ class ProviderContext(ManifestContext):
|
|||||||
context_config: Optional[ContextConfig],
|
context_config: Optional[ContextConfig],
|
||||||
) -> None:
|
) -> None:
|
||||||
if provider is None:
|
if provider is None:
|
||||||
raise InternalException(
|
raise InternalException(f"Invalid provider given to context: {provider}")
|
||||||
f"Invalid provider given to context: {provider}"
|
|
||||||
)
|
|
||||||
# mypy appeasement - we know it'll be a RuntimeConfig
|
# mypy appeasement - we know it'll be a RuntimeConfig
|
||||||
self.config: RuntimeConfig
|
self.config: RuntimeConfig
|
||||||
self.model: Union[ParsedMacro, ManifestNode] = model
|
self.model: Union[ParsedMacro, ManifestNode] = model
|
||||||
@@ -643,16 +624,12 @@ class ProviderContext(ManifestContext):
|
|||||||
self.provider: Provider = provider
|
self.provider: Provider = provider
|
||||||
self.adapter = get_adapter(self.config)
|
self.adapter = get_adapter(self.config)
|
||||||
# The macro namespace is used in creating the DatabaseWrapper
|
# The macro namespace is used in creating the DatabaseWrapper
|
||||||
self.db_wrapper = self.provider.DatabaseWrapper(
|
self.db_wrapper = self.provider.DatabaseWrapper(self.adapter, self.namespace)
|
||||||
self.adapter, self.namespace
|
|
||||||
)
|
|
||||||
|
|
||||||
# This overrides the method in ManifestContext, and provides
|
# This overrides the method in ManifestContext, and provides
|
||||||
# a model, which the ManifestContext builder does not
|
# a model, which the ManifestContext builder does not
|
||||||
def _get_namespace_builder(self):
|
def _get_namespace_builder(self):
|
||||||
internal_packages = get_adapter_package_names(
|
internal_packages = get_adapter_package_names(self.config.credentials.type)
|
||||||
self.config.credentials.type
|
|
||||||
)
|
|
||||||
return MacroNamespaceBuilder(
|
return MacroNamespaceBuilder(
|
||||||
self.config.project_name,
|
self.config.project_name,
|
||||||
self.search_package,
|
self.search_package,
|
||||||
@@ -671,19 +648,19 @@ class ProviderContext(ManifestContext):
|
|||||||
|
|
||||||
@contextmember
|
@contextmember
|
||||||
def store_result(
|
def store_result(
|
||||||
self, name: str,
|
self, name: str, response: Any, agate_table: Optional[agate.Table] = None
|
||||||
response: Any,
|
|
||||||
agate_table: Optional[agate.Table] = None
|
|
||||||
) -> str:
|
) -> str:
|
||||||
if agate_table is None:
|
if agate_table is None:
|
||||||
agate_table = agate_helper.empty_table()
|
agate_table = agate_helper.empty_table()
|
||||||
|
|
||||||
self.sql_results[name] = AttrDict({
|
self.sql_results[name] = AttrDict(
|
||||||
'response': response,
|
{
|
||||||
'data': agate_helper.as_matrix(agate_table),
|
"response": response,
|
||||||
'table': agate_table
|
"data": agate_helper.as_matrix(agate_table),
|
||||||
})
|
"table": agate_table,
|
||||||
return ''
|
}
|
||||||
|
)
|
||||||
|
return ""
|
||||||
|
|
||||||
@contextmember
|
@contextmember
|
||||||
def store_raw_result(
|
def store_raw_result(
|
||||||
@@ -692,10 +669,11 @@ class ProviderContext(ManifestContext):
|
|||||||
message=Optional[str],
|
message=Optional[str],
|
||||||
code=Optional[str],
|
code=Optional[str],
|
||||||
rows_affected=Optional[str],
|
rows_affected=Optional[str],
|
||||||
agate_table: Optional[agate.Table] = None
|
agate_table: Optional[agate.Table] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
response = AdapterResponse(
|
response = AdapterResponse(
|
||||||
_message=message, code=code, rows_affected=rows_affected)
|
_message=message, code=code, rows_affected=rows_affected
|
||||||
|
)
|
||||||
return self.store_result(name, response, agate_table)
|
return self.store_result(name, response, agate_table)
|
||||||
|
|
||||||
@contextproperty
|
@contextproperty
|
||||||
@@ -708,25 +686,28 @@ class ProviderContext(ManifestContext):
|
|||||||
elif value == arg:
|
elif value == arg:
|
||||||
return
|
return
|
||||||
raise ValidationException(
|
raise ValidationException(
|
||||||
'Expected value "{}" to be one of {}'
|
'Expected value "{}" to be one of {}'.format(
|
||||||
.format(value, ','.join(map(str, args))))
|
value, ",".join(map(str, args))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return inner
|
return inner
|
||||||
|
|
||||||
return AttrDict({
|
return AttrDict(
|
||||||
'any': validate_any,
|
{
|
||||||
})
|
"any": validate_any,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
@contextmember
|
@contextmember
|
||||||
def write(self, payload: str) -> str:
|
def write(self, payload: str) -> str:
|
||||||
# macros/source defs aren't 'writeable'.
|
# macros/source defs aren't 'writeable'.
|
||||||
if isinstance(self.model, (ParsedMacro, ParsedSourceDefinition)):
|
if isinstance(self.model, (ParsedMacro, ParsedSourceDefinition)):
|
||||||
raise_compiler_error(
|
raise_compiler_error('cannot "write" macros or sources')
|
||||||
'cannot "write" macros or sources'
|
|
||||||
)
|
|
||||||
self.model.build_path = self.model.write_node(
|
self.model.build_path = self.model.write_node(
|
||||||
self.config.target_path, 'run', payload
|
self.config.target_path, "run", payload
|
||||||
)
|
)
|
||||||
return ''
|
return ""
|
||||||
|
|
||||||
@contextmember
|
@contextmember
|
||||||
def render(self, string: str) -> str:
|
def render(self, string: str) -> str:
|
||||||
@@ -739,20 +720,17 @@ class ProviderContext(ManifestContext):
|
|||||||
try:
|
try:
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise_compiler_error(
|
raise_compiler_error(message_if_exception, self.model)
|
||||||
message_if_exception, self.model
|
|
||||||
)
|
|
||||||
|
|
||||||
@contextmember
|
@contextmember
|
||||||
def load_agate_table(self) -> agate.Table:
|
def load_agate_table(self) -> agate.Table:
|
||||||
if not isinstance(self.model, (ParsedSeedNode, CompiledSeedNode)):
|
if not isinstance(self.model, (ParsedSeedNode, CompiledSeedNode)):
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
'can only load_agate_table for seeds (got a {})'
|
"can only load_agate_table for seeds (got a {})".format(
|
||||||
.format(self.model.resource_type)
|
self.model.resource_type
|
||||||
)
|
)
|
||||||
path = os.path.join(
|
|
||||||
self.model.root_path, self.model.original_file_path
|
|
||||||
)
|
)
|
||||||
|
path = os.path.join(self.model.root_path, self.model.original_file_path)
|
||||||
column_types = self.model.config.column_types
|
column_types = self.model.config.column_types
|
||||||
try:
|
try:
|
||||||
table = agate_helper.from_csv(path, text_columns=column_types)
|
table = agate_helper.from_csv(path, text_columns=column_types)
|
||||||
@@ -810,7 +788,7 @@ class ProviderContext(ManifestContext):
|
|||||||
self.db_wrapper, self.model, self.config, self.manifest
|
self.db_wrapper, self.model, self.config, self.manifest
|
||||||
)
|
)
|
||||||
|
|
||||||
@contextproperty('config')
|
@contextproperty("config")
|
||||||
def ctx_config(self) -> Config:
|
def ctx_config(self) -> Config:
|
||||||
"""The `config` variable exists to handle end-user configuration for
|
"""The `config` variable exists to handle end-user configuration for
|
||||||
custom materializations. Configs like `unique_key` can be implemented
|
custom materializations. Configs like `unique_key` can be implemented
|
||||||
@@ -982,7 +960,7 @@ class ProviderContext(ManifestContext):
|
|||||||
node=self.model,
|
node=self.model,
|
||||||
)
|
)
|
||||||
|
|
||||||
@contextproperty('adapter')
|
@contextproperty("adapter")
|
||||||
def ctx_adapter(self) -> BaseDatabaseWrapper:
|
def ctx_adapter(self) -> BaseDatabaseWrapper:
|
||||||
"""`adapter` is a wrapper around the internal database adapter used by
|
"""`adapter` is a wrapper around the internal database adapter used by
|
||||||
dbt. It allows users to make calls to the database in their dbt models.
|
dbt. It allows users to make calls to the database in their dbt models.
|
||||||
@@ -994,8 +972,8 @@ class ProviderContext(ManifestContext):
|
|||||||
@contextproperty
|
@contextproperty
|
||||||
def api(self) -> Dict[str, Any]:
|
def api(self) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
'Relation': self.db_wrapper.Relation,
|
"Relation": self.db_wrapper.Relation,
|
||||||
'Column': self.adapter.Column,
|
"Column": self.adapter.Column,
|
||||||
}
|
}
|
||||||
|
|
||||||
@contextproperty
|
@contextproperty
|
||||||
@@ -1113,7 +1091,7 @@ class ProviderContext(ManifestContext):
|
|||||||
""" # noqa
|
""" # noqa
|
||||||
return self.manifest.flat_graph
|
return self.manifest.flat_graph
|
||||||
|
|
||||||
@contextproperty('model')
|
@contextproperty("model")
|
||||||
def ctx_model(self) -> Dict[str, Any]:
|
def ctx_model(self) -> Dict[str, Any]:
|
||||||
return self.model.to_dict(omit_none=True)
|
return self.model.to_dict(omit_none=True)
|
||||||
|
|
||||||
@@ -1177,22 +1155,20 @@ class ProviderContext(ManifestContext):
|
|||||||
...
|
...
|
||||||
{%- endmacro %}
|
{%- endmacro %}
|
||||||
"""
|
"""
|
||||||
deprecations.warn('adapter-macro', macro_name=name)
|
deprecations.warn("adapter-macro", macro_name=name)
|
||||||
original_name = name
|
original_name = name
|
||||||
package_names: Optional[List[str]] = None
|
package_names: Optional[List[str]] = None
|
||||||
if '.' in name:
|
if "." in name:
|
||||||
package_name, name = name.split('.', 1)
|
package_name, name = name.split(".", 1)
|
||||||
package_names = [package_name]
|
package_names = [package_name]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
macro = self.db_wrapper.dispatch(
|
macro = self.db_wrapper.dispatch(macro_name=name, packages=package_names)
|
||||||
macro_name=name, packages=package_names
|
|
||||||
)
|
|
||||||
except CompilationException as exc:
|
except CompilationException as exc:
|
||||||
raise CompilationException(
|
raise CompilationException(
|
||||||
f'In adapter_macro: {exc.msg}\n'
|
f"In adapter_macro: {exc.msg}\n"
|
||||||
f" Original name: '{original_name}'",
|
f" Original name: '{original_name}'",
|
||||||
node=self.model
|
node=self.model,
|
||||||
) from exc
|
) from exc
|
||||||
return macro(*args, **kwargs)
|
return macro(*args, **kwargs)
|
||||||
|
|
||||||
@@ -1230,35 +1206,27 @@ class ModelContext(ProviderContext):
|
|||||||
def pre_hooks(self) -> List[Dict[str, Any]]:
|
def pre_hooks(self) -> List[Dict[str, Any]]:
|
||||||
if isinstance(self.model, ParsedSourceDefinition):
|
if isinstance(self.model, ParsedSourceDefinition):
|
||||||
return []
|
return []
|
||||||
return [
|
return [h.to_dict(omit_none=True) for h in self.model.config.pre_hook]
|
||||||
h.to_dict(omit_none=True) for h in self.model.config.pre_hook
|
|
||||||
]
|
|
||||||
|
|
||||||
@contextproperty
|
@contextproperty
|
||||||
def post_hooks(self) -> List[Dict[str, Any]]:
|
def post_hooks(self) -> List[Dict[str, Any]]:
|
||||||
if isinstance(self.model, ParsedSourceDefinition):
|
if isinstance(self.model, ParsedSourceDefinition):
|
||||||
return []
|
return []
|
||||||
return [
|
return [h.to_dict(omit_none=True) for h in self.model.config.post_hook]
|
||||||
h.to_dict(omit_none=True) for h in self.model.config.post_hook
|
|
||||||
]
|
|
||||||
|
|
||||||
@contextproperty
|
@contextproperty
|
||||||
def sql(self) -> Optional[str]:
|
def sql(self) -> Optional[str]:
|
||||||
if getattr(self.model, 'extra_ctes_injected', None):
|
if getattr(self.model, "extra_ctes_injected", None):
|
||||||
return self.model.compiled_sql
|
return self.model.compiled_sql
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@contextproperty
|
@contextproperty
|
||||||
def database(self) -> str:
|
def database(self) -> str:
|
||||||
return getattr(
|
return getattr(self.model, "database", self.config.credentials.database)
|
||||||
self.model, 'database', self.config.credentials.database
|
|
||||||
)
|
|
||||||
|
|
||||||
@contextproperty
|
@contextproperty
|
||||||
def schema(self) -> str:
|
def schema(self) -> str:
|
||||||
return getattr(
|
return getattr(self.model, "schema", self.config.credentials.schema)
|
||||||
self.model, 'schema', self.config.credentials.schema
|
|
||||||
)
|
|
||||||
|
|
||||||
@contextproperty
|
@contextproperty
|
||||||
def this(self) -> Optional[RelationProxy]:
|
def this(self) -> Optional[RelationProxy]:
|
||||||
@@ -1306,9 +1274,7 @@ def generate_parser_model(
|
|||||||
# The __init__ method of ModelContext also initializes
|
# The __init__ method of ModelContext also initializes
|
||||||
# a ManifestContext object which creates a MacroNamespaceBuilder
|
# a ManifestContext object which creates a MacroNamespaceBuilder
|
||||||
# which adds every macro in the Manifest.
|
# which adds every macro in the Manifest.
|
||||||
ctx = ModelContext(
|
ctx = ModelContext(model, config, manifest, ParseProvider(), context_config)
|
||||||
model, config, manifest, ParseProvider(), context_config
|
|
||||||
)
|
|
||||||
# The 'to_dict' method in ManifestContext moves all of the macro names
|
# The 'to_dict' method in ManifestContext moves all of the macro names
|
||||||
# in the macro 'namespace' up to top level keys
|
# in the macro 'namespace' up to top level keys
|
||||||
return ctx.to_dict()
|
return ctx.to_dict()
|
||||||
@@ -1319,9 +1285,7 @@ def generate_generate_component_name_macro(
|
|||||||
config: RuntimeConfig,
|
config: RuntimeConfig,
|
||||||
manifest: MacroManifest,
|
manifest: MacroManifest,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
ctx = MacroContext(
|
ctx = MacroContext(macro, config, manifest, GenerateNameProvider(), None)
|
||||||
macro, config, manifest, GenerateNameProvider(), None
|
|
||||||
)
|
|
||||||
return ctx.to_dict()
|
return ctx.to_dict()
|
||||||
|
|
||||||
|
|
||||||
@@ -1330,9 +1294,7 @@ def generate_runtime_model(
|
|||||||
config: RuntimeConfig,
|
config: RuntimeConfig,
|
||||||
manifest: Manifest,
|
manifest: Manifest,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
ctx = ModelContext(
|
ctx = ModelContext(model, config, manifest, RuntimeProvider(), None)
|
||||||
model, config, manifest, RuntimeProvider(), None
|
|
||||||
)
|
|
||||||
return ctx.to_dict()
|
return ctx.to_dict()
|
||||||
|
|
||||||
|
|
||||||
@@ -1342,9 +1304,7 @@ def generate_runtime_macro(
|
|||||||
manifest: Manifest,
|
manifest: Manifest,
|
||||||
package_name: Optional[str],
|
package_name: Optional[str],
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
ctx = MacroContext(
|
ctx = MacroContext(macro, config, manifest, OperationProvider(), package_name)
|
||||||
macro, config, manifest, OperationProvider(), package_name
|
|
||||||
)
|
|
||||||
return ctx.to_dict()
|
return ctx.to_dict()
|
||||||
|
|
||||||
|
|
||||||
@@ -1353,18 +1313,17 @@ class ExposureRefResolver(BaseResolver):
|
|||||||
if len(args) not in (1, 2):
|
if len(args) not in (1, 2):
|
||||||
ref_invalid_args(self.model, args)
|
ref_invalid_args(self.model, args)
|
||||||
self.model.refs.append(list(args))
|
self.model.refs.append(list(args))
|
||||||
return ''
|
return ""
|
||||||
|
|
||||||
|
|
||||||
class ExposureSourceResolver(BaseResolver):
|
class ExposureSourceResolver(BaseResolver):
|
||||||
def __call__(self, *args) -> str:
|
def __call__(self, *args) -> str:
|
||||||
if len(args) != 2:
|
if len(args) != 2:
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
f"source() takes exactly two arguments ({len(args)} given)",
|
f"source() takes exactly two arguments ({len(args)} given)", self.model
|
||||||
self.model
|
|
||||||
)
|
)
|
||||||
self.model.sources.append(list(args))
|
self.model.sources.append(list(args))
|
||||||
return ''
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def generate_parse_exposure(
|
def generate_parse_exposure(
|
||||||
@@ -1375,18 +1334,18 @@ def generate_parse_exposure(
|
|||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
project = config.load_dependencies()[package_name]
|
project = config.load_dependencies()[package_name]
|
||||||
return {
|
return {
|
||||||
'ref': ExposureRefResolver(
|
"ref": ExposureRefResolver(
|
||||||
None,
|
None,
|
||||||
exposure,
|
exposure,
|
||||||
project,
|
project,
|
||||||
manifest,
|
manifest,
|
||||||
),
|
),
|
||||||
'source': ExposureSourceResolver(
|
"source": ExposureSourceResolver(
|
||||||
None,
|
None,
|
||||||
exposure,
|
exposure,
|
||||||
project,
|
project,
|
||||||
manifest,
|
manifest,
|
||||||
)
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -1422,8 +1381,7 @@ class TestContext(ProviderContext):
|
|||||||
if self.model.depends_on and self.model.depends_on.macros:
|
if self.model.depends_on and self.model.depends_on.macros:
|
||||||
depends_on_macros = self.model.depends_on.macros
|
depends_on_macros = self.model.depends_on.macros
|
||||||
macro_namespace = TestMacroNamespace(
|
macro_namespace = TestMacroNamespace(
|
||||||
self.macro_resolver, self.ctx, self.node, self.thread_ctx,
|
self.macro_resolver, self.ctx, self.node, self.thread_ctx, depends_on_macros
|
||||||
depends_on_macros
|
|
||||||
)
|
)
|
||||||
self._namespace = macro_namespace
|
self._namespace = macro_namespace
|
||||||
|
|
||||||
@@ -1433,11 +1391,10 @@ def generate_test_context(
|
|||||||
config: RuntimeConfig,
|
config: RuntimeConfig,
|
||||||
manifest: Manifest,
|
manifest: Manifest,
|
||||||
context_config: ContextConfig,
|
context_config: ContextConfig,
|
||||||
macro_resolver: MacroResolver
|
macro_resolver: MacroResolver,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
ctx = TestContext(
|
ctx = TestContext(
|
||||||
model, config, manifest, ParseProvider(), context_config,
|
model, config, manifest, ParseProvider(), context_config, macro_resolver
|
||||||
macro_resolver
|
|
||||||
)
|
)
|
||||||
# The 'to_dict' method in ManifestContext moves all of the macro names
|
# The 'to_dict' method in ManifestContext moves all of the macro names
|
||||||
# in the macro 'namespace' up to top level keys
|
# in the macro 'namespace' up to top level keys
|
||||||
|
|||||||
@@ -2,9 +2,7 @@ from typing import Any, Dict
|
|||||||
|
|
||||||
from dbt.contracts.connection import HasCredentials
|
from dbt.contracts.connection import HasCredentials
|
||||||
|
|
||||||
from dbt.context.base import (
|
from dbt.context import BaseContext, contextproperty
|
||||||
BaseContext, contextproperty
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TargetContext(BaseContext):
|
class TargetContext(BaseContext):
|
||||||
|
|||||||
@@ -2,25 +2,35 @@ import abc
|
|||||||
import itertools
|
import itertools
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import (
|
from typing import (
|
||||||
Any, ClassVar, Dict, Tuple, Iterable, Optional, List, Callable,
|
Any,
|
||||||
|
ClassVar,
|
||||||
|
Dict,
|
||||||
|
Tuple,
|
||||||
|
Iterable,
|
||||||
|
Optional,
|
||||||
|
List,
|
||||||
|
Callable,
|
||||||
)
|
)
|
||||||
from dbt.exceptions import InternalException
|
from dbt.exceptions import InternalException
|
||||||
from dbt.utils import translate_aliases
|
from dbt.utils import translate_aliases
|
||||||
from dbt.logger import GLOBAL_LOGGER as logger
|
from dbt.logger import GLOBAL_LOGGER as logger
|
||||||
from typing_extensions import Protocol
|
from typing_extensions import Protocol
|
||||||
from dbt.dataclass_schema import (
|
from dbt.dataclass_schema import (
|
||||||
dbtClassMixin, StrEnum, ExtensibleDbtClassMixin,
|
dbtClassMixin,
|
||||||
ValidatedStringMixin, register_pattern
|
StrEnum,
|
||||||
|
ExtensibleDbtClassMixin,
|
||||||
|
ValidatedStringMixin,
|
||||||
|
register_pattern,
|
||||||
)
|
)
|
||||||
from dbt.contracts.util import Replaceable
|
from dbt.contracts.util import Replaceable
|
||||||
|
|
||||||
|
|
||||||
class Identifier(ValidatedStringMixin):
|
class Identifier(ValidatedStringMixin):
|
||||||
ValidationRegex = r'^[A-Za-z_][A-Za-z0-9_]+$'
|
ValidationRegex = r"^[A-Za-z_][A-Za-z0-9_]+$"
|
||||||
|
|
||||||
|
|
||||||
# we need register_pattern for jsonschema validation
|
# we need register_pattern for jsonschema validation
|
||||||
register_pattern(Identifier, r'^[A-Za-z_][A-Za-z0-9_]+$')
|
register_pattern(Identifier, r"^[A-Za-z_][A-Za-z0-9_]+$")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -34,10 +44,10 @@ class AdapterResponse(dbtClassMixin):
|
|||||||
|
|
||||||
|
|
||||||
class ConnectionState(StrEnum):
|
class ConnectionState(StrEnum):
|
||||||
INIT = 'init'
|
INIT = "init"
|
||||||
OPEN = 'open'
|
OPEN = "open"
|
||||||
CLOSED = 'closed'
|
CLOSED = "closed"
|
||||||
FAIL = 'fail'
|
FAIL = "fail"
|
||||||
|
|
||||||
|
|
||||||
@dataclass(init=False)
|
@dataclass(init=False)
|
||||||
@@ -81,8 +91,7 @@ class Connection(ExtensibleDbtClassMixin, Replaceable):
|
|||||||
self._handle.resolve(self)
|
self._handle.resolve(self)
|
||||||
except RecursionError as exc:
|
except RecursionError as exc:
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
"A connection's open() method attempted to read the "
|
"A connection's open() method attempted to read the " "handle value"
|
||||||
"handle value"
|
|
||||||
) from exc
|
) from exc
|
||||||
return self._handle
|
return self._handle
|
||||||
|
|
||||||
@@ -101,8 +110,7 @@ class LazyHandle:
|
|||||||
|
|
||||||
def resolve(self, connection: Connection) -> Connection:
|
def resolve(self, connection: Connection) -> Connection:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
'Opening a new connection, currently in state {}'
|
"Opening a new connection, currently in state {}".format(connection.state)
|
||||||
.format(connection.state)
|
|
||||||
)
|
)
|
||||||
return self.opener(connection)
|
return self.opener(connection)
|
||||||
|
|
||||||
@@ -112,33 +120,24 @@ class LazyHandle:
|
|||||||
# for why we have type: ignore. Maybe someday dataclasses + abstract classes
|
# for why we have type: ignore. Maybe someday dataclasses + abstract classes
|
||||||
# will work.
|
# will work.
|
||||||
@dataclass # type: ignore
|
@dataclass # type: ignore
|
||||||
class Credentials(
|
class Credentials(ExtensibleDbtClassMixin, Replaceable, metaclass=abc.ABCMeta):
|
||||||
ExtensibleDbtClassMixin,
|
|
||||||
Replaceable,
|
|
||||||
metaclass=abc.ABCMeta
|
|
||||||
):
|
|
||||||
database: str
|
database: str
|
||||||
schema: str
|
schema: str
|
||||||
_ALIASES: ClassVar[Dict[str, str]] = field(default={}, init=False)
|
_ALIASES: ClassVar[Dict[str, str]] = field(default={}, init=False)
|
||||||
|
|
||||||
@abc.abstractproperty
|
@abc.abstractproperty
|
||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError("type not implemented for base credentials class")
|
||||||
'type not implemented for base credentials class'
|
|
||||||
)
|
|
||||||
|
|
||||||
def connection_info(
|
def connection_info(
|
||||||
self, *, with_aliases: bool = False
|
self, *, with_aliases: bool = False
|
||||||
) -> Iterable[Tuple[str, Any]]:
|
) -> Iterable[Tuple[str, Any]]:
|
||||||
"""Return an ordered iterator of key/value pairs for pretty-printing.
|
"""Return an ordered iterator of key/value pairs for pretty-printing."""
|
||||||
"""
|
|
||||||
as_dict = self.to_dict(omit_none=False)
|
as_dict = self.to_dict(omit_none=False)
|
||||||
connection_keys = set(self._connection_keys())
|
connection_keys = set(self._connection_keys())
|
||||||
aliases: List[str] = []
|
aliases: List[str] = []
|
||||||
if with_aliases:
|
if with_aliases:
|
||||||
aliases = [
|
aliases = [k for k, v in self._ALIASES.items() if v in connection_keys]
|
||||||
k for k, v in self._ALIASES.items() if v in connection_keys
|
|
||||||
]
|
|
||||||
for key in itertools.chain(self._connection_keys(), aliases):
|
for key in itertools.chain(self._connection_keys(), aliases):
|
||||||
if key in as_dict:
|
if key in as_dict:
|
||||||
yield key, as_dict[key]
|
yield key, as_dict[key]
|
||||||
@@ -162,11 +161,13 @@ class Credentials(
|
|||||||
def __post_serialize__(self, dct):
|
def __post_serialize__(self, dct):
|
||||||
# no super() -- do we need it?
|
# no super() -- do we need it?
|
||||||
if self._ALIASES:
|
if self._ALIASES:
|
||||||
dct.update({
|
dct.update(
|
||||||
|
{
|
||||||
new_name: dct[canonical_name]
|
new_name: dct[canonical_name]
|
||||||
for new_name, canonical_name in self._ALIASES.items()
|
for new_name, canonical_name in self._ALIASES.items()
|
||||||
if canonical_name in dct
|
if canonical_name in dct
|
||||||
})
|
}
|
||||||
|
)
|
||||||
return dct
|
return dct
|
||||||
|
|
||||||
|
|
||||||
@@ -188,10 +189,10 @@ class HasCredentials(Protocol):
|
|||||||
threads: int
|
threads: int
|
||||||
|
|
||||||
def to_target_dict(self):
|
def to_target_dict(self):
|
||||||
raise NotImplementedError('to_target_dict not implemented')
|
raise NotImplementedError("to_target_dict not implemented")
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_QUERY_COMMENT = '''
|
DEFAULT_QUERY_COMMENT = """
|
||||||
{%- set comment_dict = {} -%}
|
{%- set comment_dict = {} -%}
|
||||||
{%- do comment_dict.update(
|
{%- do comment_dict.update(
|
||||||
app='dbt',
|
app='dbt',
|
||||||
@@ -208,7 +209,7 @@ DEFAULT_QUERY_COMMENT = '''
|
|||||||
{%- do comment_dict.update(connection_name=connection_name) -%}
|
{%- do comment_dict.update(connection_name=connection_name) -%}
|
||||||
{%- endif -%}
|
{%- endif -%}
|
||||||
{{ return(tojson(comment_dict)) }}
|
{{ return(tojson(comment_dict)) }}
|
||||||
'''
|
"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from .util import MacroKey, SourceKey
|
|||||||
|
|
||||||
|
|
||||||
MAXIMUM_SEED_SIZE = 1 * 1024 * 1024
|
MAXIMUM_SEED_SIZE = 1 * 1024 * 1024
|
||||||
MAXIMUM_SEED_SIZE_NAME = '1MB'
|
MAXIMUM_SEED_SIZE_NAME = "1MB"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -28,9 +28,7 @@ class FilePath(dbtClassMixin):
|
|||||||
@property
|
@property
|
||||||
def full_path(self) -> str:
|
def full_path(self) -> str:
|
||||||
# useful for symlink preservation
|
# useful for symlink preservation
|
||||||
return os.path.join(
|
return os.path.join(self.project_root, self.searched_path, self.relative_path)
|
||||||
self.project_root, self.searched_path, self.relative_path
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def absolute_path(self) -> str:
|
def absolute_path(self) -> str:
|
||||||
@@ -40,13 +38,10 @@ class FilePath(dbtClassMixin):
|
|||||||
def original_file_path(self) -> str:
|
def original_file_path(self) -> str:
|
||||||
# this is mostly used for reporting errors. It doesn't show the project
|
# this is mostly used for reporting errors. It doesn't show the project
|
||||||
# name, should it?
|
# name, should it?
|
||||||
return os.path.join(
|
return os.path.join(self.searched_path, self.relative_path)
|
||||||
self.searched_path, self.relative_path
|
|
||||||
)
|
|
||||||
|
|
||||||
def seed_too_large(self) -> bool:
|
def seed_too_large(self) -> bool:
|
||||||
"""Return whether the file this represents is over the seed size limit
|
"""Return whether the file this represents is over the seed size limit"""
|
||||||
"""
|
|
||||||
return os.stat(self.full_path).st_size > MAXIMUM_SEED_SIZE
|
return os.stat(self.full_path).st_size > MAXIMUM_SEED_SIZE
|
||||||
|
|
||||||
|
|
||||||
@@ -57,35 +52,35 @@ class FileHash(dbtClassMixin):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def empty(cls):
|
def empty(cls):
|
||||||
return FileHash(name='none', checksum='')
|
return FileHash(name="none", checksum="")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def path(cls, path: str):
|
def path(cls, path: str):
|
||||||
return FileHash(name='path', checksum=path)
|
return FileHash(name="path", checksum=path)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if not isinstance(other, FileHash):
|
if not isinstance(other, FileHash):
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
|
||||||
if self.name == 'none' or self.name != other.name:
|
if self.name == "none" or self.name != other.name:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return self.checksum == other.checksum
|
return self.checksum == other.checksum
|
||||||
|
|
||||||
def compare(self, contents: str) -> bool:
|
def compare(self, contents: str) -> bool:
|
||||||
"""Compare the file contents with the given hash"""
|
"""Compare the file contents with the given hash"""
|
||||||
if self.name == 'none':
|
if self.name == "none":
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return self.from_contents(contents, name=self.name) == self.checksum
|
return self.from_contents(contents, name=self.name) == self.checksum
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_contents(cls, contents: str, name='sha256') -> 'FileHash':
|
def from_contents(cls, contents: str, name="sha256") -> "FileHash":
|
||||||
"""Create a file hash from the given file contents. The hash is always
|
"""Create a file hash from the given file contents. The hash is always
|
||||||
the utf-8 encoding of the contents given, because dbt only reads files
|
the utf-8 encoding of the contents given, because dbt only reads files
|
||||||
as utf-8.
|
as utf-8.
|
||||||
"""
|
"""
|
||||||
data = contents.encode('utf-8')
|
data = contents.encode("utf-8")
|
||||||
checksum = hashlib.new(name, data).hexdigest()
|
checksum = hashlib.new(name, data).hexdigest()
|
||||||
return cls(name=name, checksum=checksum)
|
return cls(name=name, checksum=checksum)
|
||||||
|
|
||||||
@@ -94,24 +89,25 @@ class FileHash(dbtClassMixin):
|
|||||||
class RemoteFile(dbtClassMixin):
|
class RemoteFile(dbtClassMixin):
|
||||||
@property
|
@property
|
||||||
def searched_path(self) -> str:
|
def searched_path(self) -> str:
|
||||||
return 'from remote system'
|
return "from remote system"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def relative_path(self) -> str:
|
def relative_path(self) -> str:
|
||||||
return 'from remote system'
|
return "from remote system"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def absolute_path(self) -> str:
|
def absolute_path(self) -> str:
|
||||||
return 'from remote system'
|
return "from remote system"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def original_file_path(self):
|
def original_file_path(self):
|
||||||
return 'from remote system'
|
return "from remote system"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SourceFile(dbtClassMixin):
|
class SourceFile(dbtClassMixin):
|
||||||
"""Define a source file in dbt"""
|
"""Define a source file in dbt"""
|
||||||
|
|
||||||
path: Union[FilePath, RemoteFile] # the path information
|
path: Union[FilePath, RemoteFile] # the path information
|
||||||
checksum: FileHash
|
checksum: FileHash
|
||||||
# we don't want to serialize this
|
# we don't want to serialize this
|
||||||
@@ -133,14 +129,14 @@ class SourceFile(dbtClassMixin):
|
|||||||
def search_key(self) -> Optional[str]:
|
def search_key(self) -> Optional[str]:
|
||||||
if isinstance(self.path, RemoteFile):
|
if isinstance(self.path, RemoteFile):
|
||||||
return None
|
return None
|
||||||
if self.checksum.name == 'none':
|
if self.checksum.name == "none":
|
||||||
return None
|
return None
|
||||||
return self.path.search_key
|
return self.path.search_key
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def contents(self) -> str:
|
def contents(self) -> str:
|
||||||
if self._contents is None:
|
if self._contents is None:
|
||||||
raise InternalException('SourceFile has no contents!')
|
raise InternalException("SourceFile has no contents!")
|
||||||
return self._contents
|
return self._contents
|
||||||
|
|
||||||
@contents.setter
|
@contents.setter
|
||||||
@@ -148,20 +144,20 @@ class SourceFile(dbtClassMixin):
|
|||||||
self._contents = value
|
self._contents = value
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def empty(cls, path: FilePath) -> 'SourceFile':
|
def empty(cls, path: FilePath) -> "SourceFile":
|
||||||
self = cls(path=path, checksum=FileHash.empty())
|
self = cls(path=path, checksum=FileHash.empty())
|
||||||
self.contents = ''
|
self.contents = ""
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def big_seed(cls, path: FilePath) -> 'SourceFile':
|
def big_seed(cls, path: FilePath) -> "SourceFile":
|
||||||
"""Parse seeds over the size limit with just the path"""
|
"""Parse seeds over the size limit with just the path"""
|
||||||
self = cls(path=path, checksum=FileHash.path(path.original_file_path))
|
self = cls(path=path, checksum=FileHash.path(path.original_file_path))
|
||||||
self.contents = ''
|
self.contents = ""
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def remote(cls, contents: str) -> 'SourceFile':
|
def remote(cls, contents: str) -> "SourceFile":
|
||||||
self = cls(path=RemoteFile(), checksum=FileHash.empty())
|
self = cls(path=RemoteFile(), checksum=FileHash.empty())
|
||||||
self.contents = contents
|
self.contents = contents
|
||||||
return self
|
return self
|
||||||
|
|||||||
@@ -58,31 +58,29 @@ class CompiledNode(ParsedNode, CompiledNodeMixin):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CompiledAnalysisNode(CompiledNode):
|
class CompiledAnalysisNode(CompiledNode):
|
||||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Analysis]})
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Analysis]})
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CompiledHookNode(CompiledNode):
|
class CompiledHookNode(CompiledNode):
|
||||||
resource_type: NodeType = field(
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Operation]})
|
||||||
metadata={'restrict': [NodeType.Operation]}
|
|
||||||
)
|
|
||||||
index: Optional[int] = None
|
index: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CompiledModelNode(CompiledNode):
|
class CompiledModelNode(CompiledNode):
|
||||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Model]})
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Model]})
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CompiledRPCNode(CompiledNode):
|
class CompiledRPCNode(CompiledNode):
|
||||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.RPCCall]})
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.RPCCall]})
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CompiledSeedNode(CompiledNode):
|
class CompiledSeedNode(CompiledNode):
|
||||||
# keep this in sync with ParsedSeedNode!
|
# keep this in sync with ParsedSeedNode!
|
||||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Seed]})
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Seed]})
|
||||||
config: SeedConfig = field(default_factory=SeedConfig)
|
config: SeedConfig = field(default_factory=SeedConfig)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -96,26 +94,25 @@ class CompiledSeedNode(CompiledNode):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CompiledSnapshotNode(CompiledNode):
|
class CompiledSnapshotNode(CompiledNode):
|
||||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Snapshot]})
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Snapshot]})
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CompiledDataTestNode(CompiledNode):
|
class CompiledDataTestNode(CompiledNode):
|
||||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Test]})
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Test]})
|
||||||
config: TestConfig = field(default_factory=TestConfig)
|
config: TestConfig = field(default_factory=TestConfig)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CompiledSchemaTestNode(CompiledNode, HasTestMetadata):
|
class CompiledSchemaTestNode(CompiledNode, HasTestMetadata):
|
||||||
# keep this in sync with ParsedSchemaTestNode!
|
# keep this in sync with ParsedSchemaTestNode!
|
||||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Test]})
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Test]})
|
||||||
column_name: Optional[str] = None
|
column_name: Optional[str] = None
|
||||||
config: TestConfig = field(default_factory=TestConfig)
|
config: TestConfig = field(default_factory=TestConfig)
|
||||||
|
|
||||||
def same_config(self, other) -> bool:
|
def same_config(self, other) -> bool:
|
||||||
return (
|
return self.unrendered_config.get("severity") == other.unrendered_config.get(
|
||||||
self.unrendered_config.get('severity') ==
|
"severity"
|
||||||
other.unrendered_config.get('severity')
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def same_column_name(self, other) -> bool:
|
def same_column_name(self, other) -> bool:
|
||||||
@@ -125,11 +122,7 @@ class CompiledSchemaTestNode(CompiledNode, HasTestMetadata):
|
|||||||
if other is None:
|
if other is None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return (
|
return self.same_config(other) and self.same_fqn(other) and True
|
||||||
self.same_config(other) and
|
|
||||||
self.same_fqn(other) and
|
|
||||||
True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
CompiledTestNode = Union[CompiledDataTestNode, CompiledSchemaTestNode]
|
CompiledTestNode = Union[CompiledDataTestNode, CompiledSchemaTestNode]
|
||||||
@@ -175,8 +168,7 @@ def parsed_instance_for(compiled: CompiledNode) -> ParsedResource:
|
|||||||
cls = PARSED_TYPES.get(type(compiled))
|
cls = PARSED_TYPES.get(type(compiled))
|
||||||
if cls is None:
|
if cls is None:
|
||||||
# how???
|
# how???
|
||||||
raise ValueError('invalid resource_type: {}'
|
raise ValueError("invalid resource_type: {}".format(compiled.resource_type))
|
||||||
.format(compiled.resource_type))
|
|
||||||
|
|
||||||
return cls.from_dict(compiled.to_dict(omit_none=True))
|
return cls.from_dict(compiled.to_dict(omit_none=True))
|
||||||
|
|
||||||
|
|||||||
@@ -4,25 +4,51 @@ from dataclasses import dataclass, field
|
|||||||
from itertools import chain, islice
|
from itertools import chain, islice
|
||||||
from multiprocessing.synchronize import Lock
|
from multiprocessing.synchronize import Lock
|
||||||
from typing import (
|
from typing import (
|
||||||
Dict, List, Optional, Union, Mapping, MutableMapping, Any, Set, Tuple,
|
Dict,
|
||||||
TypeVar, Callable, Iterable, Generic, cast, AbstractSet
|
List,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
Mapping,
|
||||||
|
MutableMapping,
|
||||||
|
Any,
|
||||||
|
Set,
|
||||||
|
Tuple,
|
||||||
|
TypeVar,
|
||||||
|
Callable,
|
||||||
|
Iterable,
|
||||||
|
Generic,
|
||||||
|
cast,
|
||||||
|
AbstractSet,
|
||||||
)
|
)
|
||||||
from typing_extensions import Protocol
|
from typing_extensions import Protocol
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from dbt.contracts.graph.compiled import (
|
from dbt.contracts.graph.compiled import (
|
||||||
CompileResultNode, ManifestNode, NonSourceCompiledNode, GraphMemberNode
|
CompileResultNode,
|
||||||
|
ManifestNode,
|
||||||
|
NonSourceCompiledNode,
|
||||||
|
GraphMemberNode,
|
||||||
)
|
)
|
||||||
from dbt.contracts.graph.parsed import (
|
from dbt.contracts.graph.parsed import (
|
||||||
ParsedMacro, ParsedDocumentation, ParsedNodePatch, ParsedMacroPatch,
|
ParsedMacro,
|
||||||
ParsedSourceDefinition, ParsedExposure
|
ParsedDocumentation,
|
||||||
|
ParsedNodePatch,
|
||||||
|
ParsedMacroPatch,
|
||||||
|
ParsedSourceDefinition,
|
||||||
|
ParsedExposure,
|
||||||
)
|
)
|
||||||
from dbt.contracts.files import SourceFile
|
from dbt.contracts.files import SourceFile
|
||||||
from dbt.contracts.util import (
|
from dbt.contracts.util import (
|
||||||
BaseArtifactMetadata, MacroKey, SourceKey, ArtifactMixin, schema_version
|
BaseArtifactMetadata,
|
||||||
|
MacroKey,
|
||||||
|
SourceKey,
|
||||||
|
ArtifactMixin,
|
||||||
|
schema_version,
|
||||||
)
|
)
|
||||||
from dbt.exceptions import (
|
from dbt.exceptions import (
|
||||||
raise_duplicate_resource_name, raise_compiler_error, warn_or_error,
|
raise_duplicate_resource_name,
|
||||||
|
raise_compiler_error,
|
||||||
|
warn_or_error,
|
||||||
raise_invalid_patch,
|
raise_invalid_patch,
|
||||||
)
|
)
|
||||||
from dbt.helper_types import PathSet
|
from dbt.helper_types import PathSet
|
||||||
@@ -40,12 +66,12 @@ RefName = str
|
|||||||
UniqueID = str
|
UniqueID = str
|
||||||
|
|
||||||
|
|
||||||
K_T = TypeVar('K_T')
|
K_T = TypeVar("K_T")
|
||||||
V_T = TypeVar('V_T')
|
V_T = TypeVar("V_T")
|
||||||
|
|
||||||
|
|
||||||
class PackageAwareCache(Generic[K_T, V_T]):
|
class PackageAwareCache(Generic[K_T, V_T]):
|
||||||
def __init__(self, manifest: 'Manifest'):
|
def __init__(self, manifest: "Manifest"):
|
||||||
self.storage: Dict[K_T, Dict[PackageName, UniqueID]] = {}
|
self.storage: Dict[K_T, Dict[PackageName, UniqueID]] = {}
|
||||||
self._manifest = manifest
|
self._manifest = manifest
|
||||||
self.populate()
|
self.populate()
|
||||||
@@ -95,12 +121,10 @@ class DocCache(PackageAwareCache[DocName, ParsedDocumentation]):
|
|||||||
for doc in self._manifest.docs.values():
|
for doc in self._manifest.docs.values():
|
||||||
self.add_doc(doc)
|
self.add_doc(doc)
|
||||||
|
|
||||||
def perform_lookup(
|
def perform_lookup(self, unique_id: UniqueID) -> ParsedDocumentation:
|
||||||
self, unique_id: UniqueID
|
|
||||||
) -> ParsedDocumentation:
|
|
||||||
if unique_id not in self._manifest.docs:
|
if unique_id not in self._manifest.docs:
|
||||||
raise dbt.exceptions.InternalException(
|
raise dbt.exceptions.InternalException(
|
||||||
f'Doc {unique_id} found in cache but not found in manifest'
|
f"Doc {unique_id} found in cache but not found in manifest"
|
||||||
)
|
)
|
||||||
return self._manifest.docs[unique_id]
|
return self._manifest.docs[unique_id]
|
||||||
|
|
||||||
@@ -117,12 +141,10 @@ class SourceCache(PackageAwareCache[SourceKey, ParsedSourceDefinition]):
|
|||||||
for source in self._manifest.sources.values():
|
for source in self._manifest.sources.values():
|
||||||
self.add_source(source)
|
self.add_source(source)
|
||||||
|
|
||||||
def perform_lookup(
|
def perform_lookup(self, unique_id: UniqueID) -> ParsedSourceDefinition:
|
||||||
self, unique_id: UniqueID
|
|
||||||
) -> ParsedSourceDefinition:
|
|
||||||
if unique_id not in self._manifest.sources:
|
if unique_id not in self._manifest.sources:
|
||||||
raise dbt.exceptions.InternalException(
|
raise dbt.exceptions.InternalException(
|
||||||
f'Source {unique_id} found in cache but not found in manifest'
|
f"Source {unique_id} found in cache but not found in manifest"
|
||||||
)
|
)
|
||||||
return self._manifest.sources[unique_id]
|
return self._manifest.sources[unique_id]
|
||||||
|
|
||||||
@@ -131,7 +153,7 @@ class RefableCache(PackageAwareCache[RefName, ManifestNode]):
|
|||||||
# refables are actually unique, so the Dict[PackageName, UniqueID] will
|
# refables are actually unique, so the Dict[PackageName, UniqueID] will
|
||||||
# only ever have exactly one value, but doing 3 dict lookups instead of 1
|
# only ever have exactly one value, but doing 3 dict lookups instead of 1
|
||||||
# is not a big deal at all and retains consistency
|
# is not a big deal at all and retains consistency
|
||||||
def __init__(self, manifest: 'Manifest'):
|
def __init__(self, manifest: "Manifest"):
|
||||||
self._cached_types = set(NodeType.refable())
|
self._cached_types = set(NodeType.refable())
|
||||||
super().__init__(manifest)
|
super().__init__(manifest)
|
||||||
|
|
||||||
@@ -145,12 +167,10 @@ class RefableCache(PackageAwareCache[RefName, ManifestNode]):
|
|||||||
for node in self._manifest.nodes.values():
|
for node in self._manifest.nodes.values():
|
||||||
self.add_node(node)
|
self.add_node(node)
|
||||||
|
|
||||||
def perform_lookup(
|
def perform_lookup(self, unique_id: UniqueID) -> ManifestNode:
|
||||||
self, unique_id: UniqueID
|
|
||||||
) -> ManifestNode:
|
|
||||||
if unique_id not in self._manifest.nodes:
|
if unique_id not in self._manifest.nodes:
|
||||||
raise dbt.exceptions.InternalException(
|
raise dbt.exceptions.InternalException(
|
||||||
f'Node {unique_id} found in cache but not found in manifest'
|
f"Node {unique_id} found in cache but not found in manifest"
|
||||||
)
|
)
|
||||||
return self._manifest.nodes[unique_id]
|
return self._manifest.nodes[unique_id]
|
||||||
|
|
||||||
@@ -171,30 +191,31 @@ def _search_packages(
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ManifestMetadata(BaseArtifactMetadata):
|
class ManifestMetadata(BaseArtifactMetadata):
|
||||||
"""Metadata for the manifest."""
|
"""Metadata for the manifest."""
|
||||||
|
|
||||||
dbt_schema_version: str = field(
|
dbt_schema_version: str = field(
|
||||||
default_factory=lambda: str(WritableManifest.dbt_schema_version)
|
default_factory=lambda: str(WritableManifest.dbt_schema_version)
|
||||||
)
|
)
|
||||||
project_id: Optional[str] = field(
|
project_id: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
'description': 'A unique identifier for the project',
|
"description": "A unique identifier for the project",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
user_id: Optional[UUID] = field(
|
user_id: Optional[UUID] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
'description': 'A unique identifier for the user',
|
"description": "A unique identifier for the user",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
send_anonymous_usage_stats: Optional[bool] = field(
|
send_anonymous_usage_stats: Optional[bool] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata=dict(description=(
|
metadata=dict(
|
||||||
'Whether dbt is configured to send anonymous usage statistics'
|
description=("Whether dbt is configured to send anonymous usage statistics")
|
||||||
)),
|
),
|
||||||
)
|
)
|
||||||
adapter_type: Optional[str] = field(
|
adapter_type: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata=dict(description='The type name of the adapter'),
|
metadata=dict(description="The type name of the adapter"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -205,9 +226,7 @@ class ManifestMetadata(BaseArtifactMetadata):
|
|||||||
self.user_id = tracking.active_user.id
|
self.user_id = tracking.active_user.id
|
||||||
|
|
||||||
if self.send_anonymous_usage_stats is None:
|
if self.send_anonymous_usage_stats is None:
|
||||||
self.send_anonymous_usage_stats = (
|
self.send_anonymous_usage_stats = not tracking.active_user.do_not_track
|
||||||
not tracking.active_user.do_not_track
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default(cls):
|
def default(cls):
|
||||||
@@ -281,7 +300,7 @@ class MaterializationCandidate(MacroCandidate):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_macro(
|
def from_macro(
|
||||||
cls, candidate: MacroCandidate, specificity: Specificity
|
cls, candidate: MacroCandidate, specificity: Specificity
|
||||||
) -> 'MaterializationCandidate':
|
) -> "MaterializationCandidate":
|
||||||
return cls(
|
return cls(
|
||||||
locality=candidate.locality,
|
locality=candidate.locality,
|
||||||
macro=candidate.macro,
|
macro=candidate.macro,
|
||||||
@@ -292,15 +311,14 @@ class MaterializationCandidate(MacroCandidate):
|
|||||||
if not isinstance(other, MaterializationCandidate):
|
if not isinstance(other, MaterializationCandidate):
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
equal = (
|
equal = (
|
||||||
self.specificity == other.specificity and
|
self.specificity == other.specificity and self.locality == other.locality
|
||||||
self.locality == other.locality
|
|
||||||
)
|
)
|
||||||
if equal:
|
if equal:
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
'Found two materializations with the name {} (packages {} and '
|
"Found two materializations with the name {} (packages {} and "
|
||||||
'{}). dbt cannot resolve this ambiguity'
|
"{}). dbt cannot resolve this ambiguity".format(
|
||||||
.format(self.macro.name, self.macro.package_name,
|
self.macro.name, self.macro.package_name, other.macro.package_name
|
||||||
other.macro.package_name)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return equal
|
return equal
|
||||||
@@ -319,7 +337,7 @@ class MaterializationCandidate(MacroCandidate):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
M = TypeVar('M', bound=MacroCandidate)
|
M = TypeVar("M", bound=MacroCandidate)
|
||||||
|
|
||||||
|
|
||||||
class CandidateList(List[M]):
|
class CandidateList(List[M]):
|
||||||
@@ -347,10 +365,10 @@ class Searchable(Protocol):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def search_name(self) -> str:
|
def search_name(self) -> str:
|
||||||
raise NotImplementedError('search_name not implemented')
|
raise NotImplementedError("search_name not implemented")
|
||||||
|
|
||||||
|
|
||||||
N = TypeVar('N', bound=Searchable)
|
N = TypeVar("N", bound=Searchable)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -382,7 +400,7 @@ class NameSearcher(Generic[N]):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
D = TypeVar('D')
|
D = TypeVar("D")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -393,19 +411,18 @@ class Disabled(Generic[D]):
|
|||||||
MaybeDocumentation = Optional[ParsedDocumentation]
|
MaybeDocumentation = Optional[ParsedDocumentation]
|
||||||
|
|
||||||
|
|
||||||
MaybeParsedSource = Optional[Union[
|
MaybeParsedSource = Optional[
|
||||||
|
Union[
|
||||||
ParsedSourceDefinition,
|
ParsedSourceDefinition,
|
||||||
Disabled[ParsedSourceDefinition],
|
Disabled[ParsedSourceDefinition],
|
||||||
]]
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
MaybeNonSource = Optional[Union[
|
MaybeNonSource = Optional[Union[ManifestNode, Disabled[ManifestNode]]]
|
||||||
ManifestNode,
|
|
||||||
Disabled[ManifestNode]
|
|
||||||
]]
|
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar('T', bound=GraphMemberNode)
|
T = TypeVar("T", bound=GraphMemberNode)
|
||||||
|
|
||||||
|
|
||||||
def _update_into(dest: MutableMapping[str, T], new_item: T):
|
def _update_into(dest: MutableMapping[str, T], new_item: T):
|
||||||
@@ -416,14 +433,13 @@ def _update_into(dest: MutableMapping[str, T], new_item: T):
|
|||||||
unique_id = new_item.unique_id
|
unique_id = new_item.unique_id
|
||||||
if unique_id not in dest:
|
if unique_id not in dest:
|
||||||
raise dbt.exceptions.RuntimeException(
|
raise dbt.exceptions.RuntimeException(
|
||||||
f'got an update_{new_item.resource_type} call with an '
|
f"got an update_{new_item.resource_type} call with an "
|
||||||
f'unrecognized {new_item.resource_type}: {new_item.unique_id}'
|
f"unrecognized {new_item.resource_type}: {new_item.unique_id}"
|
||||||
)
|
)
|
||||||
existing = dest[unique_id]
|
existing = dest[unique_id]
|
||||||
if new_item.original_file_path != existing.original_file_path:
|
if new_item.original_file_path != existing.original_file_path:
|
||||||
raise dbt.exceptions.RuntimeException(
|
raise dbt.exceptions.RuntimeException(
|
||||||
f'cannot update a {new_item.resource_type} to have a new file '
|
f"cannot update a {new_item.resource_type} to have a new file " f"path!"
|
||||||
f'path!'
|
|
||||||
)
|
)
|
||||||
dest[unique_id] = new_item
|
dest[unique_id] = new_item
|
||||||
|
|
||||||
@@ -447,6 +463,7 @@ class MacroMethods:
|
|||||||
"""
|
"""
|
||||||
filter: Optional[Callable[[MacroCandidate], bool]] = None
|
filter: Optional[Callable[[MacroCandidate], bool]] = None
|
||||||
if package is not None:
|
if package is not None:
|
||||||
|
|
||||||
def filter(candidate: MacroCandidate) -> bool:
|
def filter(candidate: MacroCandidate) -> bool:
|
||||||
return package == candidate.macro.package_name
|
return package == candidate.macro.package_name
|
||||||
|
|
||||||
@@ -469,11 +486,12 @@ class MacroMethods:
|
|||||||
- return the `generate_{component}_name` macro from the 'dbt'
|
- return the `generate_{component}_name` macro from the 'dbt'
|
||||||
internal project
|
internal project
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def filter(candidate: MacroCandidate) -> bool:
|
def filter(candidate: MacroCandidate) -> bool:
|
||||||
return candidate.locality != Locality.Imported
|
return candidate.locality != Locality.Imported
|
||||||
|
|
||||||
candidates: CandidateList = self._find_macros_by_name(
|
candidates: CandidateList = self._find_macros_by_name(
|
||||||
name=f'generate_{component}_name',
|
name=f"generate_{component}_name",
|
||||||
root_project_name=root_project_name,
|
root_project_name=root_project_name,
|
||||||
# filter out imported packages
|
# filter out imported packages
|
||||||
filter=filter,
|
filter=filter,
|
||||||
@@ -484,12 +502,12 @@ class MacroMethods:
|
|||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
root_project_name: str,
|
root_project_name: str,
|
||||||
filter: Optional[Callable[[MacroCandidate], bool]] = None
|
filter: Optional[Callable[[MacroCandidate], bool]] = None,
|
||||||
) -> CandidateList:
|
) -> CandidateList:
|
||||||
"""Find macros by their name.
|
"""Find macros by their name."""
|
||||||
"""
|
|
||||||
# avoid an import cycle
|
# avoid an import cycle
|
||||||
from dbt.adapters.factory import get_adapter_package_names
|
from dbt.adapters.factory import get_adapter_package_names
|
||||||
|
|
||||||
candidates: CandidateList = CandidateList()
|
candidates: CandidateList = CandidateList()
|
||||||
packages = set(get_adapter_package_names(self.metadata.adapter_type))
|
packages = set(get_adapter_package_names(self.metadata.adapter_type))
|
||||||
for unique_id, macro in self.macros.items():
|
for unique_id, macro in self.macros.items():
|
||||||
@@ -507,8 +525,8 @@ class MacroMethods:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Manifest(MacroMethods):
|
class Manifest(MacroMethods):
|
||||||
"""The manifest for the full graph, after parsing and during compilation.
|
"""The manifest for the full graph, after parsing and during compilation."""
|
||||||
"""
|
|
||||||
# These attributes are both positional and by keyword. If an attribute
|
# These attributes are both positional and by keyword. If an attribute
|
||||||
# is added it must all be added in the __reduce_ex__ method in the
|
# is added it must all be added in the __reduce_ex__ method in the
|
||||||
# args tuple in the right position.
|
# args tuple in the right position.
|
||||||
@@ -541,7 +559,7 @@ class Manifest(MacroMethods):
|
|||||||
"""
|
"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
existing = self.nodes[new_node.unique_id]
|
existing = self.nodes[new_node.unique_id]
|
||||||
if getattr(existing, 'compiled', False):
|
if getattr(existing, "compiled", False):
|
||||||
# already compiled -> must be a NonSourceCompiledNode
|
# already compiled -> must be a NonSourceCompiledNode
|
||||||
return cast(NonSourceCompiledNode, existing)
|
return cast(NonSourceCompiledNode, existing)
|
||||||
_update_into(self.nodes, new_node)
|
_update_into(self.nodes, new_node)
|
||||||
@@ -563,39 +581,30 @@ class Manifest(MacroMethods):
|
|||||||
manifest!
|
manifest!
|
||||||
"""
|
"""
|
||||||
self.flat_graph = {
|
self.flat_graph = {
|
||||||
'nodes': {
|
"nodes": {k: v.to_dict(omit_none=False) for k, v in self.nodes.items()},
|
||||||
k: v.to_dict(omit_none=False)
|
"sources": {k: v.to_dict(omit_none=False) for k, v in self.sources.items()},
|
||||||
for k, v in self.nodes.items()
|
|
||||||
},
|
|
||||||
'sources': {
|
|
||||||
k: v.to_dict(omit_none=False)
|
|
||||||
for k, v in self.sources.items()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def find_disabled_by_name(
|
def find_disabled_by_name(
|
||||||
self, name: str, package: Optional[str] = None
|
self, name: str, package: Optional[str] = None
|
||||||
) -> Optional[ManifestNode]:
|
) -> Optional[ManifestNode]:
|
||||||
searcher: NameSearcher = NameSearcher(
|
searcher: NameSearcher = NameSearcher(name, package, NodeType.refable())
|
||||||
name, package, NodeType.refable()
|
|
||||||
)
|
|
||||||
result = searcher.search(self.disabled)
|
result = searcher.search(self.disabled)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def find_disabled_source_by_name(
|
def find_disabled_source_by_name(
|
||||||
self, source_name: str, table_name: str, package: Optional[str] = None
|
self, source_name: str, table_name: str, package: Optional[str] = None
|
||||||
) -> Optional[ParsedSourceDefinition]:
|
) -> Optional[ParsedSourceDefinition]:
|
||||||
search_name = f'{source_name}.{table_name}'
|
search_name = f"{source_name}.{table_name}"
|
||||||
searcher: NameSearcher = NameSearcher(
|
searcher: NameSearcher = NameSearcher(search_name, package, [NodeType.Source])
|
||||||
search_name, package, [NodeType.Source]
|
|
||||||
)
|
|
||||||
result = searcher.search(self.disabled)
|
result = searcher.search(self.disabled)
|
||||||
if result is not None:
|
if result is not None:
|
||||||
assert isinstance(result, ParsedSourceDefinition)
|
assert isinstance(result, ParsedSourceDefinition)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _materialization_candidates_for(
|
def _materialization_candidates_for(
|
||||||
self, project_name: str,
|
self,
|
||||||
|
project_name: str,
|
||||||
materialization_name: str,
|
materialization_name: str,
|
||||||
adapter_type: Optional[str],
|
adapter_type: Optional[str],
|
||||||
) -> CandidateList:
|
) -> CandidateList:
|
||||||
@@ -618,13 +627,16 @@ class Manifest(MacroMethods):
|
|||||||
def find_materialization_macro_by_name(
|
def find_materialization_macro_by_name(
|
||||||
self, project_name: str, materialization_name: str, adapter_type: str
|
self, project_name: str, materialization_name: str, adapter_type: str
|
||||||
) -> Optional[ParsedMacro]:
|
) -> Optional[ParsedMacro]:
|
||||||
candidates: CandidateList = CandidateList(chain.from_iterable(
|
candidates: CandidateList = CandidateList(
|
||||||
|
chain.from_iterable(
|
||||||
self._materialization_candidates_for(
|
self._materialization_candidates_for(
|
||||||
project_name=project_name,
|
project_name=project_name,
|
||||||
materialization_name=materialization_name,
|
materialization_name=materialization_name,
|
||||||
adapter_type=atype,
|
adapter_type=atype,
|
||||||
) for atype in (adapter_type, None)
|
)
|
||||||
))
|
for atype in (adapter_type, None)
|
||||||
|
)
|
||||||
|
)
|
||||||
return candidates.last()
|
return candidates.last()
|
||||||
|
|
||||||
def get_resource_fqns(self) -> Mapping[str, PathSet]:
|
def get_resource_fqns(self) -> Mapping[str, PathSet]:
|
||||||
@@ -648,9 +660,7 @@ class Manifest(MacroMethods):
|
|||||||
if node.resource_type in NodeType.refable():
|
if node.resource_type in NodeType.refable():
|
||||||
self._refs_cache.add_node(node)
|
self._refs_cache.add_node(node)
|
||||||
|
|
||||||
def patch_macros(
|
def patch_macros(self, patches: MutableMapping[MacroKey, ParsedMacroPatch]) -> None:
|
||||||
self, patches: MutableMapping[MacroKey, ParsedMacroPatch]
|
|
||||||
) -> None:
|
|
||||||
for macro in self.macros.values():
|
for macro in self.macros.values():
|
||||||
key = (macro.package_name, macro.name)
|
key = (macro.package_name, macro.name)
|
||||||
patch = patches.pop(key, None)
|
patch = patches.pop(key, None)
|
||||||
@@ -662,12 +672,10 @@ class Manifest(MacroMethods):
|
|||||||
for patch in patches.values():
|
for patch in patches.values():
|
||||||
warn_or_error(
|
warn_or_error(
|
||||||
f'WARNING: Found documentation for macro "{patch.name}" '
|
f'WARNING: Found documentation for macro "{patch.name}" '
|
||||||
f'which was not found'
|
f"which was not found"
|
||||||
)
|
)
|
||||||
|
|
||||||
def patch_nodes(
|
def patch_nodes(self, patches: MutableMapping[str, ParsedNodePatch]) -> None:
|
||||||
self, patches: MutableMapping[str, ParsedNodePatch]
|
|
||||||
) -> None:
|
|
||||||
"""Patch nodes with the given dict of patches. Note that this consumes
|
"""Patch nodes with the given dict of patches. Note that this consumes
|
||||||
the input!
|
the input!
|
||||||
This relies on the fact that all nodes have unique _name_ fields, not
|
This relies on the fact that all nodes have unique _name_ fields, not
|
||||||
@@ -684,15 +692,15 @@ class Manifest(MacroMethods):
|
|||||||
|
|
||||||
expected_key = node.resource_type.pluralize()
|
expected_key = node.resource_type.pluralize()
|
||||||
if expected_key != patch.yaml_key:
|
if expected_key != patch.yaml_key:
|
||||||
if patch.yaml_key == 'models':
|
if patch.yaml_key == "models":
|
||||||
deprecations.warn(
|
deprecations.warn(
|
||||||
'models-key-mismatch',
|
"models-key-mismatch",
|
||||||
patch=patch, node=node, expected_key=expected_key
|
patch=patch,
|
||||||
|
node=node,
|
||||||
|
expected_key=expected_key,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise_invalid_patch(
|
raise_invalid_patch(node, patch.yaml_key, patch.original_file_path)
|
||||||
node, patch.yaml_key, patch.original_file_path
|
|
||||||
)
|
|
||||||
|
|
||||||
node.patch(patch)
|
node.patch(patch)
|
||||||
|
|
||||||
@@ -701,22 +709,25 @@ class Manifest(MacroMethods):
|
|||||||
for patch in patches.values():
|
for patch in patches.values():
|
||||||
# since patches aren't nodes, we can't use the existing
|
# since patches aren't nodes, we can't use the existing
|
||||||
# target_not_found warning
|
# target_not_found warning
|
||||||
logger.debug((
|
logger.debug(
|
||||||
|
(
|
||||||
'WARNING: Found documentation for resource "{}" which was '
|
'WARNING: Found documentation for resource "{}" which was '
|
||||||
'not found or is disabled').format(patch.name)
|
"not found or is disabled"
|
||||||
|
).format(patch.name)
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_used_schemas(self, resource_types=None):
|
def get_used_schemas(self, resource_types=None):
|
||||||
return frozenset({
|
return frozenset(
|
||||||
(node.database, node.schema) for node in
|
{
|
||||||
chain(self.nodes.values(), self.sources.values())
|
(node.database, node.schema)
|
||||||
|
for node in chain(self.nodes.values(), self.sources.values())
|
||||||
if not resource_types or node.resource_type in resource_types
|
if not resource_types or node.resource_type in resource_types
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def get_used_databases(self):
|
def get_used_databases(self):
|
||||||
return frozenset(
|
return frozenset(
|
||||||
x.database for x in
|
x.database for x in chain(self.nodes.values(), self.sources.values())
|
||||||
chain(self.nodes.values(), self.sources.values())
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def deepcopy(self):
|
def deepcopy(self):
|
||||||
@@ -733,11 +744,13 @@ class Manifest(MacroMethods):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def writable_manifest(self):
|
def writable_manifest(self):
|
||||||
edge_members = list(chain(
|
edge_members = list(
|
||||||
|
chain(
|
||||||
self.nodes.values(),
|
self.nodes.values(),
|
||||||
self.sources.values(),
|
self.sources.values(),
|
||||||
self.exposures.values(),
|
self.exposures.values(),
|
||||||
))
|
)
|
||||||
|
)
|
||||||
forward_edges, backward_edges = build_edges(edge_members)
|
forward_edges, backward_edges = build_edges(edge_members)
|
||||||
|
|
||||||
return WritableManifest(
|
return WritableManifest(
|
||||||
@@ -771,7 +784,7 @@ class Manifest(MacroMethods):
|
|||||||
else:
|
else:
|
||||||
# something terrible has happened
|
# something terrible has happened
|
||||||
raise dbt.exceptions.InternalException(
|
raise dbt.exceptions.InternalException(
|
||||||
'Expected node {} not found in manifest'.format(unique_id)
|
"Expected node {} not found in manifest".format(unique_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -820,9 +833,7 @@ class Manifest(MacroMethods):
|
|||||||
|
|
||||||
# it's possible that the node is disabled
|
# it's possible that the node is disabled
|
||||||
if disabled is None:
|
if disabled is None:
|
||||||
disabled = self.find_disabled_by_name(
|
disabled = self.find_disabled_by_name(target_model_name, pkg)
|
||||||
target_model_name, pkg
|
|
||||||
)
|
|
||||||
|
|
||||||
if disabled is not None:
|
if disabled is not None:
|
||||||
return Disabled(disabled)
|
return Disabled(disabled)
|
||||||
@@ -833,7 +844,7 @@ class Manifest(MacroMethods):
|
|||||||
target_source_name: str,
|
target_source_name: str,
|
||||||
target_table_name: str,
|
target_table_name: str,
|
||||||
current_project: str,
|
current_project: str,
|
||||||
node_package: str
|
node_package: str,
|
||||||
) -> MaybeParsedSource:
|
) -> MaybeParsedSource:
|
||||||
key = (target_source_name, target_table_name)
|
key = (target_source_name, target_table_name)
|
||||||
candidates = _search_packages(current_project, node_package)
|
candidates = _search_packages(current_project, node_package)
|
||||||
@@ -866,9 +877,7 @@ class Manifest(MacroMethods):
|
|||||||
resolve_ref except the is_enabled checks are unnecessary as docs are
|
resolve_ref except the is_enabled checks are unnecessary as docs are
|
||||||
always enabled.
|
always enabled.
|
||||||
"""
|
"""
|
||||||
candidates = _search_packages(
|
candidates = _search_packages(current_project, node_package, package)
|
||||||
current_project, node_package, package
|
|
||||||
)
|
|
||||||
|
|
||||||
for pkg in candidates:
|
for pkg in candidates:
|
||||||
result = self.docs_cache.find_cached_value(name, pkg)
|
result = self.docs_cache.find_cached_value(name, pkg)
|
||||||
@@ -879,7 +888,7 @@ class Manifest(MacroMethods):
|
|||||||
def merge_from_artifact(
|
def merge_from_artifact(
|
||||||
self,
|
self,
|
||||||
adapter,
|
adapter,
|
||||||
other: 'WritableManifest',
|
other: "WritableManifest",
|
||||||
selected: AbstractSet[UniqueID],
|
selected: AbstractSet[UniqueID],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Given the selected unique IDs and a writable manifest, update this
|
"""Given the selected unique IDs and a writable manifest, update this
|
||||||
@@ -892,10 +901,10 @@ class Manifest(MacroMethods):
|
|||||||
for unique_id, node in other.nodes.items():
|
for unique_id, node in other.nodes.items():
|
||||||
current = self.nodes.get(unique_id)
|
current = self.nodes.get(unique_id)
|
||||||
if current and (
|
if current and (
|
||||||
node.resource_type in refables and
|
node.resource_type in refables
|
||||||
not node.is_ephemeral and
|
and not node.is_ephemeral
|
||||||
unique_id not in selected and
|
and unique_id not in selected
|
||||||
not adapter.get_relation(
|
and not adapter.get_relation(
|
||||||
current.database, current.schema, current.identifier
|
current.database, current.schema, current.identifier
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
@@ -904,9 +913,7 @@ class Manifest(MacroMethods):
|
|||||||
|
|
||||||
# log up to 5 items
|
# log up to 5 items
|
||||||
sample = list(islice(merged, 5))
|
sample = list(islice(merged, 5))
|
||||||
logger.debug(
|
logger.debug(f"Merged {len(merged)} items from state (sample: {sample})")
|
||||||
f'Merged {len(merged)} items from state (sample: {sample})'
|
|
||||||
)
|
|
||||||
|
|
||||||
# Provide support for copy.deepcopy() - we just need to avoid the lock!
|
# Provide support for copy.deepcopy() - we just need to avoid the lock!
|
||||||
# pickle and deepcopy use this. It returns a callable object used to
|
# pickle and deepcopy use this. It returns a callable object used to
|
||||||
@@ -948,47 +955,53 @@ AnyManifest = Union[Manifest, MacroManifest]
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('manifest', 1)
|
@schema_version("manifest", 1)
|
||||||
class WritableManifest(ArtifactMixin):
|
class WritableManifest(ArtifactMixin):
|
||||||
nodes: Mapping[UniqueID, ManifestNode] = field(
|
nodes: Mapping[UniqueID, ManifestNode] = field(
|
||||||
metadata=dict(description=(
|
metadata=dict(
|
||||||
'The nodes defined in the dbt project and its dependencies'
|
description=("The nodes defined in the dbt project and its dependencies")
|
||||||
))
|
)
|
||||||
)
|
)
|
||||||
sources: Mapping[UniqueID, ParsedSourceDefinition] = field(
|
sources: Mapping[UniqueID, ParsedSourceDefinition] = field(
|
||||||
metadata=dict(description=(
|
metadata=dict(
|
||||||
'The sources defined in the dbt project and its dependencies'
|
description=("The sources defined in the dbt project and its dependencies")
|
||||||
))
|
)
|
||||||
)
|
)
|
||||||
macros: Mapping[UniqueID, ParsedMacro] = field(
|
macros: Mapping[UniqueID, ParsedMacro] = field(
|
||||||
metadata=dict(description=(
|
metadata=dict(
|
||||||
'The macros defined in the dbt project and its dependencies'
|
description=("The macros defined in the dbt project and its dependencies")
|
||||||
))
|
)
|
||||||
)
|
)
|
||||||
docs: Mapping[UniqueID, ParsedDocumentation] = field(
|
docs: Mapping[UniqueID, ParsedDocumentation] = field(
|
||||||
metadata=dict(description=(
|
metadata=dict(
|
||||||
'The docs defined in the dbt project and its dependencies'
|
description=("The docs defined in the dbt project and its dependencies")
|
||||||
))
|
)
|
||||||
)
|
)
|
||||||
exposures: Mapping[UniqueID, ParsedExposure] = field(
|
exposures: Mapping[UniqueID, ParsedExposure] = field(
|
||||||
metadata=dict(description=(
|
metadata=dict(
|
||||||
'The exposures defined in the dbt project and its dependencies'
|
description=(
|
||||||
))
|
"The exposures defined in the dbt project and its dependencies"
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
selectors: Mapping[UniqueID, Any] = field(
|
selectors: Mapping[UniqueID, Any] = field(
|
||||||
metadata=dict(description=(
|
metadata=dict(description=("The selectors defined in selectors.yml"))
|
||||||
'The selectors defined in selectors.yml'
|
)
|
||||||
))
|
disabled: Optional[List[CompileResultNode]] = field(
|
||||||
|
metadata=dict(description="A list of the disabled nodes in the target")
|
||||||
|
)
|
||||||
|
parent_map: Optional[NodeEdgeMap] = field(
|
||||||
|
metadata=dict(
|
||||||
|
description="A mapping from child nodes to their dependencies",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
child_map: Optional[NodeEdgeMap] = field(
|
||||||
|
metadata=dict(
|
||||||
|
description="A mapping from parent nodes to their dependents",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
metadata: ManifestMetadata = field(
|
||||||
|
metadata=dict(
|
||||||
|
description="Metadata about the manifest",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
disabled: Optional[List[CompileResultNode]] = field(metadata=dict(
|
|
||||||
description='A list of the disabled nodes in the target'
|
|
||||||
))
|
|
||||||
parent_map: Optional[NodeEdgeMap] = field(metadata=dict(
|
|
||||||
description='A mapping from child nodes to their dependencies',
|
|
||||||
))
|
|
||||||
child_map: Optional[NodeEdgeMap] = field(metadata=dict(
|
|
||||||
description='A mapping from parent nodes to their dependents',
|
|
||||||
))
|
|
||||||
metadata: ManifestMetadata = field(metadata=dict(
|
|
||||||
description='Metadata about the manifest',
|
|
||||||
))
|
|
||||||
|
|||||||
@@ -2,11 +2,20 @@ from dataclasses import field, Field, dataclass
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import (
|
from typing import (
|
||||||
Any, List, Optional, Dict, MutableMapping, Union, Type,
|
Any,
|
||||||
TypeVar, Callable,
|
List,
|
||||||
|
Optional,
|
||||||
|
Dict,
|
||||||
|
MutableMapping,
|
||||||
|
Union,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
Callable,
|
||||||
)
|
)
|
||||||
from dbt.dataclass_schema import (
|
from dbt.dataclass_schema import (
|
||||||
dbtClassMixin, ValidationError, register_pattern,
|
dbtClassMixin,
|
||||||
|
ValidationError,
|
||||||
|
register_pattern,
|
||||||
)
|
)
|
||||||
from dbt.contracts.graph.unparsed import AdditionalPropertiesAllowed
|
from dbt.contracts.graph.unparsed import AdditionalPropertiesAllowed
|
||||||
from dbt.exceptions import CompilationException, InternalException
|
from dbt.exceptions import CompilationException, InternalException
|
||||||
@@ -15,7 +24,7 @@ from dbt import hooks
|
|||||||
from dbt.node_types import NodeType
|
from dbt.node_types import NodeType
|
||||||
|
|
||||||
|
|
||||||
M = TypeVar('M', bound='Metadata')
|
M = TypeVar("M", bound="Metadata")
|
||||||
|
|
||||||
|
|
||||||
def _get_meta_value(cls: Type[M], fld: Field, key: str, default: Any) -> M:
|
def _get_meta_value(cls: Type[M], fld: Field, key: str, default: Any) -> M:
|
||||||
@@ -30,9 +39,7 @@ def _get_meta_value(cls: Type[M], fld: Field, key: str, default: Any) -> M:
|
|||||||
try:
|
try:
|
||||||
return cls(value)
|
return cls(value)
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
raise InternalException(
|
raise InternalException(f"Invalid {cls} value: {value}") from exc
|
||||||
f'Invalid {cls} value: {value}'
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
|
|
||||||
def _set_meta_value(
|
def _set_meta_value(
|
||||||
@@ -54,19 +61,17 @@ class Metadata(Enum):
|
|||||||
|
|
||||||
return _get_meta_value(cls, fld, key, default)
|
return _get_meta_value(cls, fld, key, default)
|
||||||
|
|
||||||
def meta(
|
def meta(self, existing: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||||
self, existing: Optional[Dict[str, Any]] = None
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
key = self.metadata_key()
|
key = self.metadata_key()
|
||||||
return _set_meta_value(self, key, existing)
|
return _set_meta_value(self, key, existing)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default_field(cls) -> 'Metadata':
|
def default_field(cls) -> "Metadata":
|
||||||
raise NotImplementedError('Not implemented')
|
raise NotImplementedError("Not implemented")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def metadata_key(cls) -> str:
|
def metadata_key(cls) -> str:
|
||||||
raise NotImplementedError('Not implemented')
|
raise NotImplementedError("Not implemented")
|
||||||
|
|
||||||
|
|
||||||
class MergeBehavior(Metadata):
|
class MergeBehavior(Metadata):
|
||||||
@@ -75,12 +80,12 @@ class MergeBehavior(Metadata):
|
|||||||
Clobber = 3
|
Clobber = 3
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default_field(cls) -> 'MergeBehavior':
|
def default_field(cls) -> "MergeBehavior":
|
||||||
return cls.Clobber
|
return cls.Clobber
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def metadata_key(cls) -> str:
|
def metadata_key(cls) -> str:
|
||||||
return 'merge'
|
return "merge"
|
||||||
|
|
||||||
|
|
||||||
class ShowBehavior(Metadata):
|
class ShowBehavior(Metadata):
|
||||||
@@ -88,12 +93,12 @@ class ShowBehavior(Metadata):
|
|||||||
Hide = 2
|
Hide = 2
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default_field(cls) -> 'ShowBehavior':
|
def default_field(cls) -> "ShowBehavior":
|
||||||
return cls.Show
|
return cls.Show
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def metadata_key(cls) -> str:
|
def metadata_key(cls) -> str:
|
||||||
return 'show_hide'
|
return "show_hide"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def should_show(cls, fld: Field) -> bool:
|
def should_show(cls, fld: Field) -> bool:
|
||||||
@@ -105,12 +110,12 @@ class CompareBehavior(Metadata):
|
|||||||
Exclude = 2
|
Exclude = 2
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default_field(cls) -> 'CompareBehavior':
|
def default_field(cls) -> "CompareBehavior":
|
||||||
return cls.Include
|
return cls.Include
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def metadata_key(cls) -> str:
|
def metadata_key(cls) -> str:
|
||||||
return 'compare'
|
return "compare"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def should_include(cls, fld: Field) -> bool:
|
def should_include(cls, fld: Field) -> bool:
|
||||||
@@ -142,32 +147,30 @@ def _merge_field_value(
|
|||||||
return _listify(self_value) + _listify(other_value)
|
return _listify(self_value) + _listify(other_value)
|
||||||
elif merge_behavior == MergeBehavior.Update:
|
elif merge_behavior == MergeBehavior.Update:
|
||||||
if not isinstance(self_value, dict):
|
if not isinstance(self_value, dict):
|
||||||
raise InternalException(f'expected dict, got {self_value}')
|
raise InternalException(f"expected dict, got {self_value}")
|
||||||
if not isinstance(other_value, dict):
|
if not isinstance(other_value, dict):
|
||||||
raise InternalException(f'expected dict, got {other_value}')
|
raise InternalException(f"expected dict, got {other_value}")
|
||||||
value = self_value.copy()
|
value = self_value.copy()
|
||||||
value.update(other_value)
|
value.update(other_value)
|
||||||
return value
|
return value
|
||||||
else:
|
else:
|
||||||
raise InternalException(
|
raise InternalException(f"Got an invalid merge_behavior: {merge_behavior}")
|
||||||
f'Got an invalid merge_behavior: {merge_behavior}'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def insensitive_patterns(*patterns: str):
|
def insensitive_patterns(*patterns: str):
|
||||||
lowercased = []
|
lowercased = []
|
||||||
for pattern in patterns:
|
for pattern in patterns:
|
||||||
lowercased.append(
|
lowercased.append(
|
||||||
''.join('[{}{}]'.format(s.upper(), s.lower()) for s in pattern)
|
"".join("[{}{}]".format(s.upper(), s.lower()) for s in pattern)
|
||||||
)
|
)
|
||||||
return '^({})$'.format('|'.join(lowercased))
|
return "^({})$".format("|".join(lowercased))
|
||||||
|
|
||||||
|
|
||||||
class Severity(str):
|
class Severity(str):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
register_pattern(Severity, insensitive_patterns('warn', 'error'))
|
register_pattern(Severity, insensitive_patterns("warn", "error"))
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -177,13 +180,11 @@ class Hook(dbtClassMixin, Replaceable):
|
|||||||
index: Optional[int] = None
|
index: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar('T', bound='BaseConfig')
|
T = TypeVar("T", bound="BaseConfig")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseConfig(
|
class BaseConfig(AdditionalPropertiesAllowed, Replaceable, MutableMapping[str, Any]):
|
||||||
AdditionalPropertiesAllowed, Replaceable, MutableMapping[str, Any]
|
|
||||||
):
|
|
||||||
# Implement MutableMapping so this config will behave as some macros expect
|
# Implement MutableMapping so this config will behave as some macros expect
|
||||||
# during parsing (notably, syntax like `{{ node.config['schema'] }}`)
|
# during parsing (notably, syntax like `{{ node.config['schema'] }}`)
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
@@ -204,8 +205,7 @@ class BaseConfig(
|
|||||||
def __delitem__(self, key):
|
def __delitem__(self, key):
|
||||||
if hasattr(self, key):
|
if hasattr(self, key):
|
||||||
msg = (
|
msg = (
|
||||||
'Error, tried to delete config key "{}": Cannot delete '
|
'Error, tried to delete config key "{}": Cannot delete ' "built-in keys"
|
||||||
'built-in keys'
|
|
||||||
).format(key)
|
).format(key)
|
||||||
raise CompilationException(msg)
|
raise CompilationException(msg)
|
||||||
else:
|
else:
|
||||||
@@ -245,9 +245,7 @@ class BaseConfig(
|
|||||||
return unrendered[key] == other[key]
|
return unrendered[key] == other[key]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def same_contents(
|
def same_contents(cls, unrendered: Dict[str, Any], other: Dict[str, Any]) -> bool:
|
||||||
cls, unrendered: Dict[str, Any], other: Dict[str, Any]
|
|
||||||
) -> bool:
|
|
||||||
"""This is like __eq__, except it ignores some fields."""
|
"""This is like __eq__, except it ignores some fields."""
|
||||||
seen = set()
|
seen = set()
|
||||||
for fld, target_name in cls._get_fields():
|
for fld, target_name in cls._get_fields():
|
||||||
@@ -265,9 +263,7 @@ class BaseConfig(
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _extract_dict(
|
def _extract_dict(cls, src: Dict[str, Any], data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
cls, src: Dict[str, Any], data: Dict[str, Any]
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Find all the items in data that match a target_field on this class,
|
"""Find all the items in data that match a target_field on this class,
|
||||||
and merge them with the data found in `src` for target_field, using the
|
and merge them with the data found in `src` for target_field, using the
|
||||||
field's specified merge behavior. Matching items will be removed from
|
field's specified merge behavior. Matching items will be removed from
|
||||||
@@ -307,6 +303,7 @@ class BaseConfig(
|
|||||||
"""
|
"""
|
||||||
# sadly, this is a circular import
|
# sadly, this is a circular import
|
||||||
from dbt.adapters.factory import get_config_class_by_name
|
from dbt.adapters.factory import get_config_class_by_name
|
||||||
|
|
||||||
dct = self.to_dict(omit_none=False)
|
dct = self.to_dict(omit_none=False)
|
||||||
|
|
||||||
adapter_config_cls = get_config_class_by_name(adapter_type)
|
adapter_config_cls = get_config_class_by_name(adapter_type)
|
||||||
@@ -348,7 +345,7 @@ class SourceConfig(BaseConfig):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class NodeConfig(BaseConfig):
|
class NodeConfig(BaseConfig):
|
||||||
enabled: bool = True
|
enabled: bool = True
|
||||||
materialized: str = 'view'
|
materialized: str = "view"
|
||||||
persist_docs: Dict[str, Any] = field(default_factory=dict)
|
persist_docs: Dict[str, Any] = field(default_factory=dict)
|
||||||
post_hook: List[Hook] = field(
|
post_hook: List[Hook] = field(
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
@@ -389,16 +386,16 @@ class NodeConfig(BaseConfig):
|
|||||||
)
|
)
|
||||||
tags: Union[List[str], str] = field(
|
tags: Union[List[str], str] = field(
|
||||||
default_factory=list_str,
|
default_factory=list_str,
|
||||||
metadata=metas(ShowBehavior.Hide,
|
metadata=metas(
|
||||||
MergeBehavior.Append,
|
ShowBehavior.Hide, MergeBehavior.Append, CompareBehavior.Exclude
|
||||||
CompareBehavior.Exclude),
|
),
|
||||||
)
|
)
|
||||||
full_refresh: Optional[bool] = None
|
full_refresh: Optional[bool] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __pre_deserialize__(cls, data):
|
def __pre_deserialize__(cls, data):
|
||||||
data = super().__pre_deserialize__(data)
|
data = super().__pre_deserialize__(data)
|
||||||
field_map = {'post-hook': 'post_hook', 'pre-hook': 'pre_hook'}
|
field_map = {"post-hook": "post_hook", "pre-hook": "pre_hook"}
|
||||||
# create a new dict because otherwise it gets overwritten in
|
# create a new dict because otherwise it gets overwritten in
|
||||||
# tests
|
# tests
|
||||||
new_dict = {}
|
new_dict = {}
|
||||||
@@ -416,7 +413,7 @@ class NodeConfig(BaseConfig):
|
|||||||
|
|
||||||
def __post_serialize__(self, dct):
|
def __post_serialize__(self, dct):
|
||||||
dct = super().__post_serialize__(dct)
|
dct = super().__post_serialize__(dct)
|
||||||
field_map = {'post_hook': 'post-hook', 'pre_hook': 'pre-hook'}
|
field_map = {"post_hook": "post-hook", "pre_hook": "pre-hook"}
|
||||||
for field_name in field_map:
|
for field_name in field_map:
|
||||||
if field_name in dct:
|
if field_name in dct:
|
||||||
dct[field_map[field_name]] = dct.pop(field_name)
|
dct[field_map[field_name]] = dct.pop(field_name)
|
||||||
@@ -425,24 +422,24 @@ class NodeConfig(BaseConfig):
|
|||||||
# this is still used by jsonschema validation
|
# this is still used by jsonschema validation
|
||||||
@classmethod
|
@classmethod
|
||||||
def field_mapping(cls):
|
def field_mapping(cls):
|
||||||
return {'post_hook': 'post-hook', 'pre_hook': 'pre-hook'}
|
return {"post_hook": "post-hook", "pre_hook": "pre-hook"}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SeedConfig(NodeConfig):
|
class SeedConfig(NodeConfig):
|
||||||
materialized: str = 'seed'
|
materialized: str = "seed"
|
||||||
quote_columns: Optional[bool] = None
|
quote_columns: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TestConfig(NodeConfig):
|
class TestConfig(NodeConfig):
|
||||||
materialized: str = 'test'
|
materialized: str = "test"
|
||||||
severity: Severity = Severity('ERROR')
|
severity: Severity = Severity("ERROR")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EmptySnapshotConfig(NodeConfig):
|
class EmptySnapshotConfig(NodeConfig):
|
||||||
materialized: str = 'snapshot'
|
materialized: str = "snapshot"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -457,25 +454,28 @@ class SnapshotConfig(EmptySnapshotConfig):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def validate(cls, data):
|
def validate(cls, data):
|
||||||
super().validate(data)
|
super().validate(data)
|
||||||
if data.get('strategy') == 'check':
|
if data.get("strategy") == "check":
|
||||||
if not data.get('check_cols'):
|
if not data.get("check_cols"):
|
||||||
raise ValidationError(
|
raise ValidationError(
|
||||||
"A snapshot configured with the check strategy must "
|
"A snapshot configured with the check strategy must "
|
||||||
"specify a check_cols configuration.")
|
"specify a check_cols configuration."
|
||||||
if (isinstance(data['check_cols'], str) and
|
)
|
||||||
data['check_cols'] != 'all'):
|
if isinstance(data["check_cols"], str) and data["check_cols"] != "all":
|
||||||
raise ValidationError(
|
raise ValidationError(
|
||||||
f"Invalid value for 'check_cols': {data['check_cols']}. "
|
f"Invalid value for 'check_cols': {data['check_cols']}. "
|
||||||
"Expected 'all' or a list of strings.")
|
"Expected 'all' or a list of strings."
|
||||||
|
)
|
||||||
|
|
||||||
elif data.get('strategy') == 'timestamp':
|
elif data.get("strategy") == "timestamp":
|
||||||
if not data.get('updated_at'):
|
if not data.get("updated_at"):
|
||||||
raise ValidationError(
|
raise ValidationError(
|
||||||
"A snapshot configured with the timestamp strategy "
|
"A snapshot configured with the timestamp strategy "
|
||||||
"must specify an updated_at configuration.")
|
"must specify an updated_at configuration."
|
||||||
if data.get('check_cols'):
|
)
|
||||||
|
if data.get("check_cols"):
|
||||||
raise ValidationError(
|
raise ValidationError(
|
||||||
"A 'timestamp' snapshot should not have 'check_cols'")
|
"A 'timestamp' snapshot should not have 'check_cols'"
|
||||||
|
)
|
||||||
# If the strategy is not 'check' or 'timestamp' it's a custom strategy,
|
# If the strategy is not 'check' or 'timestamp' it's a custom strategy,
|
||||||
# formerly supported with GenericSnapshotConfig
|
# formerly supported with GenericSnapshotConfig
|
||||||
|
|
||||||
@@ -497,9 +497,7 @@ RESOURCE_TYPES: Dict[NodeType, Type[BaseConfig]] = {
|
|||||||
# base resource types are like resource types, except nothing has mandatory
|
# base resource types are like resource types, except nothing has mandatory
|
||||||
# configs.
|
# configs.
|
||||||
BASE_RESOURCE_TYPES: Dict[NodeType, Type[BaseConfig]] = RESOURCE_TYPES.copy()
|
BASE_RESOURCE_TYPES: Dict[NodeType, Type[BaseConfig]] = RESOURCE_TYPES.copy()
|
||||||
BASE_RESOURCE_TYPES.update({
|
BASE_RESOURCE_TYPES.update({NodeType.Snapshot: EmptySnapshotConfig})
|
||||||
NodeType.Snapshot: EmptySnapshotConfig
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
def get_config_for(resource_type: NodeType, base=False) -> Type[BaseConfig]:
|
def get_config_for(resource_type: NodeType, base=False) -> Type[BaseConfig]:
|
||||||
|
|||||||
@@ -13,18 +13,27 @@ from typing import (
|
|||||||
TypeVar,
|
TypeVar,
|
||||||
)
|
)
|
||||||
|
|
||||||
from dbt.dataclass_schema import (
|
from dbt.dataclass_schema import dbtClassMixin, ExtensibleDbtClassMixin
|
||||||
dbtClassMixin, ExtensibleDbtClassMixin
|
|
||||||
)
|
|
||||||
|
|
||||||
from dbt.clients.system import write_file
|
from dbt.clients.system import write_file
|
||||||
from dbt.contracts.files import FileHash, MAXIMUM_SEED_SIZE_NAME
|
from dbt.contracts.files import FileHash, MAXIMUM_SEED_SIZE_NAME
|
||||||
from dbt.contracts.graph.unparsed import (
|
from dbt.contracts.graph.unparsed import (
|
||||||
UnparsedNode, UnparsedDocumentation, Quoting, Docs,
|
UnparsedNode,
|
||||||
UnparsedBaseNode, FreshnessThreshold, ExternalTable,
|
UnparsedDocumentation,
|
||||||
HasYamlMetadata, MacroArgument, UnparsedSourceDefinition,
|
Quoting,
|
||||||
UnparsedSourceTableDefinition, UnparsedColumn, TestDef,
|
Docs,
|
||||||
ExposureOwner, ExposureType, MaturityType
|
UnparsedBaseNode,
|
||||||
|
FreshnessThreshold,
|
||||||
|
ExternalTable,
|
||||||
|
HasYamlMetadata,
|
||||||
|
MacroArgument,
|
||||||
|
UnparsedSourceDefinition,
|
||||||
|
UnparsedSourceTableDefinition,
|
||||||
|
UnparsedColumn,
|
||||||
|
TestDef,
|
||||||
|
ExposureOwner,
|
||||||
|
ExposureType,
|
||||||
|
MaturityType,
|
||||||
)
|
)
|
||||||
from dbt.contracts.util import Replaceable, AdditionalPropertiesMixin
|
from dbt.contracts.util import Replaceable, AdditionalPropertiesMixin
|
||||||
from dbt.exceptions import warn_or_error
|
from dbt.exceptions import warn_or_error
|
||||||
@@ -44,13 +53,9 @@ from .model_config import (
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ColumnInfo(
|
class ColumnInfo(AdditionalPropertiesMixin, ExtensibleDbtClassMixin, Replaceable):
|
||||||
AdditionalPropertiesMixin,
|
|
||||||
ExtensibleDbtClassMixin,
|
|
||||||
Replaceable
|
|
||||||
):
|
|
||||||
name: str
|
name: str
|
||||||
description: str = ''
|
description: str = ""
|
||||||
meta: Dict[str, Any] = field(default_factory=dict)
|
meta: Dict[str, Any] = field(default_factory=dict)
|
||||||
data_type: Optional[str] = None
|
data_type: Optional[str] = None
|
||||||
quote: Optional[bool] = None
|
quote: Optional[bool] = None
|
||||||
@@ -62,7 +67,7 @@ class ColumnInfo(
|
|||||||
class HasFqn(dbtClassMixin, Replaceable):
|
class HasFqn(dbtClassMixin, Replaceable):
|
||||||
fqn: List[str]
|
fqn: List[str]
|
||||||
|
|
||||||
def same_fqn(self, other: 'HasFqn') -> bool:
|
def same_fqn(self, other: "HasFqn") -> bool:
|
||||||
return self.fqn == other.fqn
|
return self.fqn == other.fqn
|
||||||
|
|
||||||
|
|
||||||
@@ -101,8 +106,8 @@ class HasRelationMetadata(dbtClassMixin, Replaceable):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def __pre_deserialize__(cls, data):
|
def __pre_deserialize__(cls, data):
|
||||||
data = super().__pre_deserialize__(data)
|
data = super().__pre_deserialize__(data)
|
||||||
if 'database' not in data:
|
if "database" not in data:
|
||||||
data['database'] = None
|
data["database"] = None
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
@@ -117,7 +122,7 @@ class ParsedNodeMixins(dbtClassMixin):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def is_ephemeral(self):
|
def is_ephemeral(self):
|
||||||
return self.config.materialized == 'ephemeral'
|
return self.config.materialized == "ephemeral"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_ephemeral_model(self):
|
def is_ephemeral_model(self):
|
||||||
@@ -127,7 +132,7 @@ class ParsedNodeMixins(dbtClassMixin):
|
|||||||
def depends_on_nodes(self):
|
def depends_on_nodes(self):
|
||||||
return self.depends_on.nodes
|
return self.depends_on.nodes
|
||||||
|
|
||||||
def patch(self, patch: 'ParsedNodePatch'):
|
def patch(self, patch: "ParsedNodePatch"):
|
||||||
"""Given a ParsedNodePatch, add the new information to the node."""
|
"""Given a ParsedNodePatch, add the new information to the node."""
|
||||||
# explicitly pick out the parts to update so we don't inadvertently
|
# explicitly pick out the parts to update so we don't inadvertently
|
||||||
# step on the model name or anything
|
# step on the model name or anything
|
||||||
@@ -153,11 +158,7 @@ class ParsedNodeMixins(dbtClassMixin):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ParsedNodeMandatory(
|
class ParsedNodeMandatory(
|
||||||
UnparsedNode,
|
UnparsedNode, HasUniqueID, HasFqn, HasRelationMetadata, Replaceable
|
||||||
HasUniqueID,
|
|
||||||
HasFqn,
|
|
||||||
HasRelationMetadata,
|
|
||||||
Replaceable
|
|
||||||
):
|
):
|
||||||
alias: str
|
alias: str
|
||||||
checksum: FileHash
|
checksum: FileHash
|
||||||
@@ -174,7 +175,7 @@ class ParsedNodeDefaults(ParsedNodeMandatory):
|
|||||||
refs: List[List[str]] = field(default_factory=list)
|
refs: List[List[str]] = field(default_factory=list)
|
||||||
sources: List[List[Any]] = field(default_factory=list)
|
sources: List[List[Any]] = field(default_factory=list)
|
||||||
depends_on: DependsOn = field(default_factory=DependsOn)
|
depends_on: DependsOn = field(default_factory=DependsOn)
|
||||||
description: str = field(default='')
|
description: str = field(default="")
|
||||||
columns: Dict[str, ColumnInfo] = field(default_factory=dict)
|
columns: Dict[str, ColumnInfo] = field(default_factory=dict)
|
||||||
meta: Dict[str, Any] = field(default_factory=dict)
|
meta: Dict[str, Any] = field(default_factory=dict)
|
||||||
docs: Docs = field(default_factory=Docs)
|
docs: Docs = field(default_factory=Docs)
|
||||||
@@ -184,31 +185,28 @@ class ParsedNodeDefaults(ParsedNodeMandatory):
|
|||||||
unrendered_config: Dict[str, Any] = field(default_factory=dict)
|
unrendered_config: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
def write_node(self, target_path: str, subdirectory: str, payload: str):
|
def write_node(self, target_path: str, subdirectory: str, payload: str):
|
||||||
if (os.path.basename(self.path) ==
|
if os.path.basename(self.path) == os.path.basename(self.original_file_path):
|
||||||
os.path.basename(self.original_file_path)):
|
|
||||||
# One-to-one relationship of nodes to files.
|
# One-to-one relationship of nodes to files.
|
||||||
path = self.original_file_path
|
path = self.original_file_path
|
||||||
else:
|
else:
|
||||||
# Many-to-one relationship of nodes to files.
|
# Many-to-one relationship of nodes to files.
|
||||||
path = os.path.join(self.original_file_path, self.path)
|
path = os.path.join(self.original_file_path, self.path)
|
||||||
full_path = os.path.join(
|
full_path = os.path.join(target_path, subdirectory, self.package_name, path)
|
||||||
target_path, subdirectory, self.package_name, path
|
|
||||||
)
|
|
||||||
|
|
||||||
write_file(full_path, payload)
|
write_file(full_path, payload)
|
||||||
return full_path
|
return full_path
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar('T', bound='ParsedNode')
|
T = TypeVar("T", bound="ParsedNode")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ParsedNode(ParsedNodeDefaults, ParsedNodeMixins):
|
class ParsedNode(ParsedNodeDefaults, ParsedNodeMixins):
|
||||||
def _persist_column_docs(self) -> bool:
|
def _persist_column_docs(self) -> bool:
|
||||||
return bool(self.config.persist_docs.get('columns'))
|
return bool(self.config.persist_docs.get("columns"))
|
||||||
|
|
||||||
def _persist_relation_docs(self) -> bool:
|
def _persist_relation_docs(self) -> bool:
|
||||||
return bool(self.config.persist_docs.get('relation'))
|
return bool(self.config.persist_docs.get("relation"))
|
||||||
|
|
||||||
def same_body(self: T, other: T) -> bool:
|
def same_body(self: T, other: T) -> bool:
|
||||||
return self.raw_sql == other.raw_sql
|
return self.raw_sql == other.raw_sql
|
||||||
@@ -223,9 +221,7 @@ class ParsedNode(ParsedNodeDefaults, ParsedNodeMixins):
|
|||||||
|
|
||||||
if self._persist_column_docs():
|
if self._persist_column_docs():
|
||||||
# assert other._persist_column_docs()
|
# assert other._persist_column_docs()
|
||||||
column_descriptions = {
|
column_descriptions = {k: v.description for k, v in self.columns.items()}
|
||||||
k: v.description for k, v in self.columns.items()
|
|
||||||
}
|
|
||||||
other_column_descriptions = {
|
other_column_descriptions = {
|
||||||
k: v.description for k, v in other.columns.items()
|
k: v.description for k, v in other.columns.items()
|
||||||
}
|
}
|
||||||
@@ -239,7 +235,7 @@ class ParsedNode(ParsedNodeDefaults, ParsedNodeMixins):
|
|||||||
# compares the configured value, rather than the ultimate value (so
|
# compares the configured value, rather than the ultimate value (so
|
||||||
# generate_*_name and unset values derived from the target are
|
# generate_*_name and unset values derived from the target are
|
||||||
# ignored)
|
# ignored)
|
||||||
keys = ('database', 'schema', 'alias')
|
keys = ("database", "schema", "alias")
|
||||||
for key in keys:
|
for key in keys:
|
||||||
mine = self.unrendered_config.get(key)
|
mine = self.unrendered_config.get(key)
|
||||||
others = other.unrendered_config.get(key)
|
others = other.unrendered_config.get(key)
|
||||||
@@ -258,36 +254,34 @@ class ParsedNode(ParsedNodeDefaults, ParsedNodeMixins):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
return (
|
return (
|
||||||
self.same_body(old) and
|
self.same_body(old)
|
||||||
self.same_config(old) and
|
and self.same_config(old)
|
||||||
self.same_persisted_description(old) and
|
and self.same_persisted_description(old)
|
||||||
self.same_fqn(old) and
|
and self.same_fqn(old)
|
||||||
self.same_database_representation(old) and
|
and self.same_database_representation(old)
|
||||||
True
|
and True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ParsedAnalysisNode(ParsedNode):
|
class ParsedAnalysisNode(ParsedNode):
|
||||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Analysis]})
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Analysis]})
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ParsedHookNode(ParsedNode):
|
class ParsedHookNode(ParsedNode):
|
||||||
resource_type: NodeType = field(
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Operation]})
|
||||||
metadata={'restrict': [NodeType.Operation]}
|
|
||||||
)
|
|
||||||
index: Optional[int] = None
|
index: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ParsedModelNode(ParsedNode):
|
class ParsedModelNode(ParsedNode):
|
||||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Model]})
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Model]})
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ParsedRPCNode(ParsedNode):
|
class ParsedRPCNode(ParsedNode):
|
||||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.RPCCall]})
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.RPCCall]})
|
||||||
|
|
||||||
|
|
||||||
def same_seeds(first: ParsedNode, second: ParsedNode) -> bool:
|
def same_seeds(first: ParsedNode, second: ParsedNode) -> bool:
|
||||||
@@ -297,31 +291,31 @@ def same_seeds(first: ParsedNode, second: ParsedNode) -> bool:
|
|||||||
# if the current checksum is a path, we want to log a warning.
|
# if the current checksum is a path, we want to log a warning.
|
||||||
result = first.checksum == second.checksum
|
result = first.checksum == second.checksum
|
||||||
|
|
||||||
if first.checksum.name == 'path':
|
if first.checksum.name == "path":
|
||||||
msg: str
|
msg: str
|
||||||
if second.checksum.name != 'path':
|
if second.checksum.name != "path":
|
||||||
msg = (
|
msg = (
|
||||||
f'Found a seed ({first.package_name}.{first.name}) '
|
f"Found a seed ({first.package_name}.{first.name}) "
|
||||||
f'>{MAXIMUM_SEED_SIZE_NAME} in size. The previous file was '
|
f">{MAXIMUM_SEED_SIZE_NAME} in size. The previous file was "
|
||||||
f'<={MAXIMUM_SEED_SIZE_NAME}, so it has changed'
|
f"<={MAXIMUM_SEED_SIZE_NAME}, so it has changed"
|
||||||
)
|
)
|
||||||
elif result:
|
elif result:
|
||||||
msg = (
|
msg = (
|
||||||
f'Found a seed ({first.package_name}.{first.name}) '
|
f"Found a seed ({first.package_name}.{first.name}) "
|
||||||
f'>{MAXIMUM_SEED_SIZE_NAME} in size at the same path, dbt '
|
f">{MAXIMUM_SEED_SIZE_NAME} in size at the same path, dbt "
|
||||||
f'cannot tell if it has changed: assuming they are the same'
|
f"cannot tell if it has changed: assuming they are the same"
|
||||||
)
|
)
|
||||||
elif not result:
|
elif not result:
|
||||||
msg = (
|
msg = (
|
||||||
f'Found a seed ({first.package_name}.{first.name}) '
|
f"Found a seed ({first.package_name}.{first.name}) "
|
||||||
f'>{MAXIMUM_SEED_SIZE_NAME} in size. The previous file was in '
|
f">{MAXIMUM_SEED_SIZE_NAME} in size. The previous file was in "
|
||||||
f'a different location, assuming it has changed'
|
f"a different location, assuming it has changed"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
msg = (
|
msg = (
|
||||||
f'Found a seed ({first.package_name}.{first.name}) '
|
f"Found a seed ({first.package_name}.{first.name}) "
|
||||||
f'>{MAXIMUM_SEED_SIZE_NAME} in size. The previous file had a '
|
f">{MAXIMUM_SEED_SIZE_NAME} in size. The previous file had a "
|
||||||
f'checksum type of {second.checksum.name}, so it has changed'
|
f"checksum type of {second.checksum.name}, so it has changed"
|
||||||
)
|
)
|
||||||
warn_or_error(msg, node=first)
|
warn_or_error(msg, node=first)
|
||||||
|
|
||||||
@@ -331,7 +325,7 @@ def same_seeds(first: ParsedNode, second: ParsedNode) -> bool:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ParsedSeedNode(ParsedNode):
|
class ParsedSeedNode(ParsedNode):
|
||||||
# keep this in sync with CompiledSeedNode!
|
# keep this in sync with CompiledSeedNode!
|
||||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Seed]})
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Seed]})
|
||||||
config: SeedConfig = field(default_factory=SeedConfig)
|
config: SeedConfig = field(default_factory=SeedConfig)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -357,21 +351,20 @@ class HasTestMetadata(dbtClassMixin):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ParsedDataTestNode(ParsedNode):
|
class ParsedDataTestNode(ParsedNode):
|
||||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Test]})
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Test]})
|
||||||
config: TestConfig = field(default_factory=TestConfig)
|
config: TestConfig = field(default_factory=TestConfig)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ParsedSchemaTestNode(ParsedNode, HasTestMetadata):
|
class ParsedSchemaTestNode(ParsedNode, HasTestMetadata):
|
||||||
# keep this in sync with CompiledSchemaTestNode!
|
# keep this in sync with CompiledSchemaTestNode!
|
||||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Test]})
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Test]})
|
||||||
column_name: Optional[str] = None
|
column_name: Optional[str] = None
|
||||||
config: TestConfig = field(default_factory=TestConfig)
|
config: TestConfig = field(default_factory=TestConfig)
|
||||||
|
|
||||||
def same_config(self, other) -> bool:
|
def same_config(self, other) -> bool:
|
||||||
return (
|
return self.unrendered_config.get("severity") == other.unrendered_config.get(
|
||||||
self.unrendered_config.get('severity') ==
|
"severity"
|
||||||
other.unrendered_config.get('severity')
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def same_column_name(self, other) -> bool:
|
def same_column_name(self, other) -> bool:
|
||||||
@@ -381,11 +374,7 @@ class ParsedSchemaTestNode(ParsedNode, HasTestMetadata):
|
|||||||
if other is None:
|
if other is None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return (
|
return self.same_config(other) and self.same_fqn(other) and True
|
||||||
self.same_config(other) and
|
|
||||||
self.same_fqn(other) and
|
|
||||||
True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -396,13 +385,13 @@ class IntermediateSnapshotNode(ParsedNode):
|
|||||||
# defined in config blocks. To fix that, we have an intermediate type that
|
# defined in config blocks. To fix that, we have an intermediate type that
|
||||||
# uses a regular node config, which the snapshot parser will then convert
|
# uses a regular node config, which the snapshot parser will then convert
|
||||||
# into a full ParsedSnapshotNode after rendering.
|
# into a full ParsedSnapshotNode after rendering.
|
||||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Snapshot]})
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Snapshot]})
|
||||||
config: EmptySnapshotConfig = field(default_factory=EmptySnapshotConfig)
|
config: EmptySnapshotConfig = field(default_factory=EmptySnapshotConfig)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ParsedSnapshotNode(ParsedNode):
|
class ParsedSnapshotNode(ParsedNode):
|
||||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Snapshot]})
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Snapshot]})
|
||||||
config: SnapshotConfig
|
config: SnapshotConfig
|
||||||
|
|
||||||
|
|
||||||
@@ -431,12 +420,12 @@ class ParsedMacroPatch(ParsedPatch):
|
|||||||
class ParsedMacro(UnparsedBaseNode, HasUniqueID):
|
class ParsedMacro(UnparsedBaseNode, HasUniqueID):
|
||||||
name: str
|
name: str
|
||||||
macro_sql: str
|
macro_sql: str
|
||||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Macro]})
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Macro]})
|
||||||
# TODO: can macros even have tags?
|
# TODO: can macros even have tags?
|
||||||
tags: List[str] = field(default_factory=list)
|
tags: List[str] = field(default_factory=list)
|
||||||
# TODO: is this ever populated?
|
# TODO: is this ever populated?
|
||||||
depends_on: MacroDependsOn = field(default_factory=MacroDependsOn)
|
depends_on: MacroDependsOn = field(default_factory=MacroDependsOn)
|
||||||
description: str = ''
|
description: str = ""
|
||||||
meta: Dict[str, Any] = field(default_factory=dict)
|
meta: Dict[str, Any] = field(default_factory=dict)
|
||||||
docs: Docs = field(default_factory=Docs)
|
docs: Docs = field(default_factory=Docs)
|
||||||
patch_path: Optional[str] = None
|
patch_path: Optional[str] = None
|
||||||
@@ -457,7 +446,7 @@ class ParsedMacro(UnparsedBaseNode, HasUniqueID):
|
|||||||
dct = self.to_dict(omit_none=False)
|
dct = self.to_dict(omit_none=False)
|
||||||
self.validate(dct)
|
self.validate(dct)
|
||||||
|
|
||||||
def same_contents(self, other: Optional['ParsedMacro']) -> bool:
|
def same_contents(self, other: Optional["ParsedMacro"]) -> bool:
|
||||||
if other is None:
|
if other is None:
|
||||||
return False
|
return False
|
||||||
# the only thing that makes one macro different from another with the
|
# the only thing that makes one macro different from another with the
|
||||||
@@ -474,7 +463,7 @@ class ParsedDocumentation(UnparsedDocumentation, HasUniqueID):
|
|||||||
def search_name(self):
|
def search_name(self):
|
||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
def same_contents(self, other: Optional['ParsedDocumentation']) -> bool:
|
def same_contents(self, other: Optional["ParsedDocumentation"]) -> bool:
|
||||||
if other is None:
|
if other is None:
|
||||||
return False
|
return False
|
||||||
# the only thing that makes one doc different from another with the
|
# the only thing that makes one doc different from another with the
|
||||||
@@ -493,11 +482,11 @@ def normalize_test(testdef: TestDef) -> Dict[str, Any]:
|
|||||||
class UnpatchedSourceDefinition(UnparsedBaseNode, HasUniqueID, HasFqn):
|
class UnpatchedSourceDefinition(UnparsedBaseNode, HasUniqueID, HasFqn):
|
||||||
source: UnparsedSourceDefinition
|
source: UnparsedSourceDefinition
|
||||||
table: UnparsedSourceTableDefinition
|
table: UnparsedSourceTableDefinition
|
||||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Source]})
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Source]})
|
||||||
patch_path: Optional[Path] = None
|
patch_path: Optional[Path] = None
|
||||||
|
|
||||||
def get_full_source_name(self):
|
def get_full_source_name(self):
|
||||||
return f'{self.source.name}_{self.table.name}'
|
return f"{self.source.name}_{self.table.name}"
|
||||||
|
|
||||||
def get_source_representation(self):
|
def get_source_representation(self):
|
||||||
return f'source("{self.source.name}", "{self.table.name}")'
|
return f'source("{self.source.name}", "{self.table.name}")'
|
||||||
@@ -522,9 +511,7 @@ class UnpatchedSourceDefinition(UnparsedBaseNode, HasUniqueID, HasFqn):
|
|||||||
else:
|
else:
|
||||||
return self.table.columns
|
return self.table.columns
|
||||||
|
|
||||||
def get_tests(
|
def get_tests(self) -> Iterator[Tuple[Dict[str, Any], Optional[UnparsedColumn]]]:
|
||||||
self
|
|
||||||
) -> Iterator[Tuple[Dict[str, Any], Optional[UnparsedColumn]]]:
|
|
||||||
for test in self.tests:
|
for test in self.tests:
|
||||||
yield normalize_test(test), None
|
yield normalize_test(test), None
|
||||||
|
|
||||||
@@ -543,22 +530,19 @@ class UnpatchedSourceDefinition(UnparsedBaseNode, HasUniqueID, HasFqn):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ParsedSourceDefinition(
|
class ParsedSourceDefinition(
|
||||||
UnparsedBaseNode,
|
UnparsedBaseNode, HasUniqueID, HasRelationMetadata, HasFqn
|
||||||
HasUniqueID,
|
|
||||||
HasRelationMetadata,
|
|
||||||
HasFqn
|
|
||||||
):
|
):
|
||||||
name: str
|
name: str
|
||||||
source_name: str
|
source_name: str
|
||||||
source_description: str
|
source_description: str
|
||||||
loader: str
|
loader: str
|
||||||
identifier: str
|
identifier: str
|
||||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Source]})
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Source]})
|
||||||
quoting: Quoting = field(default_factory=Quoting)
|
quoting: Quoting = field(default_factory=Quoting)
|
||||||
loaded_at_field: Optional[str] = None
|
loaded_at_field: Optional[str] = None
|
||||||
freshness: Optional[FreshnessThreshold] = None
|
freshness: Optional[FreshnessThreshold] = None
|
||||||
external: Optional[ExternalTable] = None
|
external: Optional[ExternalTable] = None
|
||||||
description: str = ''
|
description: str = ""
|
||||||
columns: Dict[str, ColumnInfo] = field(default_factory=dict)
|
columns: Dict[str, ColumnInfo] = field(default_factory=dict)
|
||||||
meta: Dict[str, Any] = field(default_factory=dict)
|
meta: Dict[str, Any] = field(default_factory=dict)
|
||||||
source_meta: Dict[str, Any] = field(default_factory=dict)
|
source_meta: Dict[str, Any] = field(default_factory=dict)
|
||||||
@@ -568,36 +552,34 @@ class ParsedSourceDefinition(
|
|||||||
unrendered_config: Dict[str, Any] = field(default_factory=dict)
|
unrendered_config: Dict[str, Any] = field(default_factory=dict)
|
||||||
relation_name: Optional[str] = None
|
relation_name: Optional[str] = None
|
||||||
|
|
||||||
def same_database_representation(
|
def same_database_representation(self, other: "ParsedSourceDefinition") -> bool:
|
||||||
self, other: 'ParsedSourceDefinition'
|
|
||||||
) -> bool:
|
|
||||||
return (
|
return (
|
||||||
self.database == other.database and
|
self.database == other.database
|
||||||
self.schema == other.schema and
|
and self.schema == other.schema
|
||||||
self.identifier == other.identifier and
|
and self.identifier == other.identifier
|
||||||
True
|
and True
|
||||||
)
|
)
|
||||||
|
|
||||||
def same_quoting(self, other: 'ParsedSourceDefinition') -> bool:
|
def same_quoting(self, other: "ParsedSourceDefinition") -> bool:
|
||||||
return self.quoting == other.quoting
|
return self.quoting == other.quoting
|
||||||
|
|
||||||
def same_freshness(self, other: 'ParsedSourceDefinition') -> bool:
|
def same_freshness(self, other: "ParsedSourceDefinition") -> bool:
|
||||||
return (
|
return (
|
||||||
self.freshness == other.freshness and
|
self.freshness == other.freshness
|
||||||
self.loaded_at_field == other.loaded_at_field and
|
and self.loaded_at_field == other.loaded_at_field
|
||||||
True
|
and True
|
||||||
)
|
)
|
||||||
|
|
||||||
def same_external(self, other: 'ParsedSourceDefinition') -> bool:
|
def same_external(self, other: "ParsedSourceDefinition") -> bool:
|
||||||
return self.external == other.external
|
return self.external == other.external
|
||||||
|
|
||||||
def same_config(self, old: 'ParsedSourceDefinition') -> bool:
|
def same_config(self, old: "ParsedSourceDefinition") -> bool:
|
||||||
return self.config.same_contents(
|
return self.config.same_contents(
|
||||||
self.unrendered_config,
|
self.unrendered_config,
|
||||||
old.unrendered_config,
|
old.unrendered_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
def same_contents(self, old: Optional['ParsedSourceDefinition']) -> bool:
|
def same_contents(self, old: Optional["ParsedSourceDefinition"]) -> bool:
|
||||||
# existing when it didn't before is a change!
|
# existing when it didn't before is a change!
|
||||||
if old is None:
|
if old is None:
|
||||||
return True
|
return True
|
||||||
@@ -611,17 +593,17 @@ class ParsedSourceDefinition(
|
|||||||
# metadata/tags changes are not "changes"
|
# metadata/tags changes are not "changes"
|
||||||
# patching/description changes are not "changes"
|
# patching/description changes are not "changes"
|
||||||
return (
|
return (
|
||||||
self.same_database_representation(old) and
|
self.same_database_representation(old)
|
||||||
self.same_fqn(old) and
|
and self.same_fqn(old)
|
||||||
self.same_config(old) and
|
and self.same_config(old)
|
||||||
self.same_quoting(old) and
|
and self.same_quoting(old)
|
||||||
self.same_freshness(old) and
|
and self.same_freshness(old)
|
||||||
self.same_external(old) and
|
and self.same_external(old)
|
||||||
True
|
and True
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_full_source_name(self):
|
def get_full_source_name(self):
|
||||||
return f'{self.source_name}_{self.name}'
|
return f"{self.source_name}_{self.name}"
|
||||||
|
|
||||||
def get_source_representation(self):
|
def get_source_representation(self):
|
||||||
return f'source("{self.source.name}", "{self.table.name}")'
|
return f'source("{self.source.name}", "{self.table.name}")'
|
||||||
@@ -656,7 +638,7 @@ class ParsedSourceDefinition(
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def search_name(self):
|
def search_name(self):
|
||||||
return f'{self.source_name}.{self.name}'
|
return f"{self.source_name}.{self.name}"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -665,7 +647,7 @@ class ParsedExposure(UnparsedBaseNode, HasUniqueID, HasFqn):
|
|||||||
type: ExposureType
|
type: ExposureType
|
||||||
owner: ExposureOwner
|
owner: ExposureOwner
|
||||||
resource_type: NodeType = NodeType.Exposure
|
resource_type: NodeType = NodeType.Exposure
|
||||||
description: str = ''
|
description: str = ""
|
||||||
maturity: Optional[MaturityType] = None
|
maturity: Optional[MaturityType] = None
|
||||||
url: Optional[str] = None
|
url: Optional[str] = None
|
||||||
depends_on: DependsOn = field(default_factory=DependsOn)
|
depends_on: DependsOn = field(default_factory=DependsOn)
|
||||||
@@ -685,38 +667,38 @@ class ParsedExposure(UnparsedBaseNode, HasUniqueID, HasFqn):
|
|||||||
def tags(self):
|
def tags(self):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def same_depends_on(self, old: 'ParsedExposure') -> bool:
|
def same_depends_on(self, old: "ParsedExposure") -> bool:
|
||||||
return set(self.depends_on.nodes) == set(old.depends_on.nodes)
|
return set(self.depends_on.nodes) == set(old.depends_on.nodes)
|
||||||
|
|
||||||
def same_description(self, old: 'ParsedExposure') -> bool:
|
def same_description(self, old: "ParsedExposure") -> bool:
|
||||||
return self.description == old.description
|
return self.description == old.description
|
||||||
|
|
||||||
def same_maturity(self, old: 'ParsedExposure') -> bool:
|
def same_maturity(self, old: "ParsedExposure") -> bool:
|
||||||
return self.maturity == old.maturity
|
return self.maturity == old.maturity
|
||||||
|
|
||||||
def same_owner(self, old: 'ParsedExposure') -> bool:
|
def same_owner(self, old: "ParsedExposure") -> bool:
|
||||||
return self.owner == old.owner
|
return self.owner == old.owner
|
||||||
|
|
||||||
def same_exposure_type(self, old: 'ParsedExposure') -> bool:
|
def same_exposure_type(self, old: "ParsedExposure") -> bool:
|
||||||
return self.type == old.type
|
return self.type == old.type
|
||||||
|
|
||||||
def same_url(self, old: 'ParsedExposure') -> bool:
|
def same_url(self, old: "ParsedExposure") -> bool:
|
||||||
return self.url == old.url
|
return self.url == old.url
|
||||||
|
|
||||||
def same_contents(self, old: Optional['ParsedExposure']) -> bool:
|
def same_contents(self, old: Optional["ParsedExposure"]) -> bool:
|
||||||
# existing when it didn't before is a change!
|
# existing when it didn't before is a change!
|
||||||
if old is None:
|
if old is None:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return (
|
return (
|
||||||
self.same_fqn(old) and
|
self.same_fqn(old)
|
||||||
self.same_exposure_type(old) and
|
and self.same_exposure_type(old)
|
||||||
self.same_owner(old) and
|
and self.same_owner(old)
|
||||||
self.same_maturity(old) and
|
and self.same_maturity(old)
|
||||||
self.same_url(old) and
|
and self.same_url(old)
|
||||||
self.same_description(old) and
|
and self.same_description(old)
|
||||||
self.same_depends_on(old) and
|
and self.same_depends_on(old)
|
||||||
True
|
and True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,13 +4,12 @@ from dbt.contracts.util import (
|
|||||||
Mergeable,
|
Mergeable,
|
||||||
Replaceable,
|
Replaceable,
|
||||||
)
|
)
|
||||||
|
|
||||||
# trigger the PathEncoder
|
# trigger the PathEncoder
|
||||||
import dbt.helper_types # noqa:F401
|
import dbt.helper_types # noqa:F401
|
||||||
from dbt.exceptions import CompilationException
|
from dbt.exceptions import CompilationException
|
||||||
|
|
||||||
from dbt.dataclass_schema import (
|
from dbt.dataclass_schema import dbtClassMixin, StrEnum, ExtensibleDbtClassMixin
|
||||||
dbtClassMixin, StrEnum, ExtensibleDbtClassMixin
|
|
||||||
)
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
@@ -37,13 +36,15 @@ class HasSQL:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UnparsedMacro(UnparsedBaseNode, HasSQL):
|
class UnparsedMacro(UnparsedBaseNode, HasSQL):
|
||||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Macro]})
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Macro]})
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UnparsedNode(UnparsedBaseNode, HasSQL):
|
class UnparsedNode(UnparsedBaseNode, HasSQL):
|
||||||
name: str
|
name: str
|
||||||
resource_type: NodeType = field(metadata={'restrict': [
|
resource_type: NodeType = field(
|
||||||
|
metadata={
|
||||||
|
"restrict": [
|
||||||
NodeType.Model,
|
NodeType.Model,
|
||||||
NodeType.Analysis,
|
NodeType.Analysis,
|
||||||
NodeType.Test,
|
NodeType.Test,
|
||||||
@@ -51,7 +52,9 @@ class UnparsedNode(UnparsedBaseNode, HasSQL):
|
|||||||
NodeType.Operation,
|
NodeType.Operation,
|
||||||
NodeType.Seed,
|
NodeType.Seed,
|
||||||
NodeType.RPCCall,
|
NodeType.RPCCall,
|
||||||
]})
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def search_name(self):
|
def search_name(self):
|
||||||
@@ -60,9 +63,7 @@ class UnparsedNode(UnparsedBaseNode, HasSQL):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UnparsedRunHook(UnparsedNode):
|
class UnparsedRunHook(UnparsedNode):
|
||||||
resource_type: NodeType = field(
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Operation]})
|
||||||
metadata={'restrict': [NodeType.Operation]}
|
|
||||||
)
|
|
||||||
index: Optional[int] = None
|
index: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
@@ -72,10 +73,9 @@ class Docs(dbtClassMixin, Replaceable):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class HasDocs(AdditionalPropertiesMixin, ExtensibleDbtClassMixin,
|
class HasDocs(AdditionalPropertiesMixin, ExtensibleDbtClassMixin, Replaceable):
|
||||||
Replaceable):
|
|
||||||
name: str
|
name: str
|
||||||
description: str = ''
|
description: str = ""
|
||||||
meta: Dict[str, Any] = field(default_factory=dict)
|
meta: Dict[str, Any] = field(default_factory=dict)
|
||||||
data_type: Optional[str] = None
|
data_type: Optional[str] = None
|
||||||
docs: Docs = field(default_factory=Docs)
|
docs: Docs = field(default_factory=Docs)
|
||||||
@@ -131,7 +131,7 @@ class UnparsedNodeUpdate(HasColumnTests, HasTests, HasYamlMetadata):
|
|||||||
class MacroArgument(dbtClassMixin):
|
class MacroArgument(dbtClassMixin):
|
||||||
name: str
|
name: str
|
||||||
type: Optional[str] = None
|
type: Optional[str] = None
|
||||||
description: str = ''
|
description: str = ""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -140,12 +140,12 @@ class UnparsedMacroUpdate(HasDocs, HasYamlMetadata):
|
|||||||
|
|
||||||
|
|
||||||
class TimePeriod(StrEnum):
|
class TimePeriod(StrEnum):
|
||||||
minute = 'minute'
|
minute = "minute"
|
||||||
hour = 'hour'
|
hour = "hour"
|
||||||
day = 'day'
|
day = "day"
|
||||||
|
|
||||||
def plural(self) -> str:
|
def plural(self) -> str:
|
||||||
return str(self) + 's'
|
return str(self) + "s"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -167,6 +167,7 @@ class FreshnessThreshold(dbtClassMixin, Mergeable):
|
|||||||
|
|
||||||
def status(self, age: float) -> "dbt.contracts.results.FreshnessStatus":
|
def status(self, age: float) -> "dbt.contracts.results.FreshnessStatus":
|
||||||
from dbt.contracts.results import FreshnessStatus
|
from dbt.contracts.results import FreshnessStatus
|
||||||
|
|
||||||
if self.error_after and self.error_after.exceeded(age):
|
if self.error_after and self.error_after.exceeded(age):
|
||||||
return FreshnessStatus.Error
|
return FreshnessStatus.Error
|
||||||
elif self.warn_after and self.warn_after.exceeded(age):
|
elif self.warn_after and self.warn_after.exceeded(age):
|
||||||
@@ -179,24 +180,21 @@ class FreshnessThreshold(dbtClassMixin, Mergeable):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AdditionalPropertiesAllowed(
|
class AdditionalPropertiesAllowed(AdditionalPropertiesMixin, ExtensibleDbtClassMixin):
|
||||||
AdditionalPropertiesMixin,
|
|
||||||
ExtensibleDbtClassMixin
|
|
||||||
):
|
|
||||||
_extra: Dict[str, Any] = field(default_factory=dict)
|
_extra: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ExternalPartition(AdditionalPropertiesAllowed, Replaceable):
|
class ExternalPartition(AdditionalPropertiesAllowed, Replaceable):
|
||||||
name: str = ''
|
name: str = ""
|
||||||
description: str = ''
|
description: str = ""
|
||||||
data_type: str = ''
|
data_type: str = ""
|
||||||
meta: Dict[str, Any] = field(default_factory=dict)
|
meta: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.name == '' or self.data_type == '':
|
if self.name == "" or self.data_type == "":
|
||||||
raise CompilationException(
|
raise CompilationException(
|
||||||
'External partition columns must have names and data types'
|
"External partition columns must have names and data types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -225,43 +223,39 @@ class UnparsedSourceTableDefinition(HasColumnTests, HasTests):
|
|||||||
loaded_at_field: Optional[str] = None
|
loaded_at_field: Optional[str] = None
|
||||||
identifier: Optional[str] = None
|
identifier: Optional[str] = None
|
||||||
quoting: Quoting = field(default_factory=Quoting)
|
quoting: Quoting = field(default_factory=Quoting)
|
||||||
freshness: Optional[FreshnessThreshold] = field(
|
freshness: Optional[FreshnessThreshold] = field(default_factory=FreshnessThreshold)
|
||||||
default_factory=FreshnessThreshold
|
|
||||||
)
|
|
||||||
external: Optional[ExternalTable] = None
|
external: Optional[ExternalTable] = None
|
||||||
tags: List[str] = field(default_factory=list)
|
tags: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
def __post_serialize__(self, dct):
|
def __post_serialize__(self, dct):
|
||||||
dct = super().__post_serialize__(dct)
|
dct = super().__post_serialize__(dct)
|
||||||
if 'freshness' not in dct and self.freshness is None:
|
if "freshness" not in dct and self.freshness is None:
|
||||||
dct['freshness'] = None
|
dct["freshness"] = None
|
||||||
return dct
|
return dct
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UnparsedSourceDefinition(dbtClassMixin, Replaceable):
|
class UnparsedSourceDefinition(dbtClassMixin, Replaceable):
|
||||||
name: str
|
name: str
|
||||||
description: str = ''
|
description: str = ""
|
||||||
meta: Dict[str, Any] = field(default_factory=dict)
|
meta: Dict[str, Any] = field(default_factory=dict)
|
||||||
database: Optional[str] = None
|
database: Optional[str] = None
|
||||||
schema: Optional[str] = None
|
schema: Optional[str] = None
|
||||||
loader: str = ''
|
loader: str = ""
|
||||||
quoting: Quoting = field(default_factory=Quoting)
|
quoting: Quoting = field(default_factory=Quoting)
|
||||||
freshness: Optional[FreshnessThreshold] = field(
|
freshness: Optional[FreshnessThreshold] = field(default_factory=FreshnessThreshold)
|
||||||
default_factory=FreshnessThreshold
|
|
||||||
)
|
|
||||||
loaded_at_field: Optional[str] = None
|
loaded_at_field: Optional[str] = None
|
||||||
tables: List[UnparsedSourceTableDefinition] = field(default_factory=list)
|
tables: List[UnparsedSourceTableDefinition] = field(default_factory=list)
|
||||||
tags: List[str] = field(default_factory=list)
|
tags: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def yaml_key(self) -> 'str':
|
def yaml_key(self) -> "str":
|
||||||
return 'sources'
|
return "sources"
|
||||||
|
|
||||||
def __post_serialize__(self, dct):
|
def __post_serialize__(self, dct):
|
||||||
dct = super().__post_serialize__(dct)
|
dct = super().__post_serialize__(dct)
|
||||||
if 'freshnewss' not in dct and self.freshness is None:
|
if "freshnewss" not in dct and self.freshness is None:
|
||||||
dct['freshness'] = None
|
dct["freshness"] = None
|
||||||
return dct
|
return dct
|
||||||
|
|
||||||
|
|
||||||
@@ -275,9 +269,7 @@ class SourceTablePatch(dbtClassMixin):
|
|||||||
loaded_at_field: Optional[str] = None
|
loaded_at_field: Optional[str] = None
|
||||||
identifier: Optional[str] = None
|
identifier: Optional[str] = None
|
||||||
quoting: Quoting = field(default_factory=Quoting)
|
quoting: Quoting = field(default_factory=Quoting)
|
||||||
freshness: Optional[FreshnessThreshold] = field(
|
freshness: Optional[FreshnessThreshold] = field(default_factory=FreshnessThreshold)
|
||||||
default_factory=FreshnessThreshold
|
|
||||||
)
|
|
||||||
external: Optional[ExternalTable] = None
|
external: Optional[ExternalTable] = None
|
||||||
tags: Optional[List[str]] = None
|
tags: Optional[List[str]] = None
|
||||||
tests: Optional[List[TestDef]] = None
|
tests: Optional[List[TestDef]] = None
|
||||||
@@ -285,13 +277,13 @@ class SourceTablePatch(dbtClassMixin):
|
|||||||
|
|
||||||
def to_patch_dict(self) -> Dict[str, Any]:
|
def to_patch_dict(self) -> Dict[str, Any]:
|
||||||
dct = self.to_dict(omit_none=True)
|
dct = self.to_dict(omit_none=True)
|
||||||
remove_keys = ('name')
|
remove_keys = "name"
|
||||||
for key in remove_keys:
|
for key in remove_keys:
|
||||||
if key in dct:
|
if key in dct:
|
||||||
del dct[key]
|
del dct[key]
|
||||||
|
|
||||||
if self.freshness is None:
|
if self.freshness is None:
|
||||||
dct['freshness'] = None
|
dct["freshness"] = None
|
||||||
|
|
||||||
return dct
|
return dct
|
||||||
|
|
||||||
@@ -299,13 +291,13 @@ class SourceTablePatch(dbtClassMixin):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class SourcePatch(dbtClassMixin, Replaceable):
|
class SourcePatch(dbtClassMixin, Replaceable):
|
||||||
name: str = field(
|
name: str = field(
|
||||||
metadata=dict(description='The name of the source to override'),
|
metadata=dict(description="The name of the source to override"),
|
||||||
)
|
)
|
||||||
overrides: str = field(
|
overrides: str = field(
|
||||||
metadata=dict(description='The package of the source to override'),
|
metadata=dict(description="The package of the source to override"),
|
||||||
)
|
)
|
||||||
path: Path = field(
|
path: Path = field(
|
||||||
metadata=dict(description='The path to the patch-defining yml file'),
|
metadata=dict(description="The path to the patch-defining yml file"),
|
||||||
)
|
)
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
meta: Optional[Dict[str, Any]] = None
|
meta: Optional[Dict[str, Any]] = None
|
||||||
@@ -322,13 +314,13 @@ class SourcePatch(dbtClassMixin, Replaceable):
|
|||||||
|
|
||||||
def to_patch_dict(self) -> Dict[str, Any]:
|
def to_patch_dict(self) -> Dict[str, Any]:
|
||||||
dct = self.to_dict(omit_none=True)
|
dct = self.to_dict(omit_none=True)
|
||||||
remove_keys = ('name', 'overrides', 'tables', 'path')
|
remove_keys = ("name", "overrides", "tables", "path")
|
||||||
for key in remove_keys:
|
for key in remove_keys:
|
||||||
if key in dct:
|
if key in dct:
|
||||||
del dct[key]
|
del dct[key]
|
||||||
|
|
||||||
if self.freshness is None:
|
if self.freshness is None:
|
||||||
dct['freshness'] = None
|
dct["freshness"] = None
|
||||||
|
|
||||||
return dct
|
return dct
|
||||||
|
|
||||||
@@ -360,9 +352,9 @@ class UnparsedDocumentationFile(UnparsedDocumentation):
|
|||||||
# can't use total_ordering decorator here, as str provides an ordering already
|
# can't use total_ordering decorator here, as str provides an ordering already
|
||||||
# and it's not the one we want.
|
# and it's not the one we want.
|
||||||
class Maturity(StrEnum):
|
class Maturity(StrEnum):
|
||||||
low = 'low'
|
low = "low"
|
||||||
medium = 'medium'
|
medium = "medium"
|
||||||
high = 'high'
|
high = "high"
|
||||||
|
|
||||||
def __lt__(self, other):
|
def __lt__(self, other):
|
||||||
if not isinstance(other, Maturity):
|
if not isinstance(other, Maturity):
|
||||||
@@ -387,17 +379,17 @@ class Maturity(StrEnum):
|
|||||||
|
|
||||||
|
|
||||||
class ExposureType(StrEnum):
|
class ExposureType(StrEnum):
|
||||||
Dashboard = 'dashboard'
|
Dashboard = "dashboard"
|
||||||
Notebook = 'notebook'
|
Notebook = "notebook"
|
||||||
Analysis = 'analysis'
|
Analysis = "analysis"
|
||||||
ML = 'ml'
|
ML = "ml"
|
||||||
Application = 'application'
|
Application = "application"
|
||||||
|
|
||||||
|
|
||||||
class MaturityType(StrEnum):
|
class MaturityType(StrEnum):
|
||||||
Low = 'low'
|
Low = "low"
|
||||||
Medium = 'medium'
|
Medium = "medium"
|
||||||
High = 'high'
|
High = "high"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -411,7 +403,7 @@ class UnparsedExposure(dbtClassMixin, Replaceable):
|
|||||||
name: str
|
name: str
|
||||||
type: ExposureType
|
type: ExposureType
|
||||||
owner: ExposureOwner
|
owner: ExposureOwner
|
||||||
description: str = ''
|
description: str = ""
|
||||||
maturity: Optional[MaturityType] = None
|
maturity: Optional[MaturityType] = None
|
||||||
url: Optional[str] = None
|
url: Optional[str] = None
|
||||||
depends_on: List[str] = field(default_factory=list)
|
depends_on: List[str] = field(default_factory=list)
|
||||||
|
|||||||
@@ -5,24 +5,26 @@ from dbt.logger import GLOBAL_LOGGER as logger # noqa
|
|||||||
from dbt import tracking
|
from dbt import tracking
|
||||||
from dbt import ui
|
from dbt import ui
|
||||||
from dbt.dataclass_schema import (
|
from dbt.dataclass_schema import (
|
||||||
dbtClassMixin, ValidationError,
|
dbtClassMixin,
|
||||||
|
ValidationError,
|
||||||
HyphenatedDbtClassMixin,
|
HyphenatedDbtClassMixin,
|
||||||
ExtensibleDbtClassMixin,
|
ExtensibleDbtClassMixin,
|
||||||
register_pattern, ValidatedStringMixin
|
register_pattern,
|
||||||
|
ValidatedStringMixin,
|
||||||
)
|
)
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional, List, Dict, Union, Any
|
from typing import Optional, List, Dict, Union, Any
|
||||||
from mashumaro.types import SerializableType
|
from mashumaro.types import SerializableType
|
||||||
|
|
||||||
PIN_PACKAGE_URL = 'https://docs.getdbt.com/docs/package-management#section-specifying-package-versions' # noqa
|
PIN_PACKAGE_URL = "https://docs.getdbt.com/docs/package-management#section-specifying-package-versions" # noqa
|
||||||
DEFAULT_SEND_ANONYMOUS_USAGE_STATS = True
|
DEFAULT_SEND_ANONYMOUS_USAGE_STATS = True
|
||||||
|
|
||||||
|
|
||||||
class Name(ValidatedStringMixin):
|
class Name(ValidatedStringMixin):
|
||||||
ValidationRegex = r'^[^\d\W]\w*$'
|
ValidationRegex = r"^[^\d\W]\w*$"
|
||||||
|
|
||||||
|
|
||||||
register_pattern(Name, r'^[^\d\W]\w*$')
|
register_pattern(Name, r"^[^\d\W]\w*$")
|
||||||
|
|
||||||
|
|
||||||
class SemverString(str, SerializableType):
|
class SemverString(str, SerializableType):
|
||||||
@@ -30,7 +32,7 @@ class SemverString(str, SerializableType):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _deserialize(cls, value: str) -> 'SemverString':
|
def _deserialize(cls, value: str) -> "SemverString":
|
||||||
return SemverString(value)
|
return SemverString(value)
|
||||||
|
|
||||||
|
|
||||||
@@ -39,7 +41,7 @@ class SemverString(str, SerializableType):
|
|||||||
# 'semver lite'.
|
# 'semver lite'.
|
||||||
register_pattern(
|
register_pattern(
|
||||||
SemverString,
|
SemverString,
|
||||||
r'^(?:0|[1-9]\d*)\.(?:0|[1-9]\d*)(\.(?:0|[1-9]\d*))?$',
|
r"^(?:0|[1-9]\d*)\.(?:0|[1-9]\d*)(\.(?:0|[1-9]\d*))?$",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -105,8 +107,7 @@ class ProjectPackageMetadata:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_project(cls, project):
|
def from_project(cls, project):
|
||||||
return cls(name=project.project_name,
|
return cls(name=project.project_name, packages=project.packages.packages)
|
||||||
packages=project.packages.packages)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -124,46 +125,46 @@ class RegistryPackageMetadata(
|
|||||||
|
|
||||||
# A list of all the reserved words that packages may not have as names.
|
# A list of all the reserved words that packages may not have as names.
|
||||||
BANNED_PROJECT_NAMES = {
|
BANNED_PROJECT_NAMES = {
|
||||||
'_sql_results',
|
"_sql_results",
|
||||||
'adapter',
|
"adapter",
|
||||||
'api',
|
"api",
|
||||||
'column',
|
"column",
|
||||||
'config',
|
"config",
|
||||||
'context',
|
"context",
|
||||||
'database',
|
"database",
|
||||||
'env',
|
"env",
|
||||||
'env_var',
|
"env_var",
|
||||||
'exceptions',
|
"exceptions",
|
||||||
'execute',
|
"execute",
|
||||||
'flags',
|
"flags",
|
||||||
'fromjson',
|
"fromjson",
|
||||||
'fromyaml',
|
"fromyaml",
|
||||||
'graph',
|
"graph",
|
||||||
'invocation_id',
|
"invocation_id",
|
||||||
'load_agate_table',
|
"load_agate_table",
|
||||||
'load_result',
|
"load_result",
|
||||||
'log',
|
"log",
|
||||||
'model',
|
"model",
|
||||||
'modules',
|
"modules",
|
||||||
'post_hooks',
|
"post_hooks",
|
||||||
'pre_hooks',
|
"pre_hooks",
|
||||||
'ref',
|
"ref",
|
||||||
'render',
|
"render",
|
||||||
'return',
|
"return",
|
||||||
'run_started_at',
|
"run_started_at",
|
||||||
'schema',
|
"schema",
|
||||||
'source',
|
"source",
|
||||||
'sql',
|
"sql",
|
||||||
'sql_now',
|
"sql_now",
|
||||||
'store_result',
|
"store_result",
|
||||||
'store_raw_result',
|
"store_raw_result",
|
||||||
'target',
|
"target",
|
||||||
'this',
|
"this",
|
||||||
'tojson',
|
"tojson",
|
||||||
'toyaml',
|
"toyaml",
|
||||||
'try_or_compiler_error',
|
"try_or_compiler_error",
|
||||||
'var',
|
"var",
|
||||||
'write',
|
"write",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -198,7 +199,7 @@ class Project(HyphenatedDbtClassMixin, Replaceable):
|
|||||||
vars: Optional[Dict[str, Any]] = field(
|
vars: Optional[Dict[str, Any]] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata=dict(
|
metadata=dict(
|
||||||
description='map project names to their vars override dicts',
|
description="map project names to their vars override dicts",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
packages: List[PackageSpec] = field(default_factory=list)
|
packages: List[PackageSpec] = field(default_factory=list)
|
||||||
@@ -207,7 +208,7 @@ class Project(HyphenatedDbtClassMixin, Replaceable):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def validate(cls, data):
|
def validate(cls, data):
|
||||||
super().validate(data)
|
super().validate(data)
|
||||||
if data['name'] in BANNED_PROJECT_NAMES:
|
if data["name"] in BANNED_PROJECT_NAMES:
|
||||||
raise ValidationError(
|
raise ValidationError(
|
||||||
f"Invalid project name: {data['name']} is a reserved word"
|
f"Invalid project name: {data['name']} is a reserved word"
|
||||||
)
|
)
|
||||||
@@ -235,8 +236,8 @@ class UserConfig(ExtensibleDbtClassMixin, Replaceable, UserConfigContract):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ProfileConfig(HyphenatedDbtClassMixin, Replaceable):
|
class ProfileConfig(HyphenatedDbtClassMixin, Replaceable):
|
||||||
profile_name: str = field(metadata={'preserve_underscore': True})
|
profile_name: str = field(metadata={"preserve_underscore": True})
|
||||||
target_name: str = field(metadata={'preserve_underscore': True})
|
target_name: str = field(metadata={"preserve_underscore": True})
|
||||||
config: UserConfig
|
config: UserConfig
|
||||||
threads: int
|
threads: int
|
||||||
# TODO: make this a dynamic union of some kind?
|
# TODO: make this a dynamic union of some kind?
|
||||||
@@ -255,7 +256,7 @@ class ConfiguredQuoting(Quoting, Replaceable):
|
|||||||
class Configuration(Project, ProfileConfig):
|
class Configuration(Project, ProfileConfig):
|
||||||
cli_vars: Dict[str, Any] = field(
|
cli_vars: Dict[str, Any] = field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
metadata={'preserve_underscore': True},
|
metadata={"preserve_underscore": True},
|
||||||
)
|
)
|
||||||
quoting: Optional[ConfiguredQuoting] = None
|
quoting: Optional[ConfiguredQuoting] = None
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass, fields
|
||||||
from typing import (
|
from typing import (
|
||||||
Optional, Dict,
|
Optional,
|
||||||
|
Dict,
|
||||||
)
|
)
|
||||||
from typing_extensions import Protocol
|
from typing_extensions import Protocol
|
||||||
|
|
||||||
@@ -14,17 +15,17 @@ from dbt.utils import deep_merge
|
|||||||
|
|
||||||
|
|
||||||
class RelationType(StrEnum):
|
class RelationType(StrEnum):
|
||||||
Table = 'table'
|
Table = "table"
|
||||||
View = 'view'
|
View = "view"
|
||||||
CTE = 'cte'
|
CTE = "cte"
|
||||||
MaterializedView = 'materializedview'
|
MaterializedView = "materializedview"
|
||||||
External = 'external'
|
External = "external"
|
||||||
|
|
||||||
|
|
||||||
class ComponentName(StrEnum):
|
class ComponentName(StrEnum):
|
||||||
Database = 'database'
|
Database = "database"
|
||||||
Schema = 'schema'
|
Schema = "schema"
|
||||||
Identifier = 'identifier'
|
Identifier = "identifier"
|
||||||
|
|
||||||
|
|
||||||
class HasQuoting(Protocol):
|
class HasQuoting(Protocol):
|
||||||
@@ -43,12 +44,12 @@ class FakeAPIObject(dbtClassMixin, Replaceable, Mapping):
|
|||||||
raise KeyError(key) from None
|
raise KeyError(key) from None
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
deprecations.warn('not-a-dictionary', obj=self)
|
deprecations.warn("not-a-dictionary", obj=self)
|
||||||
for _, name in self._get_fields():
|
for _, name in self._get_fields():
|
||||||
yield name
|
yield name
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
deprecations.warn('not-a-dictionary', obj=self)
|
deprecations.warn("not-a-dictionary", obj=self)
|
||||||
return len(fields(self.__class__))
|
return len(fields(self.__class__))
|
||||||
|
|
||||||
def incorporate(self, **kwargs):
|
def incorporate(self, **kwargs):
|
||||||
@@ -72,8 +73,7 @@ class Policy(FakeAPIObject):
|
|||||||
return self.identifier
|
return self.identifier
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Got a key of {}, expected one of {}'
|
"Got a key of {}, expected one of {}".format(key, list(ComponentName))
|
||||||
.format(key, list(ComponentName))
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def replace_dict(self, dct: Dict[ComponentName, bool]):
|
def replace_dict(self, dct: Dict[ComponentName, bool]):
|
||||||
@@ -93,15 +93,15 @@ class Path(FakeAPIObject):
|
|||||||
# handle pesky jinja2.Undefined sneaking in here and messing up rende
|
# handle pesky jinja2.Undefined sneaking in here and messing up rende
|
||||||
if not isinstance(self.database, (type(None), str)):
|
if not isinstance(self.database, (type(None), str)):
|
||||||
raise CompilationException(
|
raise CompilationException(
|
||||||
'Got an invalid path database: {}'.format(self.database)
|
"Got an invalid path database: {}".format(self.database)
|
||||||
)
|
)
|
||||||
if not isinstance(self.schema, (type(None), str)):
|
if not isinstance(self.schema, (type(None), str)):
|
||||||
raise CompilationException(
|
raise CompilationException(
|
||||||
'Got an invalid path schema: {}'.format(self.schema)
|
"Got an invalid path schema: {}".format(self.schema)
|
||||||
)
|
)
|
||||||
if not isinstance(self.identifier, (type(None), str)):
|
if not isinstance(self.identifier, (type(None), str)):
|
||||||
raise CompilationException(
|
raise CompilationException(
|
||||||
'Got an invalid path identifier: {}'.format(self.identifier)
|
"Got an invalid path identifier: {}".format(self.identifier)
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_lowered_part(self, key: ComponentName) -> Optional[str]:
|
def get_lowered_part(self, key: ComponentName) -> Optional[str]:
|
||||||
@@ -119,8 +119,7 @@ class Path(FakeAPIObject):
|
|||||||
return self.identifier
|
return self.identifier
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Got a key of {}, expected one of {}'
|
"Got a key of {}, expected one of {}".format(key, list(ComponentName))
|
||||||
.format(key, list(ComponentName))
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def replace_dict(self, dct: Dict[ComponentName, str]):
|
def replace_dict(self, dct: Dict[ComponentName, str]):
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
from dbt.contracts.graph.manifest import CompileResultNode
|
from dbt.contracts.graph.manifest import CompileResultNode
|
||||||
from dbt.contracts.graph.unparsed import (
|
from dbt.contracts.graph.unparsed import FreshnessThreshold
|
||||||
FreshnessThreshold
|
|
||||||
)
|
|
||||||
from dbt.contracts.graph.parsed import ParsedSourceDefinition
|
from dbt.contracts.graph.parsed import ParsedSourceDefinition
|
||||||
from dbt.contracts.util import (
|
from dbt.contracts.util import (
|
||||||
BaseArtifactMetadata,
|
BaseArtifactMetadata,
|
||||||
@@ -24,7 +22,13 @@ import agate
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import (
|
from typing import (
|
||||||
Union, Dict, List, Optional, Any, NamedTuple, Sequence,
|
Union,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Any,
|
||||||
|
NamedTuple,
|
||||||
|
Sequence,
|
||||||
)
|
)
|
||||||
|
|
||||||
from dbt.clients.system import write_json
|
from dbt.clients.system import write_json
|
||||||
@@ -54,7 +58,7 @@ class collect_timing_info:
|
|||||||
def __exit__(self, exc_type, exc_value, traceback):
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
self.timing_info.end()
|
self.timing_info.end()
|
||||||
with JsonOnly(), TimingProcessor(self.timing_info):
|
with JsonOnly(), TimingProcessor(self.timing_info):
|
||||||
logger.debug('finished collecting timing info')
|
logger.debug("finished collecting timing info")
|
||||||
|
|
||||||
|
|
||||||
class NodeStatus(StrEnum):
|
class NodeStatus(StrEnum):
|
||||||
@@ -99,8 +103,8 @@ class BaseResult(dbtClassMixin):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def __pre_deserialize__(cls, data):
|
def __pre_deserialize__(cls, data):
|
||||||
data = super().__pre_deserialize__(data)
|
data = super().__pre_deserialize__(data)
|
||||||
if 'message' not in data:
|
if "message" not in data:
|
||||||
data['message'] = None
|
data["message"] = None
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
@@ -112,9 +116,8 @@ class NodeResult(BaseResult):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class RunResult(NodeResult):
|
class RunResult(NodeResult):
|
||||||
agate_table: Optional[agate.Table] = field(
|
agate_table: Optional[agate.Table] = field(
|
||||||
default=None, metadata={
|
default=None,
|
||||||
'serialize': lambda x: None, 'deserialize': lambda x: None
|
metadata={"serialize": lambda x: None, "deserialize": lambda x: None},
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -157,7 +160,7 @@ def process_run_result(result: RunResult) -> RunResultOutput:
|
|||||||
thread_id=result.thread_id,
|
thread_id=result.thread_id,
|
||||||
execution_time=result.execution_time,
|
execution_time=result.execution_time,
|
||||||
message=result.message,
|
message=result.message,
|
||||||
adapter_response=result.adapter_response
|
adapter_response=result.adapter_response,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -180,7 +183,7 @@ class RunExecutionResult(
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('run-results', 1)
|
@schema_version("run-results", 1)
|
||||||
class RunResultsArtifact(ExecutionResult, ArtifactMixin):
|
class RunResultsArtifact(ExecutionResult, ArtifactMixin):
|
||||||
results: Sequence[RunResultOutput]
|
results: Sequence[RunResultOutput]
|
||||||
args: Dict[str, Any] = field(default_factory=dict)
|
args: Dict[str, Any] = field(default_factory=dict)
|
||||||
@@ -202,7 +205,7 @@ class RunResultsArtifact(ExecutionResult, ArtifactMixin):
|
|||||||
metadata=meta,
|
metadata=meta,
|
||||||
results=processed_results,
|
results=processed_results,
|
||||||
elapsed_time=elapsed_time,
|
elapsed_time=elapsed_time,
|
||||||
args=args
|
args=args,
|
||||||
)
|
)
|
||||||
|
|
||||||
def write(self, path: str):
|
def write(self, path: str):
|
||||||
@@ -216,15 +219,14 @@ class RunOperationResult(ExecutionResult):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RunOperationResultMetadata(BaseArtifactMetadata):
|
class RunOperationResultMetadata(BaseArtifactMetadata):
|
||||||
dbt_schema_version: str = field(default_factory=lambda: str(
|
dbt_schema_version: str = field(
|
||||||
RunOperationResultsArtifact.dbt_schema_version
|
default_factory=lambda: str(RunOperationResultsArtifact.dbt_schema_version)
|
||||||
))
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('run-operation-result', 1)
|
@schema_version("run-operation-result", 1)
|
||||||
class RunOperationResultsArtifact(RunOperationResult, ArtifactMixin):
|
class RunOperationResultsArtifact(RunOperationResult, ArtifactMixin):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_success(
|
def from_success(
|
||||||
cls,
|
cls,
|
||||||
@@ -243,6 +245,7 @@ class RunOperationResultsArtifact(RunOperationResult, ArtifactMixin):
|
|||||||
success=success,
|
success=success,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# due to issues with typing.Union collapsing subclasses, this can't subclass
|
# due to issues with typing.Union collapsing subclasses, this can't subclass
|
||||||
# PartialResult
|
# PartialResult
|
||||||
|
|
||||||
@@ -261,7 +264,7 @@ class SourceFreshnessResult(NodeResult):
|
|||||||
|
|
||||||
|
|
||||||
class FreshnessErrorEnum(StrEnum):
|
class FreshnessErrorEnum(StrEnum):
|
||||||
runtime_error = 'runtime error'
|
runtime_error = "runtime error"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -291,14 +294,11 @@ class PartialSourceFreshnessResult(NodeResult):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
FreshnessNodeResult = Union[PartialSourceFreshnessResult,
|
FreshnessNodeResult = Union[PartialSourceFreshnessResult, SourceFreshnessResult]
|
||||||
SourceFreshnessResult]
|
|
||||||
FreshnessNodeOutput = Union[SourceFreshnessRuntimeError, SourceFreshnessOutput]
|
FreshnessNodeOutput = Union[SourceFreshnessRuntimeError, SourceFreshnessOutput]
|
||||||
|
|
||||||
|
|
||||||
def process_freshness_result(
|
def process_freshness_result(result: FreshnessNodeResult) -> FreshnessNodeOutput:
|
||||||
result: FreshnessNodeResult
|
|
||||||
) -> FreshnessNodeOutput:
|
|
||||||
unique_id = result.node.unique_id
|
unique_id = result.node.unique_id
|
||||||
if result.status == FreshnessStatus.RuntimeErr:
|
if result.status == FreshnessStatus.RuntimeErr:
|
||||||
return SourceFreshnessRuntimeError(
|
return SourceFreshnessRuntimeError(
|
||||||
@@ -310,16 +310,15 @@ def process_freshness_result(
|
|||||||
# we know that this must be a SourceFreshnessResult
|
# we know that this must be a SourceFreshnessResult
|
||||||
if not isinstance(result, SourceFreshnessResult):
|
if not isinstance(result, SourceFreshnessResult):
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
'Got {} instead of a SourceFreshnessResult for a '
|
"Got {} instead of a SourceFreshnessResult for a "
|
||||||
'non-error result in freshness execution!'
|
"non-error result in freshness execution!".format(type(result))
|
||||||
.format(type(result))
|
|
||||||
)
|
)
|
||||||
# if we're here, we must have a non-None freshness threshold
|
# if we're here, we must have a non-None freshness threshold
|
||||||
criteria = result.node.freshness
|
criteria = result.node.freshness
|
||||||
if criteria is None:
|
if criteria is None:
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
'Somehow evaluated a freshness result for a source '
|
"Somehow evaluated a freshness result for a source "
|
||||||
'that has no freshness criteria!'
|
"that has no freshness criteria!"
|
||||||
)
|
)
|
||||||
return SourceFreshnessOutput(
|
return SourceFreshnessOutput(
|
||||||
unique_id=unique_id,
|
unique_id=unique_id,
|
||||||
@@ -328,16 +327,14 @@ def process_freshness_result(
|
|||||||
max_loaded_at_time_ago_in_s=result.age,
|
max_loaded_at_time_ago_in_s=result.age,
|
||||||
status=result.status,
|
status=result.status,
|
||||||
criteria=criteria,
|
criteria=criteria,
|
||||||
adapter_response=result.adapter_response
|
adapter_response=result.adapter_response,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FreshnessMetadata(BaseArtifactMetadata):
|
class FreshnessMetadata(BaseArtifactMetadata):
|
||||||
dbt_schema_version: str = field(
|
dbt_schema_version: str = field(
|
||||||
default_factory=lambda: str(
|
default_factory=lambda: str(FreshnessExecutionResultArtifact.dbt_schema_version)
|
||||||
FreshnessExecutionResultArtifact.dbt_schema_version
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -358,7 +355,7 @@ class FreshnessResult(ExecutionResult):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('sources', 1)
|
@schema_version("sources", 1)
|
||||||
class FreshnessExecutionResultArtifact(
|
class FreshnessExecutionResultArtifact(
|
||||||
ArtifactMixin,
|
ArtifactMixin,
|
||||||
VersionedSchema,
|
VersionedSchema,
|
||||||
@@ -380,8 +377,7 @@ class FreshnessExecutionResultArtifact(
|
|||||||
Primitive = Union[bool, str, float, None]
|
Primitive = Union[bool, str, float, None]
|
||||||
|
|
||||||
CatalogKey = NamedTuple(
|
CatalogKey = NamedTuple(
|
||||||
'CatalogKey',
|
"CatalogKey", [("database", Optional[str]), ("schema", str), ("name", str)]
|
||||||
[('database', Optional[str]), ('schema', str), ('name', str)]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -450,13 +446,13 @@ class CatalogResults(dbtClassMixin):
|
|||||||
|
|
||||||
def __post_serialize__(self, dct):
|
def __post_serialize__(self, dct):
|
||||||
dct = super().__post_serialize__(dct)
|
dct = super().__post_serialize__(dct)
|
||||||
if '_compile_results' in dct:
|
if "_compile_results" in dct:
|
||||||
del dct['_compile_results']
|
del dct["_compile_results"]
|
||||||
return dct
|
return dct
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('catalog', 1)
|
@schema_version("catalog", 1)
|
||||||
class CatalogArtifact(CatalogResults, ArtifactMixin):
|
class CatalogArtifact(CatalogResults, ArtifactMixin):
|
||||||
metadata: CatalogMetadata
|
metadata: CatalogMetadata
|
||||||
|
|
||||||
@@ -467,8 +463,8 @@ class CatalogArtifact(CatalogResults, ArtifactMixin):
|
|||||||
nodes: Dict[str, CatalogTable],
|
nodes: Dict[str, CatalogTable],
|
||||||
sources: Dict[str, CatalogTable],
|
sources: Dict[str, CatalogTable],
|
||||||
compile_results: Optional[Any],
|
compile_results: Optional[Any],
|
||||||
errors: Optional[List[str]]
|
errors: Optional[List[str]],
|
||||||
) -> 'CatalogArtifact':
|
) -> "CatalogArtifact":
|
||||||
meta = CatalogMetadata(generated_at=generated_at)
|
meta = CatalogMetadata(generated_at=generated_at)
|
||||||
return cls(
|
return cls(
|
||||||
metadata=meta,
|
metadata=meta,
|
||||||
|
|||||||
@@ -10,7 +10,9 @@ from dbt.dataclass_schema import dbtClassMixin, StrEnum
|
|||||||
from dbt.contracts.graph.compiled import CompileResultNode
|
from dbt.contracts.graph.compiled import CompileResultNode
|
||||||
from dbt.contracts.graph.manifest import WritableManifest
|
from dbt.contracts.graph.manifest import WritableManifest
|
||||||
from dbt.contracts.results import (
|
from dbt.contracts.results import (
|
||||||
RunResult, RunResultsArtifact, TimingInfo,
|
RunResult,
|
||||||
|
RunResultsArtifact,
|
||||||
|
TimingInfo,
|
||||||
CatalogArtifact,
|
CatalogArtifact,
|
||||||
CatalogResults,
|
CatalogResults,
|
||||||
ExecutionResult,
|
ExecutionResult,
|
||||||
@@ -40,10 +42,10 @@ class RPCParameters(dbtClassMixin):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def __pre_deserialize__(cls, data, omit_none=True):
|
def __pre_deserialize__(cls, data, omit_none=True):
|
||||||
data = super().__pre_deserialize__(data)
|
data = super().__pre_deserialize__(data)
|
||||||
if 'timeout' not in data:
|
if "timeout" not in data:
|
||||||
data['timeout'] = None
|
data["timeout"] = None
|
||||||
if 'task_tags' not in data:
|
if "task_tags" not in data:
|
||||||
data['task_tags'] = None
|
data["task_tags"] = None
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
@@ -161,6 +163,7 @@ class GCParameters(RPCParameters):
|
|||||||
will be applied to the task manager before GC starts. By default the
|
will be applied to the task manager before GC starts. By default the
|
||||||
existing gc settings remain.
|
existing gc settings remain.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
task_ids: Optional[List[TaskID]] = None
|
task_ids: Optional[List[TaskID]] = None
|
||||||
before: Optional[datetime] = None
|
before: Optional[datetime] = None
|
||||||
settings: Optional[GCSettings] = None
|
settings: Optional[GCSettings] = None
|
||||||
@@ -182,6 +185,7 @@ class RPCSourceFreshnessParameters(RPCParameters):
|
|||||||
class GetManifestParameters(RPCParameters):
|
class GetManifestParameters(RPCParameters):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# Outputs
|
# Outputs
|
||||||
|
|
||||||
|
|
||||||
@@ -191,13 +195,13 @@ class RemoteResult(VersionedSchema):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('remote-deps-result', 1)
|
@schema_version("remote-deps-result", 1)
|
||||||
class RemoteDepsResult(RemoteResult):
|
class RemoteDepsResult(RemoteResult):
|
||||||
generated_at: datetime = field(default_factory=datetime.utcnow)
|
generated_at: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('remote-catalog-result', 1)
|
@schema_version("remote-catalog-result", 1)
|
||||||
class RemoteCatalogResults(CatalogResults, RemoteResult):
|
class RemoteCatalogResults(CatalogResults, RemoteResult):
|
||||||
generated_at: datetime = field(default_factory=datetime.utcnow)
|
generated_at: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
@@ -221,7 +225,7 @@ class RemoteCompileResultMixin(RemoteResult):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('remote-compile-result', 1)
|
@schema_version("remote-compile-result", 1)
|
||||||
class RemoteCompileResult(RemoteCompileResultMixin):
|
class RemoteCompileResult(RemoteCompileResultMixin):
|
||||||
generated_at: datetime = field(default_factory=datetime.utcnow)
|
generated_at: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
@@ -231,7 +235,7 @@ class RemoteCompileResult(RemoteCompileResultMixin):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('remote-execution-result', 1)
|
@schema_version("remote-execution-result", 1)
|
||||||
class RemoteExecutionResult(ExecutionResult, RemoteResult):
|
class RemoteExecutionResult(ExecutionResult, RemoteResult):
|
||||||
results: Sequence[RunResult]
|
results: Sequence[RunResult]
|
||||||
args: Dict[str, Any] = field(default_factory=dict)
|
args: Dict[str, Any] = field(default_factory=dict)
|
||||||
@@ -251,7 +255,7 @@ class RemoteExecutionResult(ExecutionResult, RemoteResult):
|
|||||||
cls,
|
cls,
|
||||||
base: RunExecutionResult,
|
base: RunExecutionResult,
|
||||||
logs: List[LogMessage],
|
logs: List[LogMessage],
|
||||||
) -> 'RemoteExecutionResult':
|
) -> "RemoteExecutionResult":
|
||||||
return cls(
|
return cls(
|
||||||
generated_at=base.generated_at,
|
generated_at=base.generated_at,
|
||||||
results=base.results,
|
results=base.results,
|
||||||
@@ -268,7 +272,7 @@ class ResultTable(dbtClassMixin):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('remote-run-operation-result', 1)
|
@schema_version("remote-run-operation-result", 1)
|
||||||
class RemoteRunOperationResult(RunOperationResult, RemoteResult):
|
class RemoteRunOperationResult(RunOperationResult, RemoteResult):
|
||||||
generated_at: datetime = field(default_factory=datetime.utcnow)
|
generated_at: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
@@ -277,7 +281,7 @@ class RemoteRunOperationResult(RunOperationResult, RemoteResult):
|
|||||||
cls,
|
cls,
|
||||||
base: RunOperationResultsArtifact,
|
base: RunOperationResultsArtifact,
|
||||||
logs: List[LogMessage],
|
logs: List[LogMessage],
|
||||||
) -> 'RemoteRunOperationResult':
|
) -> "RemoteRunOperationResult":
|
||||||
return cls(
|
return cls(
|
||||||
generated_at=base.metadata.generated_at,
|
generated_at=base.metadata.generated_at,
|
||||||
results=base.results,
|
results=base.results,
|
||||||
@@ -296,15 +300,14 @@ class RemoteRunOperationResult(RunOperationResult, RemoteResult):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('remote-freshness-result', 1)
|
@schema_version("remote-freshness-result", 1)
|
||||||
class RemoteFreshnessResult(FreshnessResult, RemoteResult):
|
class RemoteFreshnessResult(FreshnessResult, RemoteResult):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_local_result(
|
def from_local_result(
|
||||||
cls,
|
cls,
|
||||||
base: FreshnessResult,
|
base: FreshnessResult,
|
||||||
logs: List[LogMessage],
|
logs: List[LogMessage],
|
||||||
) -> 'RemoteFreshnessResult':
|
) -> "RemoteFreshnessResult":
|
||||||
return cls(
|
return cls(
|
||||||
metadata=base.metadata,
|
metadata=base.metadata,
|
||||||
results=base.results,
|
results=base.results,
|
||||||
@@ -318,7 +321,7 @@ class RemoteFreshnessResult(FreshnessResult, RemoteResult):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('remote-run-result', 1)
|
@schema_version("remote-run-result", 1)
|
||||||
class RemoteRunResult(RemoteCompileResultMixin):
|
class RemoteRunResult(RemoteCompileResultMixin):
|
||||||
table: ResultTable
|
table: ResultTable
|
||||||
generated_at: datetime = field(default_factory=datetime.utcnow)
|
generated_at: datetime = field(default_factory=datetime.utcnow)
|
||||||
@@ -336,14 +339,15 @@ RPCResult = Union[
|
|||||||
|
|
||||||
# GC types
|
# GC types
|
||||||
|
|
||||||
|
|
||||||
class GCResultState(StrEnum):
|
class GCResultState(StrEnum):
|
||||||
Deleted = 'deleted' # successful GC
|
Deleted = "deleted" # successful GC
|
||||||
Missing = 'missing' # nothing to GC
|
Missing = "missing" # nothing to GC
|
||||||
Running = 'running' # can't GC
|
Running = "running" # can't GC
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('remote-gc-result', 1)
|
@schema_version("remote-gc-result", 1)
|
||||||
class GCResult(RemoteResult):
|
class GCResult(RemoteResult):
|
||||||
logs: List[LogMessage] = field(default_factory=list)
|
logs: List[LogMessage] = field(default_factory=list)
|
||||||
deleted: List[TaskID] = field(default_factory=list)
|
deleted: List[TaskID] = field(default_factory=list)
|
||||||
@@ -358,21 +362,20 @@ class GCResult(RemoteResult):
|
|||||||
elif state == GCResultState.Deleted:
|
elif state == GCResultState.Deleted:
|
||||||
self.deleted.append(task_id)
|
self.deleted.append(task_id)
|
||||||
else:
|
else:
|
||||||
raise InternalException(
|
raise InternalException(f"Got invalid state in add_result: {state}")
|
||||||
f'Got invalid state in add_result: {state}'
|
|
||||||
)
|
|
||||||
|
|
||||||
# Task management types
|
# Task management types
|
||||||
|
|
||||||
|
|
||||||
class TaskHandlerState(StrEnum):
|
class TaskHandlerState(StrEnum):
|
||||||
NotStarted = 'not started'
|
NotStarted = "not started"
|
||||||
Initializing = 'initializing'
|
Initializing = "initializing"
|
||||||
Running = 'running'
|
Running = "running"
|
||||||
Success = 'success'
|
Success = "success"
|
||||||
Error = 'error'
|
Error = "error"
|
||||||
Killed = 'killed'
|
Killed = "killed"
|
||||||
Failed = 'failed'
|
Failed = "failed"
|
||||||
|
|
||||||
def __lt__(self, other) -> bool:
|
def __lt__(self, other) -> bool:
|
||||||
"""A logical ordering for TaskHandlerState:
|
"""A logical ordering for TaskHandlerState:
|
||||||
@@ -380,7 +383,7 @@ class TaskHandlerState(StrEnum):
|
|||||||
NotStarted < Initializing < Running < (Success, Error, Killed, Failed)
|
NotStarted < Initializing < Running < (Success, Error, Killed, Failed)
|
||||||
"""
|
"""
|
||||||
if not isinstance(other, TaskHandlerState):
|
if not isinstance(other, TaskHandlerState):
|
||||||
raise TypeError('cannot compare to non-TaskHandlerState')
|
raise TypeError("cannot compare to non-TaskHandlerState")
|
||||||
order = (self.NotStarted, self.Initializing, self.Running)
|
order = (self.NotStarted, self.Initializing, self.Running)
|
||||||
smaller = set()
|
smaller = set()
|
||||||
for value in order:
|
for value in order:
|
||||||
@@ -392,13 +395,11 @@ class TaskHandlerState(StrEnum):
|
|||||||
|
|
||||||
def __le__(self, other) -> bool:
|
def __le__(self, other) -> bool:
|
||||||
# so that ((Success <= Error) is True)
|
# so that ((Success <= Error) is True)
|
||||||
return ((self < other) or
|
return (self < other) or (self == other) or (self.finished and other.finished)
|
||||||
(self == other) or
|
|
||||||
(self.finished and other.finished))
|
|
||||||
|
|
||||||
def __gt__(self, other) -> bool:
|
def __gt__(self, other) -> bool:
|
||||||
if not isinstance(other, TaskHandlerState):
|
if not isinstance(other, TaskHandlerState):
|
||||||
raise TypeError('cannot compare to non-TaskHandlerState')
|
raise TypeError("cannot compare to non-TaskHandlerState")
|
||||||
order = (self.NotStarted, self.Initializing, self.Running)
|
order = (self.NotStarted, self.Initializing, self.Running)
|
||||||
smaller = set()
|
smaller = set()
|
||||||
for value in order:
|
for value in order:
|
||||||
@@ -409,9 +410,7 @@ class TaskHandlerState(StrEnum):
|
|||||||
|
|
||||||
def __ge__(self, other) -> bool:
|
def __ge__(self, other) -> bool:
|
||||||
# so that ((Success <= Error) is True)
|
# so that ((Success <= Error) is True)
|
||||||
return ((self > other) or
|
return (self > other) or (self == other) or (self.finished and other.finished)
|
||||||
(self == other) or
|
|
||||||
(self.finished and other.finished))
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def finished(self) -> bool:
|
def finished(self) -> bool:
|
||||||
@@ -430,7 +429,7 @@ class TaskTiming(dbtClassMixin):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def __pre_deserialize__(cls, data):
|
def __pre_deserialize__(cls, data):
|
||||||
data = super().__pre_deserialize__(data)
|
data = super().__pre_deserialize__(data)
|
||||||
for field_name in ('start', 'end', 'elapsed'):
|
for field_name in ("start", "end", "elapsed"):
|
||||||
if field_name not in data:
|
if field_name not in data:
|
||||||
data[field_name] = None
|
data[field_name] = None
|
||||||
return data
|
return data
|
||||||
@@ -447,27 +446,27 @@ class TaskRow(TaskTiming):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('remote-ps-result', 1)
|
@schema_version("remote-ps-result", 1)
|
||||||
class PSResult(RemoteResult):
|
class PSResult(RemoteResult):
|
||||||
rows: List[TaskRow]
|
rows: List[TaskRow]
|
||||||
|
|
||||||
|
|
||||||
class KillResultStatus(StrEnum):
|
class KillResultStatus(StrEnum):
|
||||||
Missing = 'missing'
|
Missing = "missing"
|
||||||
NotStarted = 'not_started'
|
NotStarted = "not_started"
|
||||||
Killed = 'killed'
|
Killed = "killed"
|
||||||
Finished = 'finished'
|
Finished = "finished"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('remote-kill-result', 1)
|
@schema_version("remote-kill-result", 1)
|
||||||
class KillResult(RemoteResult):
|
class KillResult(RemoteResult):
|
||||||
state: KillResultStatus = KillResultStatus.Missing
|
state: KillResultStatus = KillResultStatus.Missing
|
||||||
logs: List[LogMessage] = field(default_factory=list)
|
logs: List[LogMessage] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('remote-manifest-result', 1)
|
@schema_version("remote-manifest-result", 1)
|
||||||
class GetManifestResult(RemoteResult):
|
class GetManifestResult(RemoteResult):
|
||||||
manifest: Optional[WritableManifest] = None
|
manifest: Optional[WritableManifest] = None
|
||||||
|
|
||||||
@@ -498,29 +497,28 @@ class PollResult(RemoteResult, TaskTiming):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def __pre_deserialize__(cls, data):
|
def __pre_deserialize__(cls, data):
|
||||||
data = super().__pre_deserialize__(data)
|
data = super().__pre_deserialize__(data)
|
||||||
for field_name in ('start', 'end', 'elapsed'):
|
for field_name in ("start", "end", "elapsed"):
|
||||||
if field_name not in data:
|
if field_name not in data:
|
||||||
data[field_name] = None
|
data[field_name] = None
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('poll-remote-deps-result', 1)
|
@schema_version("poll-remote-deps-result", 1)
|
||||||
class PollRemoteEmptyCompleteResult(PollResult, RemoteResult):
|
class PollRemoteEmptyCompleteResult(PollResult, RemoteResult):
|
||||||
state: TaskHandlerState = field(
|
state: TaskHandlerState = field(
|
||||||
metadata=restrict_to(TaskHandlerState.Success,
|
metadata=restrict_to(TaskHandlerState.Success, TaskHandlerState.Failed),
|
||||||
TaskHandlerState.Failed),
|
|
||||||
)
|
)
|
||||||
generated_at: datetime = field(default_factory=datetime.utcnow)
|
generated_at: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_result(
|
def from_result(
|
||||||
cls: Type['PollRemoteEmptyCompleteResult'],
|
cls: Type["PollRemoteEmptyCompleteResult"],
|
||||||
base: RemoteDepsResult,
|
base: RemoteDepsResult,
|
||||||
tags: TaskTags,
|
tags: TaskTags,
|
||||||
timing: TaskTiming,
|
timing: TaskTiming,
|
||||||
logs: List[LogMessage],
|
logs: List[LogMessage],
|
||||||
) -> 'PollRemoteEmptyCompleteResult':
|
) -> "PollRemoteEmptyCompleteResult":
|
||||||
return cls(
|
return cls(
|
||||||
logs=logs,
|
logs=logs,
|
||||||
tags=tags,
|
tags=tags,
|
||||||
@@ -528,12 +526,12 @@ class PollRemoteEmptyCompleteResult(PollResult, RemoteResult):
|
|||||||
start=timing.start,
|
start=timing.start,
|
||||||
end=timing.end,
|
end=timing.end,
|
||||||
elapsed=timing.elapsed,
|
elapsed=timing.elapsed,
|
||||||
generated_at=base.generated_at
|
generated_at=base.generated_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('poll-remote-killed-result', 1)
|
@schema_version("poll-remote-killed-result", 1)
|
||||||
class PollKilledResult(PollResult):
|
class PollKilledResult(PollResult):
|
||||||
state: TaskHandlerState = field(
|
state: TaskHandlerState = field(
|
||||||
metadata=restrict_to(TaskHandlerState.Killed),
|
metadata=restrict_to(TaskHandlerState.Killed),
|
||||||
@@ -541,24 +539,23 @@ class PollKilledResult(PollResult):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('poll-remote-execution-result', 1)
|
@schema_version("poll-remote-execution-result", 1)
|
||||||
class PollExecuteCompleteResult(
|
class PollExecuteCompleteResult(
|
||||||
RemoteExecutionResult,
|
RemoteExecutionResult,
|
||||||
PollResult,
|
PollResult,
|
||||||
):
|
):
|
||||||
state: TaskHandlerState = field(
|
state: TaskHandlerState = field(
|
||||||
metadata=restrict_to(TaskHandlerState.Success,
|
metadata=restrict_to(TaskHandlerState.Success, TaskHandlerState.Failed),
|
||||||
TaskHandlerState.Failed),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_result(
|
def from_result(
|
||||||
cls: Type['PollExecuteCompleteResult'],
|
cls: Type["PollExecuteCompleteResult"],
|
||||||
base: RemoteExecutionResult,
|
base: RemoteExecutionResult,
|
||||||
tags: TaskTags,
|
tags: TaskTags,
|
||||||
timing: TaskTiming,
|
timing: TaskTiming,
|
||||||
logs: List[LogMessage],
|
logs: List[LogMessage],
|
||||||
) -> 'PollExecuteCompleteResult':
|
) -> "PollExecuteCompleteResult":
|
||||||
return cls(
|
return cls(
|
||||||
results=base.results,
|
results=base.results,
|
||||||
elapsed_time=base.elapsed_time,
|
elapsed_time=base.elapsed_time,
|
||||||
@@ -573,24 +570,23 @@ class PollExecuteCompleteResult(
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('poll-remote-compile-result', 1)
|
@schema_version("poll-remote-compile-result", 1)
|
||||||
class PollCompileCompleteResult(
|
class PollCompileCompleteResult(
|
||||||
RemoteCompileResult,
|
RemoteCompileResult,
|
||||||
PollResult,
|
PollResult,
|
||||||
):
|
):
|
||||||
state: TaskHandlerState = field(
|
state: TaskHandlerState = field(
|
||||||
metadata=restrict_to(TaskHandlerState.Success,
|
metadata=restrict_to(TaskHandlerState.Success, TaskHandlerState.Failed),
|
||||||
TaskHandlerState.Failed),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_result(
|
def from_result(
|
||||||
cls: Type['PollCompileCompleteResult'],
|
cls: Type["PollCompileCompleteResult"],
|
||||||
base: RemoteCompileResult,
|
base: RemoteCompileResult,
|
||||||
tags: TaskTags,
|
tags: TaskTags,
|
||||||
timing: TaskTiming,
|
timing: TaskTiming,
|
||||||
logs: List[LogMessage],
|
logs: List[LogMessage],
|
||||||
) -> 'PollCompileCompleteResult':
|
) -> "PollCompileCompleteResult":
|
||||||
return cls(
|
return cls(
|
||||||
raw_sql=base.raw_sql,
|
raw_sql=base.raw_sql,
|
||||||
compiled_sql=base.compiled_sql,
|
compiled_sql=base.compiled_sql,
|
||||||
@@ -602,29 +598,28 @@ class PollCompileCompleteResult(
|
|||||||
start=timing.start,
|
start=timing.start,
|
||||||
end=timing.end,
|
end=timing.end,
|
||||||
elapsed=timing.elapsed,
|
elapsed=timing.elapsed,
|
||||||
generated_at=base.generated_at
|
generated_at=base.generated_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('poll-remote-run-result', 1)
|
@schema_version("poll-remote-run-result", 1)
|
||||||
class PollRunCompleteResult(
|
class PollRunCompleteResult(
|
||||||
RemoteRunResult,
|
RemoteRunResult,
|
||||||
PollResult,
|
PollResult,
|
||||||
):
|
):
|
||||||
state: TaskHandlerState = field(
|
state: TaskHandlerState = field(
|
||||||
metadata=restrict_to(TaskHandlerState.Success,
|
metadata=restrict_to(TaskHandlerState.Success, TaskHandlerState.Failed),
|
||||||
TaskHandlerState.Failed),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_result(
|
def from_result(
|
||||||
cls: Type['PollRunCompleteResult'],
|
cls: Type["PollRunCompleteResult"],
|
||||||
base: RemoteRunResult,
|
base: RemoteRunResult,
|
||||||
tags: TaskTags,
|
tags: TaskTags,
|
||||||
timing: TaskTiming,
|
timing: TaskTiming,
|
||||||
logs: List[LogMessage],
|
logs: List[LogMessage],
|
||||||
) -> 'PollRunCompleteResult':
|
) -> "PollRunCompleteResult":
|
||||||
return cls(
|
return cls(
|
||||||
raw_sql=base.raw_sql,
|
raw_sql=base.raw_sql,
|
||||||
compiled_sql=base.compiled_sql,
|
compiled_sql=base.compiled_sql,
|
||||||
@@ -637,29 +632,28 @@ class PollRunCompleteResult(
|
|||||||
start=timing.start,
|
start=timing.start,
|
||||||
end=timing.end,
|
end=timing.end,
|
||||||
elapsed=timing.elapsed,
|
elapsed=timing.elapsed,
|
||||||
generated_at=base.generated_at
|
generated_at=base.generated_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('poll-remote-run-operation-result', 1)
|
@schema_version("poll-remote-run-operation-result", 1)
|
||||||
class PollRunOperationCompleteResult(
|
class PollRunOperationCompleteResult(
|
||||||
RemoteRunOperationResult,
|
RemoteRunOperationResult,
|
||||||
PollResult,
|
PollResult,
|
||||||
):
|
):
|
||||||
state: TaskHandlerState = field(
|
state: TaskHandlerState = field(
|
||||||
metadata=restrict_to(TaskHandlerState.Success,
|
metadata=restrict_to(TaskHandlerState.Success, TaskHandlerState.Failed),
|
||||||
TaskHandlerState.Failed),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_result(
|
def from_result(
|
||||||
cls: Type['PollRunOperationCompleteResult'],
|
cls: Type["PollRunOperationCompleteResult"],
|
||||||
base: RemoteRunOperationResult,
|
base: RemoteRunOperationResult,
|
||||||
tags: TaskTags,
|
tags: TaskTags,
|
||||||
timing: TaskTiming,
|
timing: TaskTiming,
|
||||||
logs: List[LogMessage],
|
logs: List[LogMessage],
|
||||||
) -> 'PollRunOperationCompleteResult':
|
) -> "PollRunOperationCompleteResult":
|
||||||
return cls(
|
return cls(
|
||||||
success=base.success,
|
success=base.success,
|
||||||
results=base.results,
|
results=base.results,
|
||||||
@@ -675,21 +669,20 @@ class PollRunOperationCompleteResult(
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('poll-remote-catalog-result', 1)
|
@schema_version("poll-remote-catalog-result", 1)
|
||||||
class PollCatalogCompleteResult(RemoteCatalogResults, PollResult):
|
class PollCatalogCompleteResult(RemoteCatalogResults, PollResult):
|
||||||
state: TaskHandlerState = field(
|
state: TaskHandlerState = field(
|
||||||
metadata=restrict_to(TaskHandlerState.Success,
|
metadata=restrict_to(TaskHandlerState.Success, TaskHandlerState.Failed),
|
||||||
TaskHandlerState.Failed),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_result(
|
def from_result(
|
||||||
cls: Type['PollCatalogCompleteResult'],
|
cls: Type["PollCatalogCompleteResult"],
|
||||||
base: RemoteCatalogResults,
|
base: RemoteCatalogResults,
|
||||||
tags: TaskTags,
|
tags: TaskTags,
|
||||||
timing: TaskTiming,
|
timing: TaskTiming,
|
||||||
logs: List[LogMessage],
|
logs: List[LogMessage],
|
||||||
) -> 'PollCatalogCompleteResult':
|
) -> "PollCatalogCompleteResult":
|
||||||
return cls(
|
return cls(
|
||||||
nodes=base.nodes,
|
nodes=base.nodes,
|
||||||
sources=base.sources,
|
sources=base.sources,
|
||||||
@@ -706,27 +699,26 @@ class PollCatalogCompleteResult(RemoteCatalogResults, PollResult):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('poll-remote-in-progress-result', 1)
|
@schema_version("poll-remote-in-progress-result", 1)
|
||||||
class PollInProgressResult(PollResult):
|
class PollInProgressResult(PollResult):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('poll-remote-get-manifest-result', 1)
|
@schema_version("poll-remote-get-manifest-result", 1)
|
||||||
class PollGetManifestResult(GetManifestResult, PollResult):
|
class PollGetManifestResult(GetManifestResult, PollResult):
|
||||||
state: TaskHandlerState = field(
|
state: TaskHandlerState = field(
|
||||||
metadata=restrict_to(TaskHandlerState.Success,
|
metadata=restrict_to(TaskHandlerState.Success, TaskHandlerState.Failed),
|
||||||
TaskHandlerState.Failed),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_result(
|
def from_result(
|
||||||
cls: Type['PollGetManifestResult'],
|
cls: Type["PollGetManifestResult"],
|
||||||
base: GetManifestResult,
|
base: GetManifestResult,
|
||||||
tags: TaskTags,
|
tags: TaskTags,
|
||||||
timing: TaskTiming,
|
timing: TaskTiming,
|
||||||
logs: List[LogMessage],
|
logs: List[LogMessage],
|
||||||
) -> 'PollGetManifestResult':
|
) -> "PollGetManifestResult":
|
||||||
return cls(
|
return cls(
|
||||||
manifest=base.manifest,
|
manifest=base.manifest,
|
||||||
logs=logs,
|
logs=logs,
|
||||||
@@ -739,21 +731,20 @@ class PollGetManifestResult(GetManifestResult, PollResult):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('poll-remote-freshness-result', 1)
|
@schema_version("poll-remote-freshness-result", 1)
|
||||||
class PollFreshnessResult(RemoteFreshnessResult, PollResult):
|
class PollFreshnessResult(RemoteFreshnessResult, PollResult):
|
||||||
state: TaskHandlerState = field(
|
state: TaskHandlerState = field(
|
||||||
metadata=restrict_to(TaskHandlerState.Success,
|
metadata=restrict_to(TaskHandlerState.Success, TaskHandlerState.Failed),
|
||||||
TaskHandlerState.Failed),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_result(
|
def from_result(
|
||||||
cls: Type['PollFreshnessResult'],
|
cls: Type["PollFreshnessResult"],
|
||||||
base: RemoteFreshnessResult,
|
base: RemoteFreshnessResult,
|
||||||
tags: TaskTags,
|
tags: TaskTags,
|
||||||
timing: TaskTiming,
|
timing: TaskTiming,
|
||||||
logs: List[LogMessage],
|
logs: List[LogMessage],
|
||||||
) -> 'PollFreshnessResult':
|
) -> "PollFreshnessResult":
|
||||||
return cls(
|
return cls(
|
||||||
logs=logs,
|
logs=logs,
|
||||||
tags=tags,
|
tags=tags,
|
||||||
@@ -766,18 +757,19 @@ class PollFreshnessResult(RemoteFreshnessResult, PollResult):
|
|||||||
elapsed_time=base.elapsed_time,
|
elapsed_time=base.elapsed_time,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Manifest parsing types
|
# Manifest parsing types
|
||||||
|
|
||||||
|
|
||||||
class ManifestStatus(StrEnum):
|
class ManifestStatus(StrEnum):
|
||||||
Init = 'init'
|
Init = "init"
|
||||||
Compiling = 'compiling'
|
Compiling = "compiling"
|
||||||
Ready = 'ready'
|
Ready = "ready"
|
||||||
Error = 'error'
|
Error = "error"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('remote-status-result', 1)
|
@schema_version("remote-status-result", 1)
|
||||||
class LastParse(RemoteResult):
|
class LastParse(RemoteResult):
|
||||||
state: ManifestStatus = ManifestStatus.Init
|
state: ManifestStatus = ManifestStatus.Init
|
||||||
logs: List[LogMessage] = field(default_factory=list)
|
logs: List[LogMessage] = field(default_factory=list)
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from typing import List, Dict, Any, Union
|
|||||||
class SelectorDefinition(dbtClassMixin):
|
class SelectorDefinition(dbtClassMixin):
|
||||||
name: str
|
name: str
|
||||||
definition: Union[str, Dict[str, Any]]
|
definition: Union[str, Dict[str, Any]]
|
||||||
description: str = ''
|
description: str = ""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ class PreviousState:
|
|||||||
self.path: Path = path
|
self.path: Path = path
|
||||||
self.manifest: Optional[WritableManifest] = None
|
self.manifest: Optional[WritableManifest] = None
|
||||||
|
|
||||||
manifest_path = self.path / 'manifest.json'
|
manifest_path = self.path / "manifest.json"
|
||||||
if manifest_path.exists() and manifest_path.is_file():
|
if manifest_path.exists() and manifest_path.is_file():
|
||||||
try:
|
try:
|
||||||
self.manifest = WritableManifest.read(str(manifest_path))
|
self.manifest = WritableManifest.read(str(manifest_path))
|
||||||
|
|||||||
@@ -1,9 +1,7 @@
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import (
|
from typing import List, Tuple, ClassVar, Type, TypeVar, Dict, Any, Optional
|
||||||
List, Tuple, ClassVar, Type, TypeVar, Dict, Any, Optional
|
|
||||||
)
|
|
||||||
|
|
||||||
from dbt.clients.system import write_json, read_json
|
from dbt.clients.system import write_json, read_json
|
||||||
from dbt.exceptions import (
|
from dbt.exceptions import (
|
||||||
@@ -57,9 +55,7 @@ class Mergeable(Replaceable):
|
|||||||
|
|
||||||
class Writable:
|
class Writable:
|
||||||
def write(self, path: str):
|
def write(self, path: str):
|
||||||
write_json(
|
write_json(path, self.to_dict(omit_none=False)) # type: ignore
|
||||||
path, self.to_dict(omit_none=False) # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AdditionalPropertiesMixin:
|
class AdditionalPropertiesMixin:
|
||||||
@@ -68,6 +64,7 @@ class AdditionalPropertiesMixin:
|
|||||||
The underlying class definition must include a type definition for a field
|
The underlying class definition must include a type definition for a field
|
||||||
named '_extra' that is of type `Dict[str, Any]`.
|
named '_extra' that is of type `Dict[str, Any]`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ADDITIONAL_PROPERTIES = True
|
ADDITIONAL_PROPERTIES = True
|
||||||
|
|
||||||
# This takes attributes in the dictionary that are
|
# This takes attributes in the dictionary that are
|
||||||
@@ -86,10 +83,10 @@ class AdditionalPropertiesMixin:
|
|||||||
cls_keys = cls._get_field_names()
|
cls_keys = cls._get_field_names()
|
||||||
new_dict = {}
|
new_dict = {}
|
||||||
for key, value in data.items():
|
for key, value in data.items():
|
||||||
if key not in cls_keys and key != '_extra':
|
if key not in cls_keys and key != "_extra":
|
||||||
if '_extra' not in new_dict:
|
if "_extra" not in new_dict:
|
||||||
new_dict['_extra'] = {}
|
new_dict["_extra"] = {}
|
||||||
new_dict['_extra'][key] = value
|
new_dict["_extra"][key] = value
|
||||||
else:
|
else:
|
||||||
new_dict[key] = value
|
new_dict[key] = value
|
||||||
data = new_dict
|
data = new_dict
|
||||||
@@ -99,8 +96,8 @@ class AdditionalPropertiesMixin:
|
|||||||
def __post_serialize__(self, dct):
|
def __post_serialize__(self, dct):
|
||||||
data = super().__post_serialize__(dct)
|
data = super().__post_serialize__(dct)
|
||||||
data.update(self.extra)
|
data.update(self.extra)
|
||||||
if '_extra' in data:
|
if "_extra" in data:
|
||||||
del data['_extra']
|
del data["_extra"]
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def replace(self, **kwargs):
|
def replace(self, **kwargs):
|
||||||
@@ -126,8 +123,8 @@ class Readable:
|
|||||||
return cls.from_dict(data) # type: ignore
|
return cls.from_dict(data) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
BASE_SCHEMAS_URL = 'https://schemas.getdbt.com/'
|
BASE_SCHEMAS_URL = "https://schemas.getdbt.com/"
|
||||||
SCHEMA_PATH = 'dbt/{name}/v{version}.json'
|
SCHEMA_PATH = "dbt/{name}/v{version}.json"
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
@@ -137,24 +134,22 @@ class SchemaVersion:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def path(self) -> str:
|
def path(self) -> str:
|
||||||
return SCHEMA_PATH.format(
|
return SCHEMA_PATH.format(name=self.name, version=self.version)
|
||||||
name=self.name,
|
|
||||||
version=self.version
|
|
||||||
)
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return BASE_SCHEMAS_URL + self.path
|
return BASE_SCHEMAS_URL + self.path
|
||||||
|
|
||||||
|
|
||||||
SCHEMA_VERSION_KEY = 'dbt_schema_version'
|
SCHEMA_VERSION_KEY = "dbt_schema_version"
|
||||||
|
|
||||||
|
|
||||||
METADATA_ENV_PREFIX = 'DBT_ENV_CUSTOM_ENV_'
|
METADATA_ENV_PREFIX = "DBT_ENV_CUSTOM_ENV_"
|
||||||
|
|
||||||
|
|
||||||
def get_metadata_env() -> Dict[str, str]:
|
def get_metadata_env() -> Dict[str, str]:
|
||||||
return {
|
return {
|
||||||
k[len(METADATA_ENV_PREFIX):]: v for k, v in os.environ.items()
|
k[len(METADATA_ENV_PREFIX) :]: v
|
||||||
|
for k, v in os.environ.items()
|
||||||
if k.startswith(METADATA_ENV_PREFIX)
|
if k.startswith(METADATA_ENV_PREFIX)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -163,12 +158,8 @@ def get_metadata_env() -> Dict[str, str]:
|
|||||||
class BaseArtifactMetadata(dbtClassMixin):
|
class BaseArtifactMetadata(dbtClassMixin):
|
||||||
dbt_schema_version: str
|
dbt_schema_version: str
|
||||||
dbt_version: str = __version__
|
dbt_version: str = __version__
|
||||||
generated_at: datetime = dataclasses.field(
|
generated_at: datetime = dataclasses.field(default_factory=datetime.utcnow)
|
||||||
default_factory=datetime.utcnow
|
invocation_id: Optional[str] = dataclasses.field(default_factory=get_invocation_id)
|
||||||
)
|
|
||||||
invocation_id: Optional[str] = dataclasses.field(
|
|
||||||
default_factory=get_invocation_id
|
|
||||||
)
|
|
||||||
env: Dict[str, str] = dataclasses.field(default_factory=get_metadata_env)
|
env: Dict[str, str] = dataclasses.field(default_factory=get_metadata_env)
|
||||||
|
|
||||||
|
|
||||||
@@ -179,6 +170,7 @@ def schema_version(name: str, version: int):
|
|||||||
version=version,
|
version=version,
|
||||||
)
|
)
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
return inner
|
return inner
|
||||||
|
|
||||||
|
|
||||||
@@ -190,11 +182,11 @@ class VersionedSchema(dbtClassMixin):
|
|||||||
def json_schema(cls, embeddable: bool = False) -> Dict[str, Any]:
|
def json_schema(cls, embeddable: bool = False) -> Dict[str, Any]:
|
||||||
result = super().json_schema(embeddable=embeddable)
|
result = super().json_schema(embeddable=embeddable)
|
||||||
if not embeddable:
|
if not embeddable:
|
||||||
result['$id'] = str(cls.dbt_schema_version)
|
result["$id"] = str(cls.dbt_schema_version)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar('T', bound='ArtifactMixin')
|
T = TypeVar("T", bound="ArtifactMixin")
|
||||||
|
|
||||||
|
|
||||||
# metadata should really be a Generic[T_M] where T_M is a TypeVar bound to
|
# metadata should really be a Generic[T_M] where T_M is a TypeVar bound to
|
||||||
@@ -208,6 +200,4 @@ class ArtifactMixin(VersionedSchema, Writable, Readable):
|
|||||||
def validate(cls, data):
|
def validate(cls, data):
|
||||||
super().validate(data)
|
super().validate(data)
|
||||||
if cls.dbt_schema_version is None:
|
if cls.dbt_schema_version is None:
|
||||||
raise InternalException(
|
raise InternalException("Cannot call from_dict with no schema version!")
|
||||||
'Cannot call from_dict with no schema version!'
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
from typing import (
|
from typing import (
|
||||||
Type, ClassVar, cast,
|
Type,
|
||||||
|
ClassVar,
|
||||||
|
cast,
|
||||||
)
|
)
|
||||||
import re
|
import re
|
||||||
from dataclasses import fields
|
from dataclasses import fields
|
||||||
@@ -11,9 +13,7 @@ from hologram import JsonSchemaMixin, FieldEncoder, ValidationError
|
|||||||
|
|
||||||
# type: ignore
|
# type: ignore
|
||||||
from mashumaro import DataClassDictMixin
|
from mashumaro import DataClassDictMixin
|
||||||
from mashumaro.config import (
|
from mashumaro.config import TO_DICT_ADD_OMIT_NONE_FLAG, BaseConfig as MashBaseConfig
|
||||||
TO_DICT_ADD_OMIT_NONE_FLAG, BaseConfig as MashBaseConfig
|
|
||||||
)
|
|
||||||
from mashumaro.types import SerializableType, SerializationStrategy
|
from mashumaro.types import SerializableType, SerializationStrategy
|
||||||
|
|
||||||
|
|
||||||
@@ -26,9 +26,7 @@ class DateTimeSerialization(SerializationStrategy):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
def deserialize(self, value):
|
def deserialize(self, value):
|
||||||
return (
|
return value if isinstance(value, datetime) else parse(cast(str, value))
|
||||||
value if isinstance(value, datetime) else parse(cast(str, value))
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# This class pulls in both JsonSchemaMixin from Hologram and
|
# This class pulls in both JsonSchemaMixin from Hologram and
|
||||||
@@ -60,8 +58,8 @@ class dbtClassMixin(DataClassDictMixin, JsonSchemaMixin):
|
|||||||
if self._hyphenated:
|
if self._hyphenated:
|
||||||
new_dict = {}
|
new_dict = {}
|
||||||
for key in dct:
|
for key in dct:
|
||||||
if '_' in key:
|
if "_" in key:
|
||||||
new_key = key.replace('_', '-')
|
new_key = key.replace("_", "-")
|
||||||
new_dict[new_key] = dct[key]
|
new_dict[new_key] = dct[key]
|
||||||
else:
|
else:
|
||||||
new_dict[key] = dct[key]
|
new_dict[key] = dct[key]
|
||||||
@@ -76,8 +74,8 @@ class dbtClassMixin(DataClassDictMixin, JsonSchemaMixin):
|
|||||||
if cls._hyphenated:
|
if cls._hyphenated:
|
||||||
new_dict = {}
|
new_dict = {}
|
||||||
for key in data:
|
for key in data:
|
||||||
if '-' in key:
|
if "-" in key:
|
||||||
new_key = key.replace('-', '_')
|
new_key = key.replace("-", "_")
|
||||||
new_dict[new_key] = data[key]
|
new_dict[new_key] = data[key]
|
||||||
else:
|
else:
|
||||||
new_dict[key] = data[key]
|
new_dict[key] = data[key]
|
||||||
@@ -89,16 +87,16 @@ class dbtClassMixin(DataClassDictMixin, JsonSchemaMixin):
|
|||||||
# hologram and in mashumaro.
|
# hologram and in mashumaro.
|
||||||
def _local_to_dict(self, **kwargs):
|
def _local_to_dict(self, **kwargs):
|
||||||
args = {}
|
args = {}
|
||||||
if 'omit_none' in kwargs:
|
if "omit_none" in kwargs:
|
||||||
args['omit_none'] = kwargs['omit_none']
|
args["omit_none"] = kwargs["omit_none"]
|
||||||
return self.to_dict(**args)
|
return self.to_dict(**args)
|
||||||
|
|
||||||
|
|
||||||
class ValidatedStringMixin(str, SerializableType):
|
class ValidatedStringMixin(str, SerializableType):
|
||||||
ValidationRegex = ''
|
ValidationRegex = ""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _deserialize(cls, value: str) -> 'ValidatedStringMixin':
|
def _deserialize(cls, value: str) -> "ValidatedStringMixin":
|
||||||
cls.validate(value)
|
cls.validate(value)
|
||||||
return ValidatedStringMixin(value)
|
return ValidatedStringMixin(value)
|
||||||
|
|
||||||
|
|||||||
@@ -14,39 +14,31 @@ class DBTDeprecation:
|
|||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
if self._name is not None:
|
if self._name is not None:
|
||||||
return self._name
|
return self._name
|
||||||
raise NotImplementedError(
|
raise NotImplementedError("name not implemented for {}".format(self))
|
||||||
'name not implemented for {}'.format(self)
|
|
||||||
)
|
|
||||||
|
|
||||||
def track_deprecation_warn(self) -> None:
|
def track_deprecation_warn(self) -> None:
|
||||||
if dbt.tracking.active_user is not None:
|
if dbt.tracking.active_user is not None:
|
||||||
dbt.tracking.track_deprecation_warn({
|
dbt.tracking.track_deprecation_warn({"deprecation_name": self.name})
|
||||||
"deprecation_name": self.name
|
|
||||||
})
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
if self._description is not None:
|
if self._description is not None:
|
||||||
return self._description
|
return self._description
|
||||||
raise NotImplementedError(
|
raise NotImplementedError("description not implemented for {}".format(self))
|
||||||
'description not implemented for {}'.format(self)
|
|
||||||
)
|
|
||||||
|
|
||||||
def show(self, *args, **kwargs) -> None:
|
def show(self, *args, **kwargs) -> None:
|
||||||
if self.name not in active_deprecations:
|
if self.name not in active_deprecations:
|
||||||
desc = self.description.format(**kwargs)
|
desc = self.description.format(**kwargs)
|
||||||
msg = ui.line_wrap_message(
|
msg = ui.line_wrap_message(desc, prefix="* Deprecation Warning: ")
|
||||||
desc, prefix='* Deprecation Warning: '
|
|
||||||
)
|
|
||||||
dbt.exceptions.warn_or_error(msg)
|
dbt.exceptions.warn_or_error(msg)
|
||||||
self.track_deprecation_warn()
|
self.track_deprecation_warn()
|
||||||
active_deprecations.add(self.name)
|
active_deprecations.add(self.name)
|
||||||
|
|
||||||
|
|
||||||
class MaterializationReturnDeprecation(DBTDeprecation):
|
class MaterializationReturnDeprecation(DBTDeprecation):
|
||||||
_name = 'materialization-return'
|
_name = "materialization-return"
|
||||||
|
|
||||||
_description = '''\
|
_description = """\
|
||||||
The materialization ("{materialization}") did not explicitly return a list
|
The materialization ("{materialization}") did not explicitly return a list
|
||||||
of relations to add to the cache. By default the target relation will be
|
of relations to add to the cache. By default the target relation will be
|
||||||
added, but this behavior will be removed in a future version of dbt.
|
added, but this behavior will be removed in a future version of dbt.
|
||||||
@@ -56,22 +48,22 @@ class MaterializationReturnDeprecation(DBTDeprecation):
|
|||||||
For more information, see:
|
For more information, see:
|
||||||
|
|
||||||
https://docs.getdbt.com/v0.15/docs/creating-new-materializations#section-6-returning-relations
|
https://docs.getdbt.com/v0.15/docs/creating-new-materializations#section-6-returning-relations
|
||||||
'''
|
"""
|
||||||
|
|
||||||
|
|
||||||
class NotADictionaryDeprecation(DBTDeprecation):
|
class NotADictionaryDeprecation(DBTDeprecation):
|
||||||
_name = 'not-a-dictionary'
|
_name = "not-a-dictionary"
|
||||||
|
|
||||||
_description = '''\
|
_description = """\
|
||||||
The object ("{obj}") was used as a dictionary. In a future version of dbt
|
The object ("{obj}") was used as a dictionary. In a future version of dbt
|
||||||
this capability will be removed from objects of this type.
|
this capability will be removed from objects of this type.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
|
|
||||||
class ColumnQuotingDeprecation(DBTDeprecation):
|
class ColumnQuotingDeprecation(DBTDeprecation):
|
||||||
_name = 'column-quoting-unset'
|
_name = "column-quoting-unset"
|
||||||
|
|
||||||
_description = '''\
|
_description = """\
|
||||||
The quote_columns parameter was not set for seeds, so the default value of
|
The quote_columns parameter was not set for seeds, so the default value of
|
||||||
False was chosen. The default will change to True in a future release.
|
False was chosen. The default will change to True in a future release.
|
||||||
|
|
||||||
@@ -80,13 +72,13 @@ class ColumnQuotingDeprecation(DBTDeprecation):
|
|||||||
For more information, see:
|
For more information, see:
|
||||||
|
|
||||||
https://docs.getdbt.com/v0.15/docs/seeds#section-specify-column-quoting
|
https://docs.getdbt.com/v0.15/docs/seeds#section-specify-column-quoting
|
||||||
'''
|
"""
|
||||||
|
|
||||||
|
|
||||||
class ModelsKeyNonModelDeprecation(DBTDeprecation):
|
class ModelsKeyNonModelDeprecation(DBTDeprecation):
|
||||||
_name = 'models-key-mismatch'
|
_name = "models-key-mismatch"
|
||||||
|
|
||||||
_description = '''\
|
_description = """\
|
||||||
"{node.name}" is a {node.resource_type} node, but it is specified in
|
"{node.name}" is a {node.resource_type} node, but it is specified in
|
||||||
the {patch.yaml_key} section of {patch.original_file_path}.
|
the {patch.yaml_key} section of {patch.original_file_path}.
|
||||||
|
|
||||||
@@ -96,25 +88,25 @@ class ModelsKeyNonModelDeprecation(DBTDeprecation):
|
|||||||
the {expected_key} key instead.
|
the {expected_key} key instead.
|
||||||
|
|
||||||
This warning will become an error in a future release.
|
This warning will become an error in a future release.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
|
|
||||||
class ExecuteMacrosReleaseDeprecation(DBTDeprecation):
|
class ExecuteMacrosReleaseDeprecation(DBTDeprecation):
|
||||||
_name = 'execute-macro-release'
|
_name = "execute-macro-release"
|
||||||
_description = '''\
|
_description = """\
|
||||||
The "release" argument to execute_macro is now ignored, and will be removed
|
The "release" argument to execute_macro is now ignored, and will be removed
|
||||||
in a future relase of dbt. At that time, providing a `release` argument
|
in a future relase of dbt. At that time, providing a `release` argument
|
||||||
will result in an error.
|
will result in an error.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
|
|
||||||
class AdapterMacroDeprecation(DBTDeprecation):
|
class AdapterMacroDeprecation(DBTDeprecation):
|
||||||
_name = 'adapter-macro'
|
_name = "adapter-macro"
|
||||||
_description = '''\
|
_description = """\
|
||||||
The "adapter_macro" macro has been deprecated. Instead, use the
|
The "adapter_macro" macro has been deprecated. Instead, use the
|
||||||
`adapter.dispatch` method to find a macro and call the result.
|
`adapter.dispatch` method to find a macro and call the result.
|
||||||
adapter_macro was called for: {macro_name}
|
adapter_macro was called for: {macro_name}
|
||||||
'''
|
"""
|
||||||
|
|
||||||
|
|
||||||
_adapter_renamed_description = """\
|
_adapter_renamed_description = """\
|
||||||
@@ -128,11 +120,11 @@ Documentation for {new_name} can be found here:
|
|||||||
|
|
||||||
|
|
||||||
def renamed_method(old_name: str, new_name: str):
|
def renamed_method(old_name: str, new_name: str):
|
||||||
|
|
||||||
class AdapterDeprecationWarning(DBTDeprecation):
|
class AdapterDeprecationWarning(DBTDeprecation):
|
||||||
_name = 'adapter:{}'.format(old_name)
|
_name = "adapter:{}".format(old_name)
|
||||||
_description = _adapter_renamed_description.format(old_name=old_name,
|
_description = _adapter_renamed_description.format(
|
||||||
new_name=new_name)
|
old_name=old_name, new_name=new_name
|
||||||
|
)
|
||||||
|
|
||||||
dep = AdapterDeprecationWarning()
|
dep = AdapterDeprecationWarning()
|
||||||
deprecations_list.append(dep)
|
deprecations_list.append(dep)
|
||||||
@@ -142,9 +134,7 @@ def renamed_method(old_name: str, new_name: str):
|
|||||||
def warn(name, *args, **kwargs):
|
def warn(name, *args, **kwargs):
|
||||||
if name not in deprecations:
|
if name not in deprecations:
|
||||||
# this should (hopefully) never happen
|
# this should (hopefully) never happen
|
||||||
raise RuntimeError(
|
raise RuntimeError("Error showing deprecation warning: {}".format(name))
|
||||||
"Error showing deprecation warning: {}".format(name)
|
|
||||||
)
|
|
||||||
|
|
||||||
deprecations[name].show(*args, **kwargs)
|
deprecations[name].show(*args, **kwargs)
|
||||||
|
|
||||||
@@ -163,9 +153,7 @@ deprecations_list: List[DBTDeprecation] = [
|
|||||||
AdapterMacroDeprecation(),
|
AdapterMacroDeprecation(),
|
||||||
]
|
]
|
||||||
|
|
||||||
deprecations: Dict[str, DBTDeprecation] = {
|
deprecations: Dict[str, DBTDeprecation] = {d.name: d for d in deprecations_list}
|
||||||
d.name: d for d in deprecations_list
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def reset_deprecations():
|
def reset_deprecations():
|
||||||
|
|||||||
@@ -22,12 +22,12 @@ def downloads_directory():
|
|||||||
# the user might have set an environment variable. Set it to that, and do
|
# the user might have set an environment variable. Set it to that, and do
|
||||||
# not remove it when finished.
|
# not remove it when finished.
|
||||||
if DOWNLOADS_PATH is None:
|
if DOWNLOADS_PATH is None:
|
||||||
DOWNLOADS_PATH = os.getenv('DBT_DOWNLOADS_DIR')
|
DOWNLOADS_PATH = os.getenv("DBT_DOWNLOADS_DIR")
|
||||||
remove_downloads = False
|
remove_downloads = False
|
||||||
# if we are making a per-run temp directory, remove it at the end of
|
# if we are making a per-run temp directory, remove it at the end of
|
||||||
# successful runs
|
# successful runs
|
||||||
if DOWNLOADS_PATH is None:
|
if DOWNLOADS_PATH is None:
|
||||||
DOWNLOADS_PATH = tempfile.mkdtemp(prefix='dbt-downloads-')
|
DOWNLOADS_PATH = tempfile.mkdtemp(prefix="dbt-downloads-")
|
||||||
remove_downloads = True
|
remove_downloads = True
|
||||||
|
|
||||||
system.make_directory(DOWNLOADS_PATH)
|
system.make_directory(DOWNLOADS_PATH)
|
||||||
@@ -62,7 +62,7 @@ class PinnedPackage(BasePackage):
|
|||||||
if not version:
|
if not version:
|
||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
return '{}@{}'.format(self.name, version)
|
return "{}@{}".format(self.name, version)
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def get_version(self) -> Optional[str]:
|
def get_version(self) -> Optional[str]:
|
||||||
@@ -94,8 +94,8 @@ class PinnedPackage(BasePackage):
|
|||||||
return os.path.join(project.modules_path, dest_dirname)
|
return os.path.join(project.modules_path, dest_dirname)
|
||||||
|
|
||||||
|
|
||||||
SomePinned = TypeVar('SomePinned', bound=PinnedPackage)
|
SomePinned = TypeVar("SomePinned", bound=PinnedPackage)
|
||||||
SomeUnpinned = TypeVar('SomeUnpinned', bound='UnpinnedPackage')
|
SomeUnpinned = TypeVar("SomeUnpinned", bound="UnpinnedPackage")
|
||||||
|
|
||||||
|
|
||||||
class UnpinnedPackage(Generic[SomePinned], BasePackage):
|
class UnpinnedPackage(Generic[SomePinned], BasePackage):
|
||||||
@@ -8,18 +8,16 @@ from dbt.contracts.project import (
|
|||||||
ProjectPackageMetadata,
|
ProjectPackageMetadata,
|
||||||
GitPackage,
|
GitPackage,
|
||||||
)
|
)
|
||||||
from dbt.deps.base import PinnedPackage, UnpinnedPackage, get_downloads_path
|
from dbt.deps import PinnedPackage, UnpinnedPackage, get_downloads_path
|
||||||
from dbt.exceptions import (
|
from dbt.exceptions import ExecutableError, warn_or_error, raise_dependency_error
|
||||||
ExecutableError, warn_or_error, raise_dependency_error
|
|
||||||
)
|
|
||||||
from dbt.logger import GLOBAL_LOGGER as logger
|
from dbt.logger import GLOBAL_LOGGER as logger
|
||||||
from dbt import ui
|
from dbt import ui
|
||||||
|
|
||||||
PIN_PACKAGE_URL = 'https://docs.getdbt.com/docs/package-management#section-specifying-package-versions' # noqa
|
PIN_PACKAGE_URL = "https://docs.getdbt.com/docs/package-management#section-specifying-package-versions" # noqa
|
||||||
|
|
||||||
|
|
||||||
def md5sum(s: str):
|
def md5sum(s: str):
|
||||||
return hashlib.md5(s.encode('latin-1')).hexdigest()
|
return hashlib.md5(s.encode("latin-1")).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
class GitPackageMixin:
|
class GitPackageMixin:
|
||||||
@@ -32,13 +30,11 @@ class GitPackageMixin:
|
|||||||
return self.git
|
return self.git
|
||||||
|
|
||||||
def source_type(self) -> str:
|
def source_type(self) -> str:
|
||||||
return 'git'
|
return "git"
|
||||||
|
|
||||||
|
|
||||||
class GitPinnedPackage(GitPackageMixin, PinnedPackage):
|
class GitPinnedPackage(GitPackageMixin, PinnedPackage):
|
||||||
def __init__(
|
def __init__(self, git: str, revision: str, warn_unpinned: bool = True) -> None:
|
||||||
self, git: str, revision: str, warn_unpinned: bool = True
|
|
||||||
) -> None:
|
|
||||||
super().__init__(git)
|
super().__init__(git)
|
||||||
self.revision = revision
|
self.revision = revision
|
||||||
self.warn_unpinned = warn_unpinned
|
self.warn_unpinned = warn_unpinned
|
||||||
@@ -48,15 +44,15 @@ class GitPinnedPackage(GitPackageMixin, PinnedPackage):
|
|||||||
return self.revision
|
return self.revision
|
||||||
|
|
||||||
def nice_version_name(self):
|
def nice_version_name(self):
|
||||||
if self.revision == 'HEAD':
|
if self.revision == "HEAD":
|
||||||
return 'HEAD (default branch)'
|
return "HEAD (default branch)"
|
||||||
else:
|
else:
|
||||||
return 'revision {}'.format(self.revision)
|
return "revision {}".format(self.revision)
|
||||||
|
|
||||||
def unpinned_msg(self):
|
def unpinned_msg(self):
|
||||||
if self.revision == 'HEAD':
|
if self.revision == "HEAD":
|
||||||
return 'not pinned, using HEAD (default branch)'
|
return "not pinned, using HEAD (default branch)"
|
||||||
elif self.revision in ('main', 'master'):
|
elif self.revision in ("main", "master"):
|
||||||
return f'pinned to the "{self.revision}" branch'
|
return f'pinned to the "{self.revision}" branch'
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
@@ -68,15 +64,17 @@ class GitPinnedPackage(GitPackageMixin, PinnedPackage):
|
|||||||
the path to the checked out directory."""
|
the path to the checked out directory."""
|
||||||
try:
|
try:
|
||||||
dir_ = git.clone_and_checkout(
|
dir_ = git.clone_and_checkout(
|
||||||
self.git, get_downloads_path(), branch=self.revision,
|
self.git,
|
||||||
dirname=self._checkout_name
|
get_downloads_path(),
|
||||||
|
branch=self.revision,
|
||||||
|
dirname=self._checkout_name,
|
||||||
)
|
)
|
||||||
except ExecutableError as exc:
|
except ExecutableError as exc:
|
||||||
if exc.cmd and exc.cmd[0] == 'git':
|
if exc.cmd and exc.cmd[0] == "git":
|
||||||
logger.error(
|
logger.error(
|
||||||
'Make sure git is installed on your machine. More '
|
"Make sure git is installed on your machine. More "
|
||||||
'information: '
|
"information: "
|
||||||
'https://docs.getdbt.com/docs/package-management'
|
"https://docs.getdbt.com/docs/package-management"
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
return os.path.join(get_downloads_path(), dir_)
|
return os.path.join(get_downloads_path(), dir_)
|
||||||
@@ -87,9 +85,10 @@ class GitPinnedPackage(GitPackageMixin, PinnedPackage):
|
|||||||
if self.unpinned_msg() and self.warn_unpinned:
|
if self.unpinned_msg() and self.warn_unpinned:
|
||||||
warn_or_error(
|
warn_or_error(
|
||||||
'The git package "{}" \n\tis {}.\n\tThis can introduce '
|
'The git package "{}" \n\tis {}.\n\tThis can introduce '
|
||||||
'breaking changes into your project without warning!\n\nSee {}'
|
"breaking changes into your project without warning!\n\nSee {}".format(
|
||||||
.format(self.git, self.unpinned_msg(), PIN_PACKAGE_URL),
|
self.git, self.unpinned_msg(), PIN_PACKAGE_URL
|
||||||
log_fmt=ui.yellow('WARNING: {}')
|
),
|
||||||
|
log_fmt=ui.yellow("WARNING: {}"),
|
||||||
)
|
)
|
||||||
loaded = Project.from_project_root(path, renderer)
|
loaded = Project.from_project_root(path, renderer)
|
||||||
return ProjectPackageMetadata.from_project(loaded)
|
return ProjectPackageMetadata.from_project(loaded)
|
||||||
@@ -114,26 +113,21 @@ class GitUnpinnedPackage(GitPackageMixin, UnpinnedPackage[GitPinnedPackage]):
|
|||||||
self.warn_unpinned = warn_unpinned
|
self.warn_unpinned = warn_unpinned
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_contract(
|
def from_contract(cls, contract: GitPackage) -> "GitUnpinnedPackage":
|
||||||
cls, contract: GitPackage
|
|
||||||
) -> 'GitUnpinnedPackage':
|
|
||||||
revisions = contract.get_revisions()
|
revisions = contract.get_revisions()
|
||||||
|
|
||||||
# we want to map None -> True
|
# we want to map None -> True
|
||||||
warn_unpinned = contract.warn_unpinned is not False
|
warn_unpinned = contract.warn_unpinned is not False
|
||||||
return cls(git=contract.git, revisions=revisions,
|
return cls(git=contract.git, revisions=revisions, warn_unpinned=warn_unpinned)
|
||||||
warn_unpinned=warn_unpinned)
|
|
||||||
|
|
||||||
def all_names(self) -> List[str]:
|
def all_names(self) -> List[str]:
|
||||||
if self.git.endswith('.git'):
|
if self.git.endswith(".git"):
|
||||||
other = self.git[:-4]
|
other = self.git[:-4]
|
||||||
else:
|
else:
|
||||||
other = self.git + '.git'
|
other = self.git + ".git"
|
||||||
return [self.git, other]
|
return [self.git, other]
|
||||||
|
|
||||||
def incorporate(
|
def incorporate(self, other: "GitUnpinnedPackage") -> "GitUnpinnedPackage":
|
||||||
self, other: 'GitUnpinnedPackage'
|
|
||||||
) -> 'GitUnpinnedPackage':
|
|
||||||
warn_unpinned = self.warn_unpinned and other.warn_unpinned
|
warn_unpinned = self.warn_unpinned and other.warn_unpinned
|
||||||
|
|
||||||
return GitUnpinnedPackage(
|
return GitUnpinnedPackage(
|
||||||
@@ -145,13 +139,13 @@ class GitUnpinnedPackage(GitPackageMixin, UnpinnedPackage[GitPinnedPackage]):
|
|||||||
def resolved(self) -> GitPinnedPackage:
|
def resolved(self) -> GitPinnedPackage:
|
||||||
requested = set(self.revisions)
|
requested = set(self.revisions)
|
||||||
if len(requested) == 0:
|
if len(requested) == 0:
|
||||||
requested = {'HEAD'}
|
requested = {"HEAD"}
|
||||||
elif len(requested) > 1:
|
elif len(requested) > 1:
|
||||||
raise_dependency_error(
|
raise_dependency_error(
|
||||||
'git dependencies should contain exactly one version. '
|
"git dependencies should contain exactly one version. "
|
||||||
'{} contains: {}'.format(self.git, requested))
|
"{} contains: {}".format(self.git, requested)
|
||||||
|
)
|
||||||
|
|
||||||
return GitPinnedPackage(
|
return GitPinnedPackage(
|
||||||
git=self.git, revision=requested.pop(),
|
git=self.git, revision=requested.pop(), warn_unpinned=self.warn_unpinned
|
||||||
warn_unpinned=self.warn_unpinned
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
from dbt.clients import system
|
from dbt.clients import system
|
||||||
from dbt.deps.base import PinnedPackage, UnpinnedPackage
|
from dbt.deps import PinnedPackage, UnpinnedPackage
|
||||||
from dbt.contracts.project import (
|
from dbt.contracts.project import (
|
||||||
ProjectPackageMetadata,
|
ProjectPackageMetadata,
|
||||||
LocalPackage,
|
LocalPackage,
|
||||||
@@ -19,7 +19,7 @@ class LocalPackageMixin:
|
|||||||
return self.local
|
return self.local
|
||||||
|
|
||||||
def source_type(self):
|
def source_type(self):
|
||||||
return 'local'
|
return "local"
|
||||||
|
|
||||||
|
|
||||||
class LocalPinnedPackage(LocalPackageMixin, PinnedPackage):
|
class LocalPinnedPackage(LocalPackageMixin, PinnedPackage):
|
||||||
@@ -30,7 +30,7 @@ class LocalPinnedPackage(LocalPackageMixin, PinnedPackage):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def nice_version_name(self):
|
def nice_version_name(self):
|
||||||
return '<local @ {}>'.format(self.local)
|
return "<local @ {}>".format(self.local)
|
||||||
|
|
||||||
def resolve_path(self, project):
|
def resolve_path(self, project):
|
||||||
return system.resolve_path_from_base(
|
return system.resolve_path_from_base(
|
||||||
@@ -39,9 +39,7 @@ class LocalPinnedPackage(LocalPackageMixin, PinnedPackage):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _fetch_metadata(self, project, renderer):
|
def _fetch_metadata(self, project, renderer):
|
||||||
loaded = project.from_project_root(
|
loaded = project.from_project_root(self.resolve_path(project), renderer)
|
||||||
self.resolve_path(project), renderer
|
|
||||||
)
|
|
||||||
return ProjectPackageMetadata.from_project(loaded)
|
return ProjectPackageMetadata.from_project(loaded)
|
||||||
|
|
||||||
def install(self, project, renderer):
|
def install(self, project, renderer):
|
||||||
@@ -57,27 +55,22 @@ class LocalPinnedPackage(LocalPackageMixin, PinnedPackage):
|
|||||||
system.remove_file(dest_path)
|
system.remove_file(dest_path)
|
||||||
|
|
||||||
if can_create_symlink:
|
if can_create_symlink:
|
||||||
logger.debug(' Creating symlink to local dependency.')
|
logger.debug(" Creating symlink to local dependency.")
|
||||||
system.make_symlink(src_path, dest_path)
|
system.make_symlink(src_path, dest_path)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.debug(' Symlinks are not available on this '
|
logger.debug(
|
||||||
'OS, copying dependency.')
|
" Symlinks are not available on this " "OS, copying dependency."
|
||||||
|
)
|
||||||
shutil.copytree(src_path, dest_path)
|
shutil.copytree(src_path, dest_path)
|
||||||
|
|
||||||
|
|
||||||
class LocalUnpinnedPackage(
|
class LocalUnpinnedPackage(LocalPackageMixin, UnpinnedPackage[LocalPinnedPackage]):
|
||||||
LocalPackageMixin, UnpinnedPackage[LocalPinnedPackage]
|
|
||||||
):
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_contract(
|
def from_contract(cls, contract: LocalPackage) -> "LocalUnpinnedPackage":
|
||||||
cls, contract: LocalPackage
|
|
||||||
) -> 'LocalUnpinnedPackage':
|
|
||||||
return cls(local=contract.local)
|
return cls(local=contract.local)
|
||||||
|
|
||||||
def incorporate(
|
def incorporate(self, other: "LocalUnpinnedPackage") -> "LocalUnpinnedPackage":
|
||||||
self, other: 'LocalUnpinnedPackage'
|
|
||||||
) -> 'LocalUnpinnedPackage':
|
|
||||||
return LocalUnpinnedPackage(local=self.local)
|
return LocalUnpinnedPackage(local=self.local)
|
||||||
|
|
||||||
def resolved(self) -> LocalPinnedPackage:
|
def resolved(self) -> LocalPinnedPackage:
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from dbt.contracts.project import (
|
|||||||
RegistryPackageMetadata,
|
RegistryPackageMetadata,
|
||||||
RegistryPackage,
|
RegistryPackage,
|
||||||
)
|
)
|
||||||
from dbt.deps.base import PinnedPackage, UnpinnedPackage, get_downloads_path
|
from dbt.deps import PinnedPackage, UnpinnedPackage, get_downloads_path
|
||||||
from dbt.exceptions import (
|
from dbt.exceptions import (
|
||||||
package_version_not_found,
|
package_version_not_found,
|
||||||
VersionsNotCompatibleException,
|
VersionsNotCompatibleException,
|
||||||
@@ -26,7 +26,7 @@ class RegistryPackageMixin:
|
|||||||
return self.package
|
return self.package
|
||||||
|
|
||||||
def source_type(self) -> str:
|
def source_type(self) -> str:
|
||||||
return 'hub'
|
return "hub"
|
||||||
|
|
||||||
|
|
||||||
class RegistryPinnedPackage(RegistryPackageMixin, PinnedPackage):
|
class RegistryPinnedPackage(RegistryPackageMixin, PinnedPackage):
|
||||||
@@ -39,13 +39,13 @@ class RegistryPinnedPackage(RegistryPackageMixin, PinnedPackage):
|
|||||||
return self.package
|
return self.package
|
||||||
|
|
||||||
def source_type(self):
|
def source_type(self):
|
||||||
return 'hub'
|
return "hub"
|
||||||
|
|
||||||
def get_version(self):
|
def get_version(self):
|
||||||
return self.version
|
return self.version
|
||||||
|
|
||||||
def nice_version_name(self):
|
def nice_version_name(self):
|
||||||
return 'version {}'.format(self.version)
|
return "version {}".format(self.version)
|
||||||
|
|
||||||
def _fetch_metadata(self, project, renderer) -> RegistryPackageMetadata:
|
def _fetch_metadata(self, project, renderer) -> RegistryPackageMetadata:
|
||||||
dct = registry.package_version(self.package, self.version)
|
dct = registry.package_version(self.package, self.version)
|
||||||
@@ -54,10 +54,8 @@ class RegistryPinnedPackage(RegistryPackageMixin, PinnedPackage):
|
|||||||
def install(self, project, renderer):
|
def install(self, project, renderer):
|
||||||
metadata = self.fetch_metadata(project, renderer)
|
metadata = self.fetch_metadata(project, renderer)
|
||||||
|
|
||||||
tar_name = '{}.{}.tar.gz'.format(self.package, self.version)
|
tar_name = "{}.{}.tar.gz".format(self.package, self.version)
|
||||||
tar_path = os.path.realpath(
|
tar_path = os.path.realpath(os.path.join(get_downloads_path(), tar_name))
|
||||||
os.path.join(get_downloads_path(), tar_name)
|
|
||||||
)
|
|
||||||
system.make_directory(os.path.dirname(tar_path))
|
system.make_directory(os.path.dirname(tar_path))
|
||||||
|
|
||||||
download_url = metadata.downloads.tarball
|
download_url = metadata.downloads.tarball
|
||||||
@@ -70,9 +68,7 @@ class RegistryPinnedPackage(RegistryPackageMixin, PinnedPackage):
|
|||||||
class RegistryUnpinnedPackage(
|
class RegistryUnpinnedPackage(
|
||||||
RegistryPackageMixin, UnpinnedPackage[RegistryPinnedPackage]
|
RegistryPackageMixin, UnpinnedPackage[RegistryPinnedPackage]
|
||||||
):
|
):
|
||||||
def __init__(
|
def __init__(self, package: str, versions: List[semver.VersionSpecifier]) -> None:
|
||||||
self, package: str, versions: List[semver.VersionSpecifier]
|
|
||||||
) -> None:
|
|
||||||
super().__init__(package)
|
super().__init__(package)
|
||||||
self.versions = versions
|
self.versions = versions
|
||||||
|
|
||||||
@@ -82,20 +78,15 @@ class RegistryUnpinnedPackage(
|
|||||||
package_not_found(self.package)
|
package_not_found(self.package)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_contract(
|
def from_contract(cls, contract: RegistryPackage) -> "RegistryUnpinnedPackage":
|
||||||
cls, contract: RegistryPackage
|
|
||||||
) -> 'RegistryUnpinnedPackage':
|
|
||||||
raw_version = contract.get_versions()
|
raw_version = contract.get_versions()
|
||||||
|
|
||||||
versions = [
|
versions = [semver.VersionSpecifier.from_version_string(v) for v in raw_version]
|
||||||
semver.VersionSpecifier.from_version_string(v)
|
|
||||||
for v in raw_version
|
|
||||||
]
|
|
||||||
return cls(package=contract.package, versions=versions)
|
return cls(package=contract.package, versions=versions)
|
||||||
|
|
||||||
def incorporate(
|
def incorporate(
|
||||||
self, other: 'RegistryUnpinnedPackage'
|
self, other: "RegistryUnpinnedPackage"
|
||||||
) -> 'RegistryUnpinnedPackage':
|
) -> "RegistryUnpinnedPackage":
|
||||||
return RegistryUnpinnedPackage(
|
return RegistryUnpinnedPackage(
|
||||||
package=self.package,
|
package=self.package,
|
||||||
versions=self.versions + other.versions,
|
versions=self.versions + other.versions,
|
||||||
@@ -106,8 +97,7 @@ class RegistryUnpinnedPackage(
|
|||||||
try:
|
try:
|
||||||
range_ = semver.reduce_versions(*self.versions)
|
range_ = semver.reduce_versions(*self.versions)
|
||||||
except VersionsNotCompatibleException as e:
|
except VersionsNotCompatibleException as e:
|
||||||
new_msg = ('Version error for package {}: {}'
|
new_msg = "Version error for package {}: {}".format(self.name, e)
|
||||||
.format(self.name, e))
|
|
||||||
raise DependencyException(new_msg) from e
|
raise DependencyException(new_msg) from e
|
||||||
|
|
||||||
available = registry.get_available_versions(self.package)
|
available = registry.get_available_versions(self.package)
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from dbt.exceptions import raise_dependency_error, InternalException
|
|||||||
from dbt.context.target import generate_target_context
|
from dbt.context.target import generate_target_context
|
||||||
from dbt.config import Project, RuntimeConfig
|
from dbt.config import Project, RuntimeConfig
|
||||||
from dbt.config.renderer import DbtProjectYamlRenderer
|
from dbt.config.renderer import DbtProjectYamlRenderer
|
||||||
from dbt.deps.base import BasePackage, PinnedPackage, UnpinnedPackage
|
from dbt.deps import BasePackage, PinnedPackage, UnpinnedPackage
|
||||||
from dbt.deps.local import LocalUnpinnedPackage
|
from dbt.deps.local import LocalUnpinnedPackage
|
||||||
from dbt.deps.git import GitUnpinnedPackage
|
from dbt.deps.git import GitUnpinnedPackage
|
||||||
from dbt.deps.registry import RegistryUnpinnedPackage
|
from dbt.deps.registry import RegistryUnpinnedPackage
|
||||||
@@ -49,12 +49,10 @@ class PackageListing:
|
|||||||
key_str: str = self._pick_key(key)
|
key_str: str = self._pick_key(key)
|
||||||
self.packages[key_str] = value
|
self.packages[key_str] = value
|
||||||
|
|
||||||
def _mismatched_types(
|
def _mismatched_types(self, old: UnpinnedPackage, new: UnpinnedPackage) -> NoReturn:
|
||||||
self, old: UnpinnedPackage, new: UnpinnedPackage
|
|
||||||
) -> NoReturn:
|
|
||||||
raise_dependency_error(
|
raise_dependency_error(
|
||||||
f'Cannot incorporate {new} ({new.__class__.__name__}) in {old} '
|
f"Cannot incorporate {new} ({new.__class__.__name__}) in {old} "
|
||||||
f'({old.__class__.__name__}): mismatched types'
|
f"({old.__class__.__name__}): mismatched types"
|
||||||
)
|
)
|
||||||
|
|
||||||
def incorporate(self, package: UnpinnedPackage):
|
def incorporate(self, package: UnpinnedPackage):
|
||||||
@@ -78,14 +76,14 @@ class PackageListing:
|
|||||||
pkg = RegistryUnpinnedPackage.from_contract(contract)
|
pkg = RegistryUnpinnedPackage.from_contract(contract)
|
||||||
else:
|
else:
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
'Invalid package type {}'.format(type(contract))
|
"Invalid package type {}".format(type(contract))
|
||||||
)
|
)
|
||||||
self.incorporate(pkg)
|
self.incorporate(pkg)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_contracts(
|
def from_contracts(
|
||||||
cls: Type['PackageListing'], src: List[PackageContract]
|
cls: Type["PackageListing"], src: List[PackageContract]
|
||||||
) -> 'PackageListing':
|
) -> "PackageListing":
|
||||||
self = cls({})
|
self = cls({})
|
||||||
self.update_from(src)
|
self.update_from(src)
|
||||||
return self
|
return self
|
||||||
@@ -108,14 +106,14 @@ def _check_for_duplicate_project_names(
|
|||||||
if project_name in seen:
|
if project_name in seen:
|
||||||
raise_dependency_error(
|
raise_dependency_error(
|
||||||
f'Found duplicate project "{project_name}". This occurs when '
|
f'Found duplicate project "{project_name}". This occurs when '
|
||||||
'a dependency has the same project name as some other '
|
"a dependency has the same project name as some other "
|
||||||
'dependency.'
|
"dependency."
|
||||||
)
|
)
|
||||||
elif project_name == config.project_name:
|
elif project_name == config.project_name:
|
||||||
raise_dependency_error(
|
raise_dependency_error(
|
||||||
'Found a dependency with the same name as the root project '
|
"Found a dependency with the same name as the root project "
|
||||||
f'"{project_name}". Package names must be unique in a project.'
|
f'"{project_name}". Package names must be unique in a project.'
|
||||||
' Please rename one of these packages.'
|
" Please rename one of these packages."
|
||||||
)
|
)
|
||||||
seen.add(project_name)
|
seen.add(project_name)
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
if os.name != 'nt':
|
|
||||||
|
if os.name != "nt":
|
||||||
# https://bugs.python.org/issue41567
|
# https://bugs.python.org/issue41567
|
||||||
import multiprocessing.popen_spawn_posix # type: ignore
|
import multiprocessing.popen_spawn_posix # type: ignore
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -23,7 +24,7 @@ def env_set_truthy(key: str) -> Optional[str]:
|
|||||||
otherwise.
|
otherwise.
|
||||||
"""
|
"""
|
||||||
value = os.getenv(key)
|
value = os.getenv(key)
|
||||||
if not value or value.lower() in ('0', 'false', 'f'):
|
if not value or value.lower() in ("0", "false", "f"):
|
||||||
return None
|
return None
|
||||||
return value
|
return value
|
||||||
|
|
||||||
@@ -36,24 +37,23 @@ def env_set_path(key: str) -> Optional[Path]:
|
|||||||
return Path(value)
|
return Path(value)
|
||||||
|
|
||||||
|
|
||||||
SINGLE_THREADED_WEBSERVER = env_set_truthy('DBT_SINGLE_THREADED_WEBSERVER')
|
SINGLE_THREADED_WEBSERVER = env_set_truthy("DBT_SINGLE_THREADED_WEBSERVER")
|
||||||
SINGLE_THREADED_HANDLER = env_set_truthy('DBT_SINGLE_THREADED_HANDLER')
|
SINGLE_THREADED_HANDLER = env_set_truthy("DBT_SINGLE_THREADED_HANDLER")
|
||||||
MACRO_DEBUGGING = env_set_truthy('DBT_MACRO_DEBUGGING')
|
MACRO_DEBUGGING = env_set_truthy("DBT_MACRO_DEBUGGING")
|
||||||
DEFER_MODE = env_set_truthy('DBT_DEFER_TO_STATE')
|
DEFER_MODE = env_set_truthy("DBT_DEFER_TO_STATE")
|
||||||
ARTIFACT_STATE_PATH = env_set_path('DBT_ARTIFACT_STATE_PATH')
|
ARTIFACT_STATE_PATH = env_set_path("DBT_ARTIFACT_STATE_PATH")
|
||||||
|
|
||||||
|
|
||||||
def _get_context():
|
def _get_context():
|
||||||
# TODO: change this back to use fork() on linux when we have made that safe
|
# TODO: change this back to use fork() on linux when we have made that safe
|
||||||
return multiprocessing.get_context('spawn')
|
return multiprocessing.get_context("spawn")
|
||||||
|
|
||||||
|
|
||||||
MP_CONTEXT = _get_context()
|
MP_CONTEXT = _get_context()
|
||||||
|
|
||||||
|
|
||||||
def reset():
|
def reset():
|
||||||
global STRICT_MODE, FULL_REFRESH, USE_CACHE, WARN_ERROR, TEST_NEW_PARSER, \
|
global STRICT_MODE, FULL_REFRESH, USE_CACHE, WARN_ERROR, TEST_NEW_PARSER, WRITE_JSON, PARTIAL_PARSE, MP_CONTEXT, USE_COLORS
|
||||||
WRITE_JSON, PARTIAL_PARSE, MP_CONTEXT, USE_COLORS
|
|
||||||
|
|
||||||
STRICT_MODE = False
|
STRICT_MODE = False
|
||||||
FULL_REFRESH = False
|
FULL_REFRESH = False
|
||||||
@@ -67,26 +67,22 @@ def reset():
|
|||||||
|
|
||||||
|
|
||||||
def set_from_args(args):
|
def set_from_args(args):
|
||||||
global STRICT_MODE, FULL_REFRESH, USE_CACHE, WARN_ERROR, TEST_NEW_PARSER, \
|
global STRICT_MODE, FULL_REFRESH, USE_CACHE, WARN_ERROR, TEST_NEW_PARSER, WRITE_JSON, PARTIAL_PARSE, MP_CONTEXT, USE_COLORS
|
||||||
WRITE_JSON, PARTIAL_PARSE, MP_CONTEXT, USE_COLORS
|
|
||||||
|
|
||||||
USE_CACHE = getattr(args, 'use_cache', USE_CACHE)
|
USE_CACHE = getattr(args, "use_cache", USE_CACHE)
|
||||||
|
|
||||||
FULL_REFRESH = getattr(args, 'full_refresh', FULL_REFRESH)
|
FULL_REFRESH = getattr(args, "full_refresh", FULL_REFRESH)
|
||||||
STRICT_MODE = getattr(args, 'strict', STRICT_MODE)
|
STRICT_MODE = getattr(args, "strict", STRICT_MODE)
|
||||||
WARN_ERROR = (
|
WARN_ERROR = STRICT_MODE or getattr(args, "warn_error", STRICT_MODE or WARN_ERROR)
|
||||||
STRICT_MODE or
|
|
||||||
getattr(args, 'warn_error', STRICT_MODE or WARN_ERROR)
|
|
||||||
)
|
|
||||||
|
|
||||||
TEST_NEW_PARSER = getattr(args, 'test_new_parser', TEST_NEW_PARSER)
|
TEST_NEW_PARSER = getattr(args, "test_new_parser", TEST_NEW_PARSER)
|
||||||
WRITE_JSON = getattr(args, 'write_json', WRITE_JSON)
|
WRITE_JSON = getattr(args, "write_json", WRITE_JSON)
|
||||||
PARTIAL_PARSE = getattr(args, 'partial_parse', None)
|
PARTIAL_PARSE = getattr(args, "partial_parse", None)
|
||||||
MP_CONTEXT = _get_context()
|
MP_CONTEXT = _get_context()
|
||||||
|
|
||||||
# The use_colors attribute will always have a value because it is assigned
|
# The use_colors attribute will always have a value because it is assigned
|
||||||
# None by default from the add_mutually_exclusive_group function
|
# None by default from the add_mutually_exclusive_group function
|
||||||
use_colors_override = getattr(args, 'use_colors')
|
use_colors_override = getattr(args, "use_colors")
|
||||||
|
|
||||||
if use_colors_override is not None:
|
if use_colors_override is not None:
|
||||||
USE_COLORS = use_colors_override
|
USE_COLORS = use_colors_override
|
||||||
|
|||||||
@@ -2,9 +2,7 @@
|
|||||||
import itertools
|
import itertools
|
||||||
from dbt.clients.yaml_helper import yaml, Loader, Dumper # noqa: F401
|
from dbt.clients.yaml_helper import yaml, Loader, Dumper # noqa: F401
|
||||||
|
|
||||||
from typing import (
|
from typing import Dict, List, Optional, Tuple, Any, Union
|
||||||
Dict, List, Optional, Tuple, Any, Union
|
|
||||||
)
|
|
||||||
|
|
||||||
from dbt.contracts.selection import SelectorDefinition, SelectorFile
|
from dbt.contracts.selection import SelectorDefinition, SelectorFile
|
||||||
from dbt.exceptions import InternalException, ValidationException
|
from dbt.exceptions import InternalException, ValidationException
|
||||||
@@ -17,21 +15,17 @@ from .selector_spec import (
|
|||||||
SelectionCriteria,
|
SelectionCriteria,
|
||||||
)
|
)
|
||||||
|
|
||||||
INTERSECTION_DELIMITER = ','
|
INTERSECTION_DELIMITER = ","
|
||||||
|
|
||||||
DEFAULT_INCLUDES: List[str] = ['fqn:*', 'source:*', 'exposure:*']
|
DEFAULT_INCLUDES: List[str] = ["fqn:*", "source:*", "exposure:*"]
|
||||||
DEFAULT_EXCLUDES: List[str] = []
|
DEFAULT_EXCLUDES: List[str] = []
|
||||||
DATA_TEST_SELECTOR: str = 'test_type:data'
|
DATA_TEST_SELECTOR: str = "test_type:data"
|
||||||
SCHEMA_TEST_SELECTOR: str = 'test_type:schema'
|
SCHEMA_TEST_SELECTOR: str = "test_type:schema"
|
||||||
|
|
||||||
|
|
||||||
def parse_union(
|
def parse_union(components: List[str], expect_exists: bool) -> SelectionUnion:
|
||||||
components: List[str], expect_exists: bool
|
|
||||||
) -> SelectionUnion:
|
|
||||||
# turn ['a b', 'c'] -> ['a', 'b', 'c']
|
# turn ['a b', 'c'] -> ['a', 'b', 'c']
|
||||||
raw_specs = itertools.chain.from_iterable(
|
raw_specs = itertools.chain.from_iterable(r.split(" ") for r in components)
|
||||||
r.split(' ') for r in components
|
|
||||||
)
|
|
||||||
union_components: List[SelectionSpec] = []
|
union_components: List[SelectionSpec] = []
|
||||||
|
|
||||||
# ['a', 'b', 'c,d'] -> union('a', 'b', intersection('c', 'd'))
|
# ['a', 'b', 'c,d'] -> union('a', 'b', intersection('c', 'd'))
|
||||||
@@ -40,11 +34,13 @@ def parse_union(
|
|||||||
SelectionCriteria.from_single_spec(part)
|
SelectionCriteria.from_single_spec(part)
|
||||||
for part in raw_spec.split(INTERSECTION_DELIMITER)
|
for part in raw_spec.split(INTERSECTION_DELIMITER)
|
||||||
]
|
]
|
||||||
union_components.append(SelectionIntersection(
|
union_components.append(
|
||||||
|
SelectionIntersection(
|
||||||
components=intersection_components,
|
components=intersection_components,
|
||||||
expect_exists=expect_exists,
|
expect_exists=expect_exists,
|
||||||
raw=raw_spec,
|
raw=raw_spec,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return SelectionUnion(
|
return SelectionUnion(
|
||||||
components=union_components,
|
components=union_components,
|
||||||
@@ -78,9 +74,7 @@ def parse_test_selectors(
|
|||||||
union_components = []
|
union_components = []
|
||||||
|
|
||||||
if data:
|
if data:
|
||||||
union_components.append(
|
union_components.append(SelectionCriteria.from_single_spec(DATA_TEST_SELECTOR))
|
||||||
SelectionCriteria.from_single_spec(DATA_TEST_SELECTOR)
|
|
||||||
)
|
|
||||||
if schema:
|
if schema:
|
||||||
union_components.append(
|
union_components.append(
|
||||||
SelectionCriteria.from_single_spec(SCHEMA_TEST_SELECTOR)
|
SelectionCriteria.from_single_spec(SCHEMA_TEST_SELECTOR)
|
||||||
@@ -98,27 +92,21 @@ def parse_test_selectors(
|
|||||||
raw=[DATA_TEST_SELECTOR, SCHEMA_TEST_SELECTOR],
|
raw=[DATA_TEST_SELECTOR, SCHEMA_TEST_SELECTOR],
|
||||||
)
|
)
|
||||||
|
|
||||||
return SelectionIntersection(
|
return SelectionIntersection(components=[base, intersect_with], expect_exists=True)
|
||||||
components=[base, intersect_with], expect_exists=True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
RawDefinition = Union[str, Dict[str, Any]]
|
RawDefinition = Union[str, Dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
def _get_list_dicts(
|
def _get_list_dicts(dct: Dict[str, Any], key: str) -> List[RawDefinition]:
|
||||||
dct: Dict[str, Any], key: str
|
|
||||||
) -> List[RawDefinition]:
|
|
||||||
result: List[RawDefinition] = []
|
result: List[RawDefinition] = []
|
||||||
if key not in dct:
|
if key not in dct:
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
f'Expected to find key {key} in dict, only found {list(dct)}'
|
f"Expected to find key {key} in dict, only found {list(dct)}"
|
||||||
)
|
)
|
||||||
values = dct[key]
|
values = dct[key]
|
||||||
if not isinstance(values, list):
|
if not isinstance(values, list):
|
||||||
raise ValidationException(
|
raise ValidationException(f'Invalid value for key "{key}". Expected a list.')
|
||||||
f'Invalid value for key "{key}". Expected a list.'
|
|
||||||
)
|
|
||||||
for value in values:
|
for value in values:
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
for value_key in value:
|
for value_key in value:
|
||||||
@@ -133,36 +121,31 @@ def _get_list_dicts(
|
|||||||
else:
|
else:
|
||||||
raise ValidationException(
|
raise ValidationException(
|
||||||
f'Invalid value type {type(value)} in key "{key}", expected '
|
f'Invalid value type {type(value)} in key "{key}", expected '
|
||||||
f'dict or str (value: {value}).'
|
f"dict or str (value: {value})."
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _parse_exclusions(definition) -> Optional[SelectionSpec]:
|
def _parse_exclusions(definition) -> Optional[SelectionSpec]:
|
||||||
exclusions = _get_list_dicts(definition, 'exclude')
|
exclusions = _get_list_dicts(definition, "exclude")
|
||||||
parsed_exclusions = [
|
parsed_exclusions = [parse_from_definition(excl) for excl in exclusions]
|
||||||
parse_from_definition(excl) for excl in exclusions
|
|
||||||
]
|
|
||||||
if len(parsed_exclusions) == 1:
|
if len(parsed_exclusions) == 1:
|
||||||
return parsed_exclusions[0]
|
return parsed_exclusions[0]
|
||||||
elif len(parsed_exclusions) > 1:
|
elif len(parsed_exclusions) > 1:
|
||||||
return SelectionUnion(
|
return SelectionUnion(components=parsed_exclusions, raw=exclusions)
|
||||||
components=parsed_exclusions,
|
|
||||||
raw=exclusions
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _parse_include_exclude_subdefs(
|
def _parse_include_exclude_subdefs(
|
||||||
definitions: List[RawDefinition]
|
definitions: List[RawDefinition],
|
||||||
) -> Tuple[List[SelectionSpec], Optional[SelectionSpec]]:
|
) -> Tuple[List[SelectionSpec], Optional[SelectionSpec]]:
|
||||||
include_parts: List[SelectionSpec] = []
|
include_parts: List[SelectionSpec] = []
|
||||||
diff_arg: Optional[SelectionSpec] = None
|
diff_arg: Optional[SelectionSpec] = None
|
||||||
|
|
||||||
for definition in definitions:
|
for definition in definitions:
|
||||||
if isinstance(definition, dict) and 'exclude' in definition:
|
if isinstance(definition, dict) and "exclude" in definition:
|
||||||
# do not allow multiple exclude: defs at the same level
|
# do not allow multiple exclude: defs at the same level
|
||||||
if diff_arg is not None:
|
if diff_arg is not None:
|
||||||
yaml_sel_cfg = yaml.dump(definition)
|
yaml_sel_cfg = yaml.dump(definition)
|
||||||
@@ -178,7 +161,7 @@ def _parse_include_exclude_subdefs(
|
|||||||
|
|
||||||
|
|
||||||
def parse_union_definition(definition: Dict[str, Any]) -> SelectionSpec:
|
def parse_union_definition(definition: Dict[str, Any]) -> SelectionSpec:
|
||||||
union_def_parts = _get_list_dicts(definition, 'union')
|
union_def_parts = _get_list_dicts(definition, "union")
|
||||||
include, exclude = _parse_include_exclude_subdefs(union_def_parts)
|
include, exclude = _parse_include_exclude_subdefs(union_def_parts)
|
||||||
|
|
||||||
union = SelectionUnion(components=include)
|
union = SelectionUnion(components=include)
|
||||||
@@ -187,16 +170,11 @@ def parse_union_definition(definition: Dict[str, Any]) -> SelectionSpec:
|
|||||||
union.raw = definition
|
union.raw = definition
|
||||||
return union
|
return union
|
||||||
else:
|
else:
|
||||||
return SelectionDifference(
|
return SelectionDifference(components=[union, exclude], raw=definition)
|
||||||
components=[union, exclude],
|
|
||||||
raw=definition
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_intersection_definition(
|
def parse_intersection_definition(definition: Dict[str, Any]) -> SelectionSpec:
|
||||||
definition: Dict[str, Any]
|
intersection_def_parts = _get_list_dicts(definition, "intersection")
|
||||||
) -> SelectionSpec:
|
|
||||||
intersection_def_parts = _get_list_dicts(definition, 'intersection')
|
|
||||||
include, exclude = _parse_include_exclude_subdefs(intersection_def_parts)
|
include, exclude = _parse_include_exclude_subdefs(intersection_def_parts)
|
||||||
intersection = SelectionIntersection(components=include)
|
intersection = SelectionIntersection(components=include)
|
||||||
|
|
||||||
@@ -204,10 +182,7 @@ def parse_intersection_definition(
|
|||||||
intersection.raw = definition
|
intersection.raw = definition
|
||||||
return intersection
|
return intersection
|
||||||
else:
|
else:
|
||||||
return SelectionDifference(
|
return SelectionDifference(components=[intersection, exclude], raw=definition)
|
||||||
components=[intersection, exclude],
|
|
||||||
raw=definition
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_dict_definition(definition: Dict[str, Any]) -> SelectionSpec:
|
def parse_dict_definition(definition: Dict[str, Any]) -> SelectionSpec:
|
||||||
@@ -221,14 +196,14 @@ def parse_dict_definition(definition: Dict[str, Any]) -> SelectionSpec:
|
|||||||
f'"{type(key)}" ({key})'
|
f'"{type(key)}" ({key})'
|
||||||
)
|
)
|
||||||
dct = {
|
dct = {
|
||||||
'method': key,
|
"method": key,
|
||||||
'value': value,
|
"value": value,
|
||||||
}
|
}
|
||||||
elif 'method' in definition and 'value' in definition:
|
elif "method" in definition and "value" in definition:
|
||||||
dct = definition
|
dct = definition
|
||||||
if 'exclude' in definition:
|
if "exclude" in definition:
|
||||||
diff_arg = _parse_exclusions(definition)
|
diff_arg = _parse_exclusions(definition)
|
||||||
dct = {k: v for k, v in dct.items() if k != 'exclude'}
|
dct = {k: v for k, v in dct.items() if k != "exclude"}
|
||||||
else:
|
else:
|
||||||
raise ValidationException(
|
raise ValidationException(
|
||||||
f'Expected either 1 key or else "method" '
|
f'Expected either 1 key or else "method" '
|
||||||
@@ -243,13 +218,14 @@ def parse_dict_definition(definition: Dict[str, Any]) -> SelectionSpec:
|
|||||||
return SelectionDifference(components=[base, diff_arg])
|
return SelectionDifference(components=[base, diff_arg])
|
||||||
|
|
||||||
|
|
||||||
def parse_from_definition(
|
def parse_from_definition(definition: RawDefinition, rootlevel=False) -> SelectionSpec:
|
||||||
definition: RawDefinition, rootlevel=False
|
|
||||||
) -> SelectionSpec:
|
|
||||||
|
|
||||||
if (isinstance(definition, dict) and
|
if (
|
||||||
('union' in definition or 'intersection' in definition) and
|
isinstance(definition, dict)
|
||||||
rootlevel and len(definition) > 1):
|
and ("union" in definition or "intersection" in definition)
|
||||||
|
and rootlevel
|
||||||
|
and len(definition) > 1
|
||||||
|
):
|
||||||
keys = ",".join(definition.keys())
|
keys = ",".join(definition.keys())
|
||||||
raise ValidationException(
|
raise ValidationException(
|
||||||
f"Only a single 'union' or 'intersection' key is allowed "
|
f"Only a single 'union' or 'intersection' key is allowed "
|
||||||
@@ -257,25 +233,24 @@ def parse_from_definition(
|
|||||||
)
|
)
|
||||||
if isinstance(definition, str):
|
if isinstance(definition, str):
|
||||||
return SelectionCriteria.from_single_spec(definition)
|
return SelectionCriteria.from_single_spec(definition)
|
||||||
elif 'union' in definition:
|
elif "union" in definition:
|
||||||
return parse_union_definition(definition)
|
return parse_union_definition(definition)
|
||||||
elif 'intersection' in definition:
|
elif "intersection" in definition:
|
||||||
return parse_intersection_definition(definition)
|
return parse_intersection_definition(definition)
|
||||||
elif isinstance(definition, dict):
|
elif isinstance(definition, dict):
|
||||||
return parse_dict_definition(definition)
|
return parse_dict_definition(definition)
|
||||||
else:
|
else:
|
||||||
raise ValidationException(
|
raise ValidationException(
|
||||||
f'Expected to find union, intersection, str or dict, instead '
|
f"Expected to find union, intersection, str or dict, instead "
|
||||||
f'found {type(definition)}: {definition}'
|
f"found {type(definition)}: {definition}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_from_selectors_definition(
|
def parse_from_selectors_definition(source: SelectorFile) -> Dict[str, SelectionSpec]:
|
||||||
source: SelectorFile
|
|
||||||
) -> Dict[str, SelectionSpec]:
|
|
||||||
result: Dict[str, SelectionSpec] = {}
|
result: Dict[str, SelectionSpec] = {}
|
||||||
selector: SelectorDefinition
|
selector: SelectorDefinition
|
||||||
for selector in source.selectors:
|
for selector in source.selectors:
|
||||||
result[selector.name] = parse_from_definition(selector.definition,
|
result[selector.name] = parse_from_definition(
|
||||||
rootlevel=True)
|
selector.definition, rootlevel=True
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -1,17 +1,16 @@
|
|||||||
from typing import (
|
from typing import Set, Iterable, Iterator, Optional, NewType
|
||||||
Set, Iterable, Iterator, Optional, NewType
|
|
||||||
)
|
|
||||||
import networkx as nx # type: ignore
|
import networkx as nx # type: ignore
|
||||||
|
|
||||||
from dbt.exceptions import InternalException
|
from dbt.exceptions import InternalException
|
||||||
|
|
||||||
UniqueId = NewType('UniqueId', str)
|
UniqueId = NewType("UniqueId", str)
|
||||||
|
|
||||||
|
|
||||||
class Graph:
|
class Graph:
|
||||||
"""A wrapper around the networkx graph that understands SelectionCriteria
|
"""A wrapper around the networkx graph that understands SelectionCriteria
|
||||||
and how they interact with the graph.
|
and how they interact with the graph.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, graph):
|
def __init__(self, graph):
|
||||||
self.graph = graph
|
self.graph = graph
|
||||||
|
|
||||||
@@ -29,12 +28,11 @@ class Graph:
|
|||||||
) -> Set[UniqueId]:
|
) -> Set[UniqueId]:
|
||||||
"""Returns all nodes having a path to `node` in `graph`"""
|
"""Returns all nodes having a path to `node` in `graph`"""
|
||||||
if not self.graph.has_node(node):
|
if not self.graph.has_node(node):
|
||||||
raise InternalException(f'Node {node} not found in the graph!')
|
raise InternalException(f"Node {node} not found in the graph!")
|
||||||
with nx.utils.reversed(self.graph):
|
with nx.utils.reversed(self.graph):
|
||||||
anc = nx.single_source_shortest_path_length(G=self.graph,
|
anc = nx.single_source_shortest_path_length(
|
||||||
source=node,
|
G=self.graph, source=node, cutoff=max_depth
|
||||||
cutoff=max_depth)\
|
).keys()
|
||||||
.keys()
|
|
||||||
return anc - {node}
|
return anc - {node}
|
||||||
|
|
||||||
def descendants(
|
def descendants(
|
||||||
@@ -42,16 +40,13 @@ class Graph:
|
|||||||
) -> Set[UniqueId]:
|
) -> Set[UniqueId]:
|
||||||
"""Returns all nodes reachable from `node` in `graph`"""
|
"""Returns all nodes reachable from `node` in `graph`"""
|
||||||
if not self.graph.has_node(node):
|
if not self.graph.has_node(node):
|
||||||
raise InternalException(f'Node {node} not found in the graph!')
|
raise InternalException(f"Node {node} not found in the graph!")
|
||||||
des = nx.single_source_shortest_path_length(G=self.graph,
|
des = nx.single_source_shortest_path_length(
|
||||||
source=node,
|
G=self.graph, source=node, cutoff=max_depth
|
||||||
cutoff=max_depth)\
|
).keys()
|
||||||
.keys()
|
|
||||||
return des - {node}
|
return des - {node}
|
||||||
|
|
||||||
def select_childrens_parents(
|
def select_childrens_parents(self, selected: Set[UniqueId]) -> Set[UniqueId]:
|
||||||
self, selected: Set[UniqueId]
|
|
||||||
) -> Set[UniqueId]:
|
|
||||||
ancestors_for = self.select_children(selected) | selected
|
ancestors_for = self.select_children(selected) | selected
|
||||||
return self.select_parents(ancestors_for) | ancestors_for
|
return self.select_parents(ancestors_for) | ancestors_for
|
||||||
|
|
||||||
@@ -77,7 +72,7 @@ class Graph:
|
|||||||
successors.update(self.graph.successors(node))
|
successors.update(self.graph.successors(node))
|
||||||
return successors
|
return successors
|
||||||
|
|
||||||
def get_subset_graph(self, selected: Iterable[UniqueId]) -> 'Graph':
|
def get_subset_graph(self, selected: Iterable[UniqueId]) -> "Graph":
|
||||||
"""Create and return a new graph that is a shallow copy of the graph,
|
"""Create and return a new graph that is a shallow copy of the graph,
|
||||||
but with only the nodes in include_nodes. Transitive edges across
|
but with only the nodes in include_nodes. Transitive edges across
|
||||||
removed nodes are preserved as explicit new edges.
|
removed nodes are preserved as explicit new edges.
|
||||||
@@ -98,7 +93,7 @@ class Graph:
|
|||||||
)
|
)
|
||||||
return Graph(new_graph)
|
return Graph(new_graph)
|
||||||
|
|
||||||
def subgraph(self, nodes: Iterable[UniqueId]) -> 'Graph':
|
def subgraph(self, nodes: Iterable[UniqueId]) -> "Graph":
|
||||||
return Graph(self.graph.subgraph(nodes))
|
return Graph(self.graph.subgraph(nodes))
|
||||||
|
|
||||||
def get_dependent_nodes(self, node: UniqueId):
|
def get_dependent_nodes(self, node: UniqueId):
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
import threading
|
import threading
|
||||||
from queue import PriorityQueue
|
from queue import PriorityQueue
|
||||||
from typing import (
|
from typing import Dict, Set, Optional
|
||||||
Dict, Set, Optional
|
|
||||||
)
|
|
||||||
|
|
||||||
import networkx as nx # type: ignore
|
import networkx as nx # type: ignore
|
||||||
|
|
||||||
@@ -21,9 +19,8 @@ class GraphQueue:
|
|||||||
that separate threads do not call `.empty()` or `__len__()` and `.get()` at
|
that separate threads do not call `.empty()` or `__len__()` and `.get()` at
|
||||||
the same time, as there is an unlocked race!
|
the same time, as there is an unlocked race!
|
||||||
"""
|
"""
|
||||||
def __init__(
|
|
||||||
self, graph: nx.DiGraph, manifest: Manifest, selected: Set[UniqueId]
|
def __init__(self, graph: nx.DiGraph, manifest: Manifest, selected: Set[UniqueId]):
|
||||||
):
|
|
||||||
self.graph = graph
|
self.graph = graph
|
||||||
self.manifest = manifest
|
self.manifest = manifest
|
||||||
self._selected = selected
|
self._selected = selected
|
||||||
@@ -75,10 +72,13 @@ class GraphQueue:
|
|||||||
"""
|
"""
|
||||||
scores = {}
|
scores = {}
|
||||||
for node in self.graph.nodes():
|
for node in self.graph.nodes():
|
||||||
score = -1 * len([
|
score = -1 * len(
|
||||||
d for d in nx.descendants(self.graph, node)
|
[
|
||||||
|
d
|
||||||
|
for d in nx.descendants(self.graph, node)
|
||||||
if self._include_in_cost(d)
|
if self._include_in_cost(d)
|
||||||
])
|
]
|
||||||
|
)
|
||||||
scores[node] = score
|
scores[node] = score
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
from typing import Set, List, Optional
|
from typing import Set, List, Optional
|
||||||
|
|
||||||
from .graph import Graph, UniqueId
|
from .graph import Graph, UniqueId
|
||||||
@@ -25,14 +24,13 @@ def get_package_names(nodes):
|
|||||||
def alert_non_existence(raw_spec, nodes):
|
def alert_non_existence(raw_spec, nodes):
|
||||||
if len(nodes) == 0:
|
if len(nodes) == 0:
|
||||||
warn_or_error(
|
warn_or_error(
|
||||||
f"The selection criterion '{str(raw_spec)}' does not match"
|
f"The selection criterion '{str(raw_spec)}' does not match" f" any nodes"
|
||||||
f" any nodes"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class NodeSelector(MethodManager):
|
class NodeSelector(MethodManager):
|
||||||
"""The node selector is aware of the graph and manifest,
|
"""The node selector is aware of the graph and manifest,"""
|
||||||
"""
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
graph: Graph,
|
graph: Graph,
|
||||||
@@ -45,13 +43,16 @@ class NodeSelector(MethodManager):
|
|||||||
# build a subgraph containing only non-empty, enabled nodes and enabled
|
# build a subgraph containing only non-empty, enabled nodes and enabled
|
||||||
# sources.
|
# sources.
|
||||||
graph_members = {
|
graph_members = {
|
||||||
unique_id for unique_id in self.full_graph.nodes()
|
unique_id
|
||||||
|
for unique_id in self.full_graph.nodes()
|
||||||
if self._is_graph_member(unique_id)
|
if self._is_graph_member(unique_id)
|
||||||
}
|
}
|
||||||
self.graph = self.full_graph.subgraph(graph_members)
|
self.graph = self.full_graph.subgraph(graph_members)
|
||||||
|
|
||||||
def select_included(
|
def select_included(
|
||||||
self, included_nodes: Set[UniqueId], spec: SelectionCriteria,
|
self,
|
||||||
|
included_nodes: Set[UniqueId],
|
||||||
|
spec: SelectionCriteria,
|
||||||
) -> Set[UniqueId]:
|
) -> Set[UniqueId]:
|
||||||
"""Select the explicitly included nodes, using the given spec. Return
|
"""Select the explicitly included nodes, using the given spec. Return
|
||||||
the selected set of unique IDs.
|
the selected set of unique IDs.
|
||||||
@@ -116,10 +117,7 @@ class NodeSelector(MethodManager):
|
|||||||
if isinstance(spec, SelectionCriteria):
|
if isinstance(spec, SelectionCriteria):
|
||||||
result = self.get_nodes_from_criteria(spec)
|
result = self.get_nodes_from_criteria(spec)
|
||||||
else:
|
else:
|
||||||
node_selections = [
|
node_selections = [self.select_nodes(component) for component in spec]
|
||||||
self.select_nodes(component)
|
|
||||||
for component in spec
|
|
||||||
]
|
|
||||||
result = spec.combined(node_selections)
|
result = spec.combined(node_selections)
|
||||||
if spec.expect_exists:
|
if spec.expect_exists:
|
||||||
alert_non_existence(spec.raw, result)
|
alert_non_existence(spec.raw, result)
|
||||||
@@ -149,18 +147,14 @@ class NodeSelector(MethodManager):
|
|||||||
elif unique_id in self.manifest.exposures:
|
elif unique_id in self.manifest.exposures:
|
||||||
node = self.manifest.exposures[unique_id]
|
node = self.manifest.exposures[unique_id]
|
||||||
else:
|
else:
|
||||||
raise InternalException(
|
raise InternalException(f"Node {unique_id} not found in the manifest!")
|
||||||
f'Node {unique_id} not found in the manifest!'
|
|
||||||
)
|
|
||||||
return self.node_is_match(node)
|
return self.node_is_match(node)
|
||||||
|
|
||||||
def filter_selection(self, selected: Set[UniqueId]) -> Set[UniqueId]:
|
def filter_selection(self, selected: Set[UniqueId]) -> Set[UniqueId]:
|
||||||
"""Return the subset of selected nodes that is a match for this
|
"""Return the subset of selected nodes that is a match for this
|
||||||
selector.
|
selector.
|
||||||
"""
|
"""
|
||||||
return {
|
return {unique_id for unique_id in selected if self._is_match(unique_id)}
|
||||||
unique_id for unique_id in selected if self._is_match(unique_id)
|
|
||||||
}
|
|
||||||
|
|
||||||
def expand_selection(self, selected: Set[UniqueId]) -> Set[UniqueId]:
|
def expand_selection(self, selected: Set[UniqueId]) -> Set[UniqueId]:
|
||||||
"""Perform selector-specific expansion."""
|
"""Perform selector-specific expansion."""
|
||||||
|
|||||||
@@ -31,28 +31,28 @@ from dbt.node_types import NodeType
|
|||||||
from dbt.ui import warning_tag
|
from dbt.ui import warning_tag
|
||||||
|
|
||||||
|
|
||||||
SELECTOR_GLOB = '*'
|
SELECTOR_GLOB = "*"
|
||||||
SELECTOR_DELIMITER = ':'
|
SELECTOR_DELIMITER = ":"
|
||||||
|
|
||||||
|
|
||||||
class MethodName(StrEnum):
|
class MethodName(StrEnum):
|
||||||
FQN = 'fqn'
|
FQN = "fqn"
|
||||||
Tag = 'tag'
|
Tag = "tag"
|
||||||
Source = 'source'
|
Source = "source"
|
||||||
Path = 'path'
|
Path = "path"
|
||||||
Package = 'package'
|
Package = "package"
|
||||||
Config = 'config'
|
Config = "config"
|
||||||
TestName = 'test_name'
|
TestName = "test_name"
|
||||||
TestType = 'test_type'
|
TestType = "test_type"
|
||||||
ResourceType = 'resource_type'
|
ResourceType = "resource_type"
|
||||||
State = 'state'
|
State = "state"
|
||||||
Exposure = 'exposure'
|
Exposure = "exposure"
|
||||||
|
|
||||||
|
|
||||||
def is_selected_node(real_node, node_selector):
|
def is_selected_node(real_node, node_selector):
|
||||||
for i, selector_part in enumerate(node_selector):
|
for i, selector_part in enumerate(node_selector):
|
||||||
|
|
||||||
is_last = (i == len(node_selector) - 1)
|
is_last = i == len(node_selector) - 1
|
||||||
|
|
||||||
# if we hit a GLOB, then this node is selected
|
# if we hit a GLOB, then this node is selected
|
||||||
if selector_part == SELECTOR_GLOB:
|
if selector_part == SELECTOR_GLOB:
|
||||||
@@ -83,15 +83,14 @@ class SelectorMethod(metaclass=abc.ABCMeta):
|
|||||||
self,
|
self,
|
||||||
manifest: Manifest,
|
manifest: Manifest,
|
||||||
previous_state: Optional[PreviousState],
|
previous_state: Optional[PreviousState],
|
||||||
arguments: List[str]
|
arguments: List[str],
|
||||||
):
|
):
|
||||||
self.manifest: Manifest = manifest
|
self.manifest: Manifest = manifest
|
||||||
self.previous_state = previous_state
|
self.previous_state = previous_state
|
||||||
self.arguments: List[str] = arguments
|
self.arguments: List[str] = arguments
|
||||||
|
|
||||||
def parsed_nodes(
|
def parsed_nodes(
|
||||||
self,
|
self, included_nodes: Set[UniqueId]
|
||||||
included_nodes: Set[UniqueId]
|
|
||||||
) -> Iterator[Tuple[UniqueId, ManifestNode]]:
|
) -> Iterator[Tuple[UniqueId, ManifestNode]]:
|
||||||
|
|
||||||
for key, node in self.manifest.nodes.items():
|
for key, node in self.manifest.nodes.items():
|
||||||
@@ -101,8 +100,7 @@ class SelectorMethod(metaclass=abc.ABCMeta):
|
|||||||
yield unique_id, node
|
yield unique_id, node
|
||||||
|
|
||||||
def source_nodes(
|
def source_nodes(
|
||||||
self,
|
self, included_nodes: Set[UniqueId]
|
||||||
included_nodes: Set[UniqueId]
|
|
||||||
) -> Iterator[Tuple[UniqueId, ParsedSourceDefinition]]:
|
) -> Iterator[Tuple[UniqueId, ParsedSourceDefinition]]:
|
||||||
|
|
||||||
for key, source in self.manifest.sources.items():
|
for key, source in self.manifest.sources.items():
|
||||||
@@ -112,8 +110,7 @@ class SelectorMethod(metaclass=abc.ABCMeta):
|
|||||||
yield unique_id, source
|
yield unique_id, source
|
||||||
|
|
||||||
def exposure_nodes(
|
def exposure_nodes(
|
||||||
self,
|
self, included_nodes: Set[UniqueId]
|
||||||
included_nodes: Set[UniqueId]
|
|
||||||
) -> Iterator[Tuple[UniqueId, ParsedExposure]]:
|
) -> Iterator[Tuple[UniqueId, ParsedExposure]]:
|
||||||
|
|
||||||
for key, exposure in self.manifest.exposures.items():
|
for key, exposure in self.manifest.exposures.items():
|
||||||
@@ -123,26 +120,28 @@ class SelectorMethod(metaclass=abc.ABCMeta):
|
|||||||
yield unique_id, exposure
|
yield unique_id, exposure
|
||||||
|
|
||||||
def all_nodes(
|
def all_nodes(
|
||||||
self,
|
self, included_nodes: Set[UniqueId]
|
||||||
included_nodes: Set[UniqueId]
|
|
||||||
) -> Iterator[Tuple[UniqueId, SelectorTarget]]:
|
) -> Iterator[Tuple[UniqueId, SelectorTarget]]:
|
||||||
yield from chain(self.parsed_nodes(included_nodes),
|
yield from chain(
|
||||||
|
self.parsed_nodes(included_nodes),
|
||||||
self.source_nodes(included_nodes),
|
self.source_nodes(included_nodes),
|
||||||
self.exposure_nodes(included_nodes))
|
self.exposure_nodes(included_nodes),
|
||||||
|
)
|
||||||
|
|
||||||
def configurable_nodes(
|
def configurable_nodes(
|
||||||
self,
|
self, included_nodes: Set[UniqueId]
|
||||||
included_nodes: Set[UniqueId]
|
|
||||||
) -> Iterator[Tuple[UniqueId, CompileResultNode]]:
|
) -> Iterator[Tuple[UniqueId, CompileResultNode]]:
|
||||||
yield from chain(self.parsed_nodes(included_nodes),
|
yield from chain(
|
||||||
self.source_nodes(included_nodes))
|
self.parsed_nodes(included_nodes), self.source_nodes(included_nodes)
|
||||||
|
)
|
||||||
|
|
||||||
def non_source_nodes(
|
def non_source_nodes(
|
||||||
self,
|
self,
|
||||||
included_nodes: Set[UniqueId],
|
included_nodes: Set[UniqueId],
|
||||||
) -> Iterator[Tuple[UniqueId, Union[ParsedExposure, ManifestNode]]]:
|
) -> Iterator[Tuple[UniqueId, Union[ParsedExposure, ManifestNode]]]:
|
||||||
yield from chain(self.parsed_nodes(included_nodes),
|
yield from chain(
|
||||||
self.exposure_nodes(included_nodes))
|
self.parsed_nodes(included_nodes), self.exposure_nodes(included_nodes)
|
||||||
|
)
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def search(
|
def search(
|
||||||
@@ -150,7 +149,7 @@ class SelectorMethod(metaclass=abc.ABCMeta):
|
|||||||
included_nodes: Set[UniqueId],
|
included_nodes: Set[UniqueId],
|
||||||
selector: str,
|
selector: str,
|
||||||
) -> Iterator[UniqueId]:
|
) -> Iterator[UniqueId]:
|
||||||
raise NotImplementedError('subclasses should implement this')
|
raise NotImplementedError("subclasses should implement this")
|
||||||
|
|
||||||
|
|
||||||
class QualifiedNameSelectorMethod(SelectorMethod):
|
class QualifiedNameSelectorMethod(SelectorMethod):
|
||||||
@@ -216,7 +215,7 @@ class SourceSelectorMethod(SelectorMethod):
|
|||||||
self, included_nodes: Set[UniqueId], selector: str
|
self, included_nodes: Set[UniqueId], selector: str
|
||||||
) -> Iterator[UniqueId]:
|
) -> Iterator[UniqueId]:
|
||||||
"""yields nodes from included are the specified source."""
|
"""yields nodes from included are the specified source."""
|
||||||
parts = selector.split('.')
|
parts = selector.split(".")
|
||||||
target_package = SELECTOR_GLOB
|
target_package = SELECTOR_GLOB
|
||||||
if len(parts) == 1:
|
if len(parts) == 1:
|
||||||
target_source, target_table = parts[0], None
|
target_source, target_table = parts[0], None
|
||||||
@@ -227,9 +226,9 @@ class SourceSelectorMethod(SelectorMethod):
|
|||||||
else: # len(parts) > 3 or len(parts) == 0
|
else: # len(parts) > 3 or len(parts) == 0
|
||||||
msg = (
|
msg = (
|
||||||
'Invalid source selector value "{}". Sources must be of the '
|
'Invalid source selector value "{}". Sources must be of the '
|
||||||
'form `${{source_name}}`, '
|
"form `${{source_name}}`, "
|
||||||
'`${{source_name}}.${{target_name}}`, or '
|
"`${{source_name}}.${{target_name}}`, or "
|
||||||
'`${{package_name}}.${{source_name}}.${{target_name}}'
|
"`${{package_name}}.${{source_name}}.${{target_name}}"
|
||||||
).format(selector)
|
).format(selector)
|
||||||
raise RuntimeException(msg)
|
raise RuntimeException(msg)
|
||||||
|
|
||||||
@@ -248,7 +247,7 @@ class ExposureSelectorMethod(SelectorMethod):
|
|||||||
def search(
|
def search(
|
||||||
self, included_nodes: Set[UniqueId], selector: str
|
self, included_nodes: Set[UniqueId], selector: str
|
||||||
) -> Iterator[UniqueId]:
|
) -> Iterator[UniqueId]:
|
||||||
parts = selector.split('.')
|
parts = selector.split(".")
|
||||||
target_package = SELECTOR_GLOB
|
target_package = SELECTOR_GLOB
|
||||||
if len(parts) == 1:
|
if len(parts) == 1:
|
||||||
target_name = parts[0]
|
target_name = parts[0]
|
||||||
@@ -257,8 +256,8 @@ class ExposureSelectorMethod(SelectorMethod):
|
|||||||
else:
|
else:
|
||||||
msg = (
|
msg = (
|
||||||
'Invalid exposure selector value "{}". Exposures must be of '
|
'Invalid exposure selector value "{}". Exposures must be of '
|
||||||
'the form ${{exposure_name}} or '
|
"the form ${{exposure_name}} or "
|
||||||
'${{exposure_package.exposure_name}}'
|
"${{exposure_package.exposure_name}}"
|
||||||
).format(selector)
|
).format(selector)
|
||||||
raise RuntimeException(msg)
|
raise RuntimeException(msg)
|
||||||
|
|
||||||
@@ -275,9 +274,7 @@ class PathSelectorMethod(SelectorMethod):
|
|||||||
def search(
|
def search(
|
||||||
self, included_nodes: Set[UniqueId], selector: str
|
self, included_nodes: Set[UniqueId], selector: str
|
||||||
) -> Iterator[UniqueId]:
|
) -> Iterator[UniqueId]:
|
||||||
"""Yields nodes from inclucded that match the given path.
|
"""Yields nodes from inclucded that match the given path."""
|
||||||
|
|
||||||
"""
|
|
||||||
# use '.' and not 'root' for easy comparison
|
# use '.' and not 'root' for easy comparison
|
||||||
root = Path.cwd()
|
root = Path.cwd()
|
||||||
paths = set(p.relative_to(root) for p in root.glob(selector))
|
paths = set(p.relative_to(root) for p in root.glob(selector))
|
||||||
@@ -336,7 +333,7 @@ class ConfigSelectorMethod(SelectorMethod):
|
|||||||
parts = self.arguments
|
parts = self.arguments
|
||||||
# special case: if the user wanted to compare test severity,
|
# special case: if the user wanted to compare test severity,
|
||||||
# make the comparison case-insensitive
|
# make the comparison case-insensitive
|
||||||
if parts == ['severity']:
|
if parts == ["severity"]:
|
||||||
selector = CaseInsensitive(selector)
|
selector = CaseInsensitive(selector)
|
||||||
|
|
||||||
# search sources is kind of useless now source configs only have
|
# search sources is kind of useless now source configs only have
|
||||||
@@ -382,14 +379,13 @@ class TestTypeSelectorMethod(SelectorMethod):
|
|||||||
self, included_nodes: Set[UniqueId], selector: str
|
self, included_nodes: Set[UniqueId], selector: str
|
||||||
) -> Iterator[UniqueId]:
|
) -> Iterator[UniqueId]:
|
||||||
search_types: Tuple[Type, ...]
|
search_types: Tuple[Type, ...]
|
||||||
if selector == 'schema':
|
if selector == "schema":
|
||||||
search_types = (ParsedSchemaTestNode, CompiledSchemaTestNode)
|
search_types = (ParsedSchemaTestNode, CompiledSchemaTestNode)
|
||||||
elif selector == 'data':
|
elif selector == "data":
|
||||||
search_types = (ParsedDataTestNode, CompiledDataTestNode)
|
search_types = (ParsedDataTestNode, CompiledDataTestNode)
|
||||||
else:
|
else:
|
||||||
raise RuntimeException(
|
raise RuntimeException(
|
||||||
f'Invalid test type selector {selector}: expected "data" or '
|
f'Invalid test type selector {selector}: expected "data" or ' '"schema"'
|
||||||
'"schema"'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for node, real_node in self.parsed_nodes(included_nodes):
|
for node, real_node in self.parsed_nodes(included_nodes):
|
||||||
@@ -405,25 +401,23 @@ class StateSelectorMethod(SelectorMethod):
|
|||||||
def _macros_modified(self) -> List[str]:
|
def _macros_modified(self) -> List[str]:
|
||||||
# we checked in the caller!
|
# we checked in the caller!
|
||||||
if self.previous_state is None or self.previous_state.manifest is None:
|
if self.previous_state is None or self.previous_state.manifest is None:
|
||||||
raise InternalException(
|
raise InternalException("No comparison manifest in _macros_modified")
|
||||||
'No comparison manifest in _macros_modified'
|
|
||||||
)
|
|
||||||
old_macros = self.previous_state.manifest.macros
|
old_macros = self.previous_state.manifest.macros
|
||||||
new_macros = self.manifest.macros
|
new_macros = self.manifest.macros
|
||||||
|
|
||||||
modified = []
|
modified = []
|
||||||
for uid, macro in new_macros.items():
|
for uid, macro in new_macros.items():
|
||||||
name = f'{macro.package_name}.{macro.name}'
|
name = f"{macro.package_name}.{macro.name}"
|
||||||
if uid in old_macros:
|
if uid in old_macros:
|
||||||
old_macro = old_macros[uid]
|
old_macro = old_macros[uid]
|
||||||
if macro.macro_sql != old_macro.macro_sql:
|
if macro.macro_sql != old_macro.macro_sql:
|
||||||
modified.append(f'{name} changed')
|
modified.append(f"{name} changed")
|
||||||
else:
|
else:
|
||||||
modified.append(f'{name} added')
|
modified.append(f"{name} added")
|
||||||
|
|
||||||
for uid, macro in old_macros.items():
|
for uid, macro in old_macros.items():
|
||||||
if uid not in new_macros:
|
if uid not in new_macros:
|
||||||
modified.append(f'{macro.package_name}.{macro.name} removed')
|
modified.append(f"{macro.package_name}.{macro.name} removed")
|
||||||
|
|
||||||
return modified[:3]
|
return modified[:3]
|
||||||
|
|
||||||
@@ -437,12 +431,14 @@ class StateSelectorMethod(SelectorMethod):
|
|||||||
if self.macros_were_modified is None:
|
if self.macros_were_modified is None:
|
||||||
self.macros_were_modified = self._macros_modified()
|
self.macros_were_modified = self._macros_modified()
|
||||||
if self.macros_were_modified:
|
if self.macros_were_modified:
|
||||||
log_str = ', '.join(self.macros_were_modified)
|
log_str = ", ".join(self.macros_were_modified)
|
||||||
logger.warning(warning_tag(
|
logger.warning(
|
||||||
f'During a state comparison, dbt detected a change in '
|
warning_tag(
|
||||||
f'macros. This will not be marked as a modification. Some '
|
f"During a state comparison, dbt detected a change in "
|
||||||
f'macros: {log_str}'
|
f"macros. This will not be marked as a modification. Some "
|
||||||
))
|
f"macros: {log_str}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return not new.same_contents(old) # type: ignore
|
return not new.same_contents(old) # type: ignore
|
||||||
|
|
||||||
@@ -458,12 +454,12 @@ class StateSelectorMethod(SelectorMethod):
|
|||||||
) -> Iterator[UniqueId]:
|
) -> Iterator[UniqueId]:
|
||||||
if self.previous_state is None or self.previous_state.manifest is None:
|
if self.previous_state is None or self.previous_state.manifest is None:
|
||||||
raise RuntimeException(
|
raise RuntimeException(
|
||||||
'Got a state selector method, but no comparison manifest'
|
"Got a state selector method, but no comparison manifest"
|
||||||
)
|
)
|
||||||
|
|
||||||
state_checks = {
|
state_checks = {
|
||||||
'modified': self.check_modified,
|
"modified": self.check_modified,
|
||||||
'new': self.check_new,
|
"new": self.check_new,
|
||||||
}
|
}
|
||||||
if selector in state_checks:
|
if selector in state_checks:
|
||||||
checker = state_checks[selector]
|
checker = state_checks[selector]
|
||||||
@@ -517,7 +513,7 @@ class MethodManager:
|
|||||||
if method not in self.SELECTOR_METHODS:
|
if method not in self.SELECTOR_METHODS:
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
f'Method name "{method}" is a valid node selection '
|
f'Method name "{method}" is a valid node selection '
|
||||||
f'method name, but it is not handled'
|
f"method name, but it is not handled"
|
||||||
)
|
)
|
||||||
cls: Type[SelectorMethod] = self.SELECTOR_METHODS[method]
|
cls: Type[SelectorMethod] = self.SELECTOR_METHODS[method]
|
||||||
return cls(self.manifest, self.previous_state, method_arguments)
|
return cls(self.manifest, self.previous_state, method_arguments)
|
||||||
|
|||||||
@@ -3,23 +3,21 @@ import re
|
|||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from typing import (
|
from typing import Set, Iterator, List, Optional, Dict, Union, Any, Iterable, Tuple
|
||||||
Set, Iterator, List, Optional, Dict, Union, Any, Iterable, Tuple
|
|
||||||
)
|
|
||||||
from .graph import UniqueId
|
from .graph import UniqueId
|
||||||
from .selector_methods import MethodName
|
from .selector_methods import MethodName
|
||||||
from dbt.exceptions import RuntimeException, InvalidSelectorException
|
from dbt.exceptions import RuntimeException, InvalidSelectorException
|
||||||
|
|
||||||
|
|
||||||
RAW_SELECTOR_PATTERN = re.compile(
|
RAW_SELECTOR_PATTERN = re.compile(
|
||||||
r'\A'
|
r"\A"
|
||||||
r'(?P<childrens_parents>(\@))?'
|
r"(?P<childrens_parents>(\@))?"
|
||||||
r'(?P<parents>((?P<parents_depth>(\d*))\+))?'
|
r"(?P<parents>((?P<parents_depth>(\d*))\+))?"
|
||||||
r'((?P<method>([\w.]+)):)?(?P<value>(.*?))'
|
r"((?P<method>([\w.]+)):)?(?P<value>(.*?))"
|
||||||
r'(?P<children>(\+(?P<children_depth>(\d*))))?'
|
r"(?P<children>(\+(?P<children_depth>(\d*))))?"
|
||||||
r'\Z'
|
r"\Z"
|
||||||
)
|
)
|
||||||
SELECTOR_METHOD_SEPARATOR = '.'
|
SELECTOR_METHOD_SEPARATOR = "."
|
||||||
|
|
||||||
|
|
||||||
def _probably_path(value: str):
|
def _probably_path(value: str):
|
||||||
@@ -43,15 +41,15 @@ def _match_to_int(match: Dict[str, str], key: str) -> Optional[int]:
|
|||||||
return int(raw)
|
return int(raw)
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
raise RuntimeException(
|
raise RuntimeException(
|
||||||
f'Invalid node spec - could not handle parent depth {raw}'
|
f"Invalid node spec - could not handle parent depth {raw}"
|
||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
|
|
||||||
SelectionSpec = Union[
|
SelectionSpec = Union[
|
||||||
'SelectionCriteria',
|
"SelectionCriteria",
|
||||||
'SelectionIntersection',
|
"SelectionIntersection",
|
||||||
'SelectionDifference',
|
"SelectionDifference",
|
||||||
'SelectionUnion',
|
"SelectionUnion",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -71,7 +69,7 @@ class SelectionCriteria:
|
|||||||
if self.children and self.childrens_parents:
|
if self.children and self.childrens_parents:
|
||||||
raise RuntimeException(
|
raise RuntimeException(
|
||||||
f'Invalid node spec {self.raw} - "@" prefix and "+" suffix '
|
f'Invalid node spec {self.raw} - "@" prefix and "+" suffix '
|
||||||
'are incompatible'
|
"are incompatible"
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -82,12 +80,10 @@ class SelectionCriteria:
|
|||||||
return MethodName.FQN
|
return MethodName.FQN
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parse_method(
|
def parse_method(cls, groupdict: Dict[str, Any]) -> Tuple[MethodName, List[str]]:
|
||||||
cls, groupdict: Dict[str, Any]
|
raw_method = groupdict.get("method")
|
||||||
) -> Tuple[MethodName, List[str]]:
|
|
||||||
raw_method = groupdict.get('method')
|
|
||||||
if raw_method is None:
|
if raw_method is None:
|
||||||
return cls.default_method(groupdict['value']), []
|
return cls.default_method(groupdict["value"]), []
|
||||||
|
|
||||||
method_parts: List[str] = raw_method.split(SELECTOR_METHOD_SEPARATOR)
|
method_parts: List[str] = raw_method.split(SELECTOR_METHOD_SEPARATOR)
|
||||||
try:
|
try:
|
||||||
@@ -104,24 +100,22 @@ class SelectionCriteria:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def selection_criteria_from_dict(
|
def selection_criteria_from_dict(
|
||||||
cls, raw: Any, dct: Dict[str, Any]
|
cls, raw: Any, dct: Dict[str, Any]
|
||||||
) -> 'SelectionCriteria':
|
) -> "SelectionCriteria":
|
||||||
if 'value' not in dct:
|
if "value" not in dct:
|
||||||
raise RuntimeException(
|
raise RuntimeException(f'Invalid node spec "{raw}" - no search value!')
|
||||||
f'Invalid node spec "{raw}" - no search value!'
|
|
||||||
)
|
|
||||||
method_name, method_arguments = cls.parse_method(dct)
|
method_name, method_arguments = cls.parse_method(dct)
|
||||||
|
|
||||||
parents_depth = _match_to_int(dct, 'parents_depth')
|
parents_depth = _match_to_int(dct, "parents_depth")
|
||||||
children_depth = _match_to_int(dct, 'children_depth')
|
children_depth = _match_to_int(dct, "children_depth")
|
||||||
return cls(
|
return cls(
|
||||||
raw=raw,
|
raw=raw,
|
||||||
method=method_name,
|
method=method_name,
|
||||||
method_arguments=method_arguments,
|
method_arguments=method_arguments,
|
||||||
value=dct['value'],
|
value=dct["value"],
|
||||||
childrens_parents=bool(dct.get('childrens_parents')),
|
childrens_parents=bool(dct.get("childrens_parents")),
|
||||||
parents=bool(dct.get('parents')),
|
parents=bool(dct.get("parents")),
|
||||||
parents_depth=parents_depth,
|
parents_depth=parents_depth,
|
||||||
children=bool(dct.get('children')),
|
children=bool(dct.get("children")),
|
||||||
children_depth=children_depth,
|
children_depth=children_depth,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -129,24 +123,24 @@ class SelectionCriteria:
|
|||||||
def dict_from_single_spec(cls, raw: str):
|
def dict_from_single_spec(cls, raw: str):
|
||||||
result = RAW_SELECTOR_PATTERN.match(raw)
|
result = RAW_SELECTOR_PATTERN.match(raw)
|
||||||
if result is None:
|
if result is None:
|
||||||
return {'error': 'Invalid selector spec'}
|
return {"error": "Invalid selector spec"}
|
||||||
dct: Dict[str, Any] = result.groupdict()
|
dct: Dict[str, Any] = result.groupdict()
|
||||||
method_name, method_arguments = cls.parse_method(dct)
|
method_name, method_arguments = cls.parse_method(dct)
|
||||||
meth_name = str(method_name)
|
meth_name = str(method_name)
|
||||||
if method_arguments:
|
if method_arguments:
|
||||||
meth_name = meth_name + '.' + '.'.join(method_arguments)
|
meth_name = meth_name + "." + ".".join(method_arguments)
|
||||||
dct['method'] = meth_name
|
dct["method"] = meth_name
|
||||||
dct = {k: v for k, v in dct.items() if (v is not None and v != '')}
|
dct = {k: v for k, v in dct.items() if (v is not None and v != "")}
|
||||||
if 'childrens_parents' in dct:
|
if "childrens_parents" in dct:
|
||||||
dct['childrens_parents'] = bool(dct.get('childrens_parents'))
|
dct["childrens_parents"] = bool(dct.get("childrens_parents"))
|
||||||
if 'parents' in dct:
|
if "parents" in dct:
|
||||||
dct['parents'] = bool(dct.get('parents'))
|
dct["parents"] = bool(dct.get("parents"))
|
||||||
if 'children' in dct:
|
if "children" in dct:
|
||||||
dct['children'] = bool(dct.get('children'))
|
dct["children"] = bool(dct.get("children"))
|
||||||
return dct
|
return dct
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_single_spec(cls, raw: str) -> 'SelectionCriteria':
|
def from_single_spec(cls, raw: str) -> "SelectionCriteria":
|
||||||
result = RAW_SELECTOR_PATTERN.match(raw)
|
result = RAW_SELECTOR_PATTERN.match(raw)
|
||||||
if result is None:
|
if result is None:
|
||||||
# bad spec!
|
# bad spec!
|
||||||
@@ -175,9 +169,7 @@ class BaseSelectionGroup(Iterable[SelectionSpec], metaclass=ABCMeta):
|
|||||||
self,
|
self,
|
||||||
selections: List[Set[UniqueId]],
|
selections: List[Set[UniqueId]],
|
||||||
) -> Set[UniqueId]:
|
) -> Set[UniqueId]:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError("_combine_selections not implemented!")
|
||||||
'_combine_selections not implemented!'
|
|
||||||
)
|
|
||||||
|
|
||||||
def combined(self, selections: List[Set[UniqueId]]) -> Set[UniqueId]:
|
def combined(self, selections: List[Set[UniqueId]]) -> Set[UniqueId]:
|
||||||
if not selections:
|
if not selections:
|
||||||
|
|||||||
@@ -5,7 +5,9 @@ from pathlib import Path
|
|||||||
from typing import Tuple, AbstractSet, Union
|
from typing import Tuple, AbstractSet, Union
|
||||||
|
|
||||||
from dbt.dataclass_schema import (
|
from dbt.dataclass_schema import (
|
||||||
dbtClassMixin, ValidationError, StrEnum,
|
dbtClassMixin,
|
||||||
|
ValidationError,
|
||||||
|
StrEnum,
|
||||||
)
|
)
|
||||||
from hologram import FieldEncoder, JsonDict
|
from hologram import FieldEncoder, JsonDict
|
||||||
from mashumaro.types import SerializableType
|
from mashumaro.types import SerializableType
|
||||||
@@ -13,11 +15,11 @@ from mashumaro.types import SerializableType
|
|||||||
|
|
||||||
class Port(int, SerializableType):
|
class Port(int, SerializableType):
|
||||||
@classmethod
|
@classmethod
|
||||||
def _deserialize(cls, value: Union[int, str]) -> 'Port':
|
def _deserialize(cls, value: Union[int, str]) -> "Port":
|
||||||
try:
|
try:
|
||||||
value = int(value)
|
value = int(value)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise ValidationError(f'Cannot encode {value} into port number')
|
raise ValidationError(f"Cannot encode {value} into port number")
|
||||||
|
|
||||||
return Port(value)
|
return Port(value)
|
||||||
|
|
||||||
@@ -28,7 +30,7 @@ class Port(int, SerializableType):
|
|||||||
class PortEncoder(FieldEncoder):
|
class PortEncoder(FieldEncoder):
|
||||||
@property
|
@property
|
||||||
def json_schema(self):
|
def json_schema(self):
|
||||||
return {'type': 'integer', 'minimum': 0, 'maximum': 65535}
|
return {"type": "integer", "minimum": 0, "maximum": 65535}
|
||||||
|
|
||||||
|
|
||||||
class TimeDeltaFieldEncoder(FieldEncoder[timedelta]):
|
class TimeDeltaFieldEncoder(FieldEncoder[timedelta]):
|
||||||
@@ -44,12 +46,12 @@ class TimeDeltaFieldEncoder(FieldEncoder[timedelta]):
|
|||||||
return timedelta(seconds=value)
|
return timedelta(seconds=value)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
raise ValidationError(
|
raise ValidationError(
|
||||||
'cannot encode {} into timedelta'.format(value)
|
"cannot encode {} into timedelta".format(value)
|
||||||
) from None
|
) from None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def json_schema(self) -> JsonDict:
|
def json_schema(self) -> JsonDict:
|
||||||
return {'type': 'number'}
|
return {"type": "number"}
|
||||||
|
|
||||||
|
|
||||||
class PathEncoder(FieldEncoder):
|
class PathEncoder(FieldEncoder):
|
||||||
@@ -63,16 +65,16 @@ class PathEncoder(FieldEncoder):
|
|||||||
return Path(value)
|
return Path(value)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
raise ValidationError(
|
raise ValidationError(
|
||||||
'cannot encode {} into timedelta'.format(value)
|
"cannot encode {} into timedelta".format(value)
|
||||||
) from None
|
) from None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def json_schema(self) -> JsonDict:
|
def json_schema(self) -> JsonDict:
|
||||||
return {'type': 'string'}
|
return {"type": "string"}
|
||||||
|
|
||||||
|
|
||||||
class NVEnum(StrEnum):
|
class NVEnum(StrEnum):
|
||||||
novalue = 'novalue'
|
novalue = "novalue"
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return isinstance(other, NVEnum)
|
return isinstance(other, NVEnum)
|
||||||
@@ -81,14 +83,17 @@ class NVEnum(StrEnum):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class NoValue(dbtClassMixin):
|
class NoValue(dbtClassMixin):
|
||||||
"""Sometimes, you want a way to say none that isn't None"""
|
"""Sometimes, you want a way to say none that isn't None"""
|
||||||
|
|
||||||
novalue: NVEnum = NVEnum.novalue
|
novalue: NVEnum = NVEnum.novalue
|
||||||
|
|
||||||
|
|
||||||
dbtClassMixin.register_field_encoders({
|
dbtClassMixin.register_field_encoders(
|
||||||
|
{
|
||||||
Port: PortEncoder(),
|
Port: PortEncoder(),
|
||||||
timedelta: TimeDeltaFieldEncoder(),
|
timedelta: TimeDeltaFieldEncoder(),
|
||||||
Path: PathEncoder(),
|
Path: PathEncoder(),
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
FQNPath = Tuple[str, ...]
|
FQNPath = Tuple[str, ...]
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ from typing import Union, Dict, Any
|
|||||||
|
|
||||||
|
|
||||||
class ModelHookType(StrEnum):
|
class ModelHookType(StrEnum):
|
||||||
PreHook = 'pre-hook'
|
PreHook = "pre-hook"
|
||||||
PostHook = 'post-hook'
|
PostHook = "post-hook"
|
||||||
|
|
||||||
|
|
||||||
def get_hook_dict(source: Union[str, Dict[str, Any]]) -> Dict[str, Any]:
|
def get_hook_dict(source: Union[str, Dict[str, Any]]) -> Dict[str, Any]:
|
||||||
@@ -18,4 +18,4 @@ def get_hook_dict(source: Union[str, Dict[str, Any]]) -> Dict[str, Any]:
|
|||||||
try:
|
try:
|
||||||
return json.loads(source)
|
return json.loads(source)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return {'sql': source}
|
return {"sql": source}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
PACKAGE_PATH = os.path.dirname(__file__)
|
PACKAGE_PATH = os.path.dirname(__file__)
|
||||||
PROJECT_NAME = 'dbt'
|
PROJECT_NAME = "dbt"
|
||||||
|
|
||||||
DOCS_INDEX_FILE_PATH = os.path.normpath(
|
DOCS_INDEX_FILE_PATH = os.path.normpath(os.path.join(PACKAGE_PATH, "..", "index.html"))
|
||||||
os.path.join(PACKAGE_PATH, '..', "index.html"))
|
|
||||||
|
|||||||
@@ -287,4 +287,3 @@
|
|||||||
{% macro set_sql_header(config) -%}
|
{% macro set_sql_header(config) -%}
|
||||||
{{ config.set('sql_header', caller()) }}
|
{{ config.set('sql_header', caller()) }}
|
||||||
{%- endmacro %}
|
{%- endmacro %}
|
||||||
|
|
||||||
|
|||||||
@@ -23,5 +23,3 @@
|
|||||||
values ({{ insert_cols_csv }})
|
values ({{ insert_cols_csv }})
|
||||||
;
|
;
|
||||||
{% endmacro %}
|
{% endmacro %}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
ProfileConfigDocs = 'https://docs.getdbt.com/docs/configure-your-profile'
|
ProfileConfigDocs = "https://docs.getdbt.com/docs/configure-your-profile"
|
||||||
SnowflakeQuotingDocs = 'https://docs.getdbt.com/v0.10/docs/configuring-quoting'
|
SnowflakeQuotingDocs = "https://docs.getdbt.com/v0.10/docs/configuring-quoting"
|
||||||
IncrementalDocs = 'https://docs.getdbt.com/docs/configuring-incremental-models'
|
IncrementalDocs = "https://docs.getdbt.com/docs/configuring-incremental-models"
|
||||||
BigQueryNewPartitionBy = 'https://docs.getdbt.com/docs/upgrading-to-0-16-0'
|
BigQueryNewPartitionBy = "https://docs.getdbt.com/docs/upgrading-to-0-16-0"
|
||||||
|
|||||||
@@ -26,21 +26,21 @@ colorama_wrap = True
|
|||||||
colorama.init(wrap=colorama_wrap)
|
colorama.init(wrap=colorama_wrap)
|
||||||
|
|
||||||
|
|
||||||
if sys.platform == 'win32' and not os.getenv('TERM'):
|
if sys.platform == "win32" and not os.getenv("TERM"):
|
||||||
colorama_wrap = False
|
colorama_wrap = False
|
||||||
colorama_stdout = colorama.AnsiToWin32(sys.stdout).stream
|
colorama_stdout = colorama.AnsiToWin32(sys.stdout).stream
|
||||||
|
|
||||||
elif sys.platform == 'win32':
|
elif sys.platform == "win32":
|
||||||
colorama_wrap = False
|
colorama_wrap = False
|
||||||
|
|
||||||
colorama.init(wrap=colorama_wrap)
|
colorama.init(wrap=colorama_wrap)
|
||||||
|
|
||||||
|
|
||||||
STDOUT_LOG_FORMAT = '{record.message}'
|
STDOUT_LOG_FORMAT = "{record.message}"
|
||||||
DEBUG_LOG_FORMAT = (
|
DEBUG_LOG_FORMAT = (
|
||||||
'{record.time:%Y-%m-%d %H:%M:%S.%f%z} '
|
"{record.time:%Y-%m-%d %H:%M:%S.%f%z} "
|
||||||
'({record.thread_name}): '
|
"({record.thread_name}): "
|
||||||
'{record.message}'
|
"{record.message}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -94,6 +94,7 @@ class JsonFormatter(LogMessageFormatter):
|
|||||||
"""Return a the record converted to LogMessage's JSON form"""
|
"""Return a the record converted to LogMessage's JSON form"""
|
||||||
# utils imports exceptions which imports logger...
|
# utils imports exceptions which imports logger...
|
||||||
import dbt.utils
|
import dbt.utils
|
||||||
|
|
||||||
log_message = super().__call__(record, handler)
|
log_message = super().__call__(record, handler)
|
||||||
dct = log_message.to_dict(omit_none=True)
|
dct = log_message.to_dict(omit_none=True)
|
||||||
return json.dumps(dct, cls=dbt.utils.JSONEncoder)
|
return json.dumps(dct, cls=dbt.utils.JSONEncoder)
|
||||||
@@ -117,9 +118,7 @@ class FormatterMixin:
|
|||||||
self.format_string = self._text_format_string
|
self.format_string = self._text_format_string
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError("reset() not implemented in FormatterMixin subclass")
|
||||||
'reset() not implemented in FormatterMixin subclass'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class OutputHandler(logbook.StreamHandler, FormatterMixin):
|
class OutputHandler(logbook.StreamHandler, FormatterMixin):
|
||||||
@@ -164,9 +163,9 @@ class OutputHandler(logbook.StreamHandler, FormatterMixin):
|
|||||||
if record.level < self.level:
|
if record.level < self.level:
|
||||||
return False
|
return False
|
||||||
text_mode = self.formatter_class is logbook.StringFormatter
|
text_mode = self.formatter_class is logbook.StringFormatter
|
||||||
if text_mode and record.extra.get('json_only', False):
|
if text_mode and record.extra.get("json_only", False):
|
||||||
return False
|
return False
|
||||||
elif not text_mode and record.extra.get('text_only', False):
|
elif not text_mode and record.extra.get("text_only", False):
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
@@ -177,7 +176,7 @@ def _redirect_std_logging():
|
|||||||
|
|
||||||
|
|
||||||
def _root_channel(record: logbook.LogRecord) -> str:
|
def _root_channel(record: logbook.LogRecord) -> str:
|
||||||
return record.channel.split('.')[0]
|
return record.channel.split(".")[0]
|
||||||
|
|
||||||
|
|
||||||
class Relevel(logbook.Processor):
|
class Relevel(logbook.Processor):
|
||||||
@@ -195,7 +194,7 @@ class Relevel(logbook.Processor):
|
|||||||
def process(self, record):
|
def process(self, record):
|
||||||
if _root_channel(record) in self.allowed:
|
if _root_channel(record) in self.allowed:
|
||||||
return
|
return
|
||||||
record.extra['old_level'] = record.level
|
record.extra["old_level"] = record.level
|
||||||
# suppress logs at/below our min level by lowering them to NOTSET
|
# suppress logs at/below our min level by lowering them to NOTSET
|
||||||
if record.level < self.min_level:
|
if record.level < self.min_level:
|
||||||
record.level = logbook.NOTSET
|
record.level = logbook.NOTSET
|
||||||
@@ -207,12 +206,12 @@ class Relevel(logbook.Processor):
|
|||||||
|
|
||||||
class JsonOnly(logbook.Processor):
|
class JsonOnly(logbook.Processor):
|
||||||
def process(self, record):
|
def process(self, record):
|
||||||
record.extra['json_only'] = True
|
record.extra["json_only"] = True
|
||||||
|
|
||||||
|
|
||||||
class TextOnly(logbook.Processor):
|
class TextOnly(logbook.Processor):
|
||||||
def process(self, record):
|
def process(self, record):
|
||||||
record.extra['text_only'] = True
|
record.extra["text_only"] = True
|
||||||
|
|
||||||
|
|
||||||
class TimingProcessor(logbook.Processor):
|
class TimingProcessor(logbook.Processor):
|
||||||
@@ -222,8 +221,7 @@ class TimingProcessor(logbook.Processor):
|
|||||||
|
|
||||||
def process(self, record):
|
def process(self, record):
|
||||||
if self.timing_info is not None:
|
if self.timing_info is not None:
|
||||||
record.extra['timing_info'] = self.timing_info.to_dict(
|
record.extra["timing_info"] = self.timing_info.to_dict(omit_none=True)
|
||||||
omit_none=True)
|
|
||||||
|
|
||||||
|
|
||||||
class DbtProcessState(logbook.Processor):
|
class DbtProcessState(logbook.Processor):
|
||||||
@@ -233,11 +231,10 @@ class DbtProcessState(logbook.Processor):
|
|||||||
|
|
||||||
def process(self, record):
|
def process(self, record):
|
||||||
overwrite = (
|
overwrite = (
|
||||||
'run_state' not in record.extra or
|
"run_state" not in record.extra or record.extra["run_state"] == "internal"
|
||||||
record.extra['run_state'] == 'internal'
|
|
||||||
)
|
)
|
||||||
if overwrite:
|
if overwrite:
|
||||||
record.extra['run_state'] = self.value
|
record.extra["run_state"] = self.value
|
||||||
|
|
||||||
|
|
||||||
class DbtModelState(logbook.Processor):
|
class DbtModelState(logbook.Processor):
|
||||||
@@ -251,7 +248,7 @@ class DbtModelState(logbook.Processor):
|
|||||||
|
|
||||||
class DbtStatusMessage(logbook.Processor):
|
class DbtStatusMessage(logbook.Processor):
|
||||||
def process(self, record):
|
def process(self, record):
|
||||||
record.extra['is_status_message'] = True
|
record.extra["is_status_message"] = True
|
||||||
|
|
||||||
|
|
||||||
class UniqueID(logbook.Processor):
|
class UniqueID(logbook.Processor):
|
||||||
@@ -260,7 +257,7 @@ class UniqueID(logbook.Processor):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def process(self, record):
|
def process(self, record):
|
||||||
record.extra['unique_id'] = self.unique_id
|
record.extra["unique_id"] = self.unique_id
|
||||||
|
|
||||||
|
|
||||||
class NodeCount(logbook.Processor):
|
class NodeCount(logbook.Processor):
|
||||||
@@ -269,7 +266,7 @@ class NodeCount(logbook.Processor):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def process(self, record):
|
def process(self, record):
|
||||||
record.extra['node_count'] = self.node_count
|
record.extra["node_count"] = self.node_count
|
||||||
|
|
||||||
|
|
||||||
class NodeMetadata(logbook.Processor):
|
class NodeMetadata(logbook.Processor):
|
||||||
@@ -289,26 +286,26 @@ class NodeMetadata(logbook.Processor):
|
|||||||
|
|
||||||
def process(self, record):
|
def process(self, record):
|
||||||
self.process_keys(record)
|
self.process_keys(record)
|
||||||
record.extra['node_index'] = self.index
|
record.extra["node_index"] = self.index
|
||||||
|
|
||||||
|
|
||||||
class ModelMetadata(NodeMetadata):
|
class ModelMetadata(NodeMetadata):
|
||||||
def mapping_keys(self):
|
def mapping_keys(self):
|
||||||
return [
|
return [
|
||||||
('alias', 'node_alias'),
|
("alias", "node_alias"),
|
||||||
('schema', 'node_schema'),
|
("schema", "node_schema"),
|
||||||
('database', 'node_database'),
|
("database", "node_database"),
|
||||||
('original_file_path', 'node_path'),
|
("original_file_path", "node_path"),
|
||||||
('name', 'node_name'),
|
("name", "node_name"),
|
||||||
('resource_type', 'resource_type'),
|
("resource_type", "resource_type"),
|
||||||
('depends_on_nodes', 'depends_on'),
|
("depends_on_nodes", "depends_on"),
|
||||||
]
|
]
|
||||||
|
|
||||||
def process_config(self, record):
|
def process_config(self, record):
|
||||||
if hasattr(self.node, 'config'):
|
if hasattr(self.node, "config"):
|
||||||
materialized = getattr(self.node.config, 'materialized', None)
|
materialized = getattr(self.node.config, "materialized", None)
|
||||||
if materialized is not None:
|
if materialized is not None:
|
||||||
record.extra['node_materialized'] = materialized
|
record.extra["node_materialized"] = materialized
|
||||||
|
|
||||||
def process(self, record):
|
def process(self, record):
|
||||||
super().process(record)
|
super().process(record)
|
||||||
@@ -318,8 +315,8 @@ class ModelMetadata(NodeMetadata):
|
|||||||
class HookMetadata(NodeMetadata):
|
class HookMetadata(NodeMetadata):
|
||||||
def mapping_keys(self):
|
def mapping_keys(self):
|
||||||
return [
|
return [
|
||||||
('name', 'node_name'),
|
("name", "node_name"),
|
||||||
('resource_type', 'resource_type'),
|
("resource_type", "resource_type"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -333,30 +330,31 @@ class TimestampNamed(logbook.Processor):
|
|||||||
record.extra[self.name] = datetime.utcnow().isoformat()
|
record.extra[self.name] = datetime.utcnow().isoformat()
|
||||||
|
|
||||||
|
|
||||||
logger = logbook.Logger('dbt')
|
logger = logbook.Logger("dbt")
|
||||||
# provide this for the cache, disabled by default
|
# provide this for the cache, disabled by default
|
||||||
CACHE_LOGGER = logbook.Logger('dbt.cache')
|
CACHE_LOGGER = logbook.Logger("dbt.cache")
|
||||||
CACHE_LOGGER.disable()
|
CACHE_LOGGER.disable()
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=ResourceWarning,
|
warnings.filterwarnings(
|
||||||
message="unclosed.*<socket.socket.*>")
|
"ignore", category=ResourceWarning, message="unclosed.*<socket.socket.*>"
|
||||||
|
)
|
||||||
|
|
||||||
initialized = False
|
initialized = False
|
||||||
|
|
||||||
|
|
||||||
def make_log_dir_if_missing(log_dir):
|
def make_log_dir_if_missing(log_dir):
|
||||||
import dbt.clients.system
|
import dbt.clients.system
|
||||||
|
|
||||||
dbt.clients.system.make_directory(log_dir)
|
dbt.clients.system.make_directory(log_dir)
|
||||||
|
|
||||||
|
|
||||||
class DebugWarnings(logbook.compat.redirected_warnings):
|
class DebugWarnings(logbook.compat.redirected_warnings):
|
||||||
"""Log warnings, except send them to 'debug' instead of 'warning' level.
|
"""Log warnings, except send them to 'debug' instead of 'warning' level."""
|
||||||
"""
|
|
||||||
|
|
||||||
def make_record(self, message, exception, filename, lineno):
|
def make_record(self, message, exception, filename, lineno):
|
||||||
rv = super().make_record(message, exception, filename, lineno)
|
rv = super().make_record(message, exception, filename, lineno)
|
||||||
rv.level = logbook.DEBUG
|
rv.level = logbook.DEBUG
|
||||||
rv.extra['from_warnings'] = True
|
rv.extra["from_warnings"] = True
|
||||||
return rv
|
return rv
|
||||||
|
|
||||||
|
|
||||||
@@ -408,14 +406,14 @@ class DelayedFileHandler(logbook.RotatingFileHandler, FormatterMixin):
|
|||||||
if self.disabled:
|
if self.disabled:
|
||||||
return
|
return
|
||||||
|
|
||||||
assert not self.initialized, 'set_path called after being set'
|
assert not self.initialized, "set_path called after being set"
|
||||||
|
|
||||||
if log_dir is None:
|
if log_dir is None:
|
||||||
self.disabled = True
|
self.disabled = True
|
||||||
return
|
return
|
||||||
|
|
||||||
make_log_dir_if_missing(log_dir)
|
make_log_dir_if_missing(log_dir)
|
||||||
log_path = os.path.join(log_dir, 'dbt.log')
|
log_path = os.path.join(log_dir, "dbt.log")
|
||||||
self._super_init(log_path)
|
self._super_init(log_path)
|
||||||
self._replay_buffered()
|
self._replay_buffered()
|
||||||
self._log_path = log_path
|
self._log_path = log_path
|
||||||
@@ -435,8 +433,9 @@ class DelayedFileHandler(logbook.RotatingFileHandler, FormatterMixin):
|
|||||||
FormatterMixin.__init__(self, DEBUG_LOG_FORMAT)
|
FormatterMixin.__init__(self, DEBUG_LOG_FORMAT)
|
||||||
|
|
||||||
def _replay_buffered(self):
|
def _replay_buffered(self):
|
||||||
assert self._msg_buffer is not None, \
|
assert (
|
||||||
'_msg_buffer should never be None in _replay_buffered'
|
self._msg_buffer is not None
|
||||||
|
), "_msg_buffer should never be None in _replay_buffered"
|
||||||
for record in self._msg_buffer:
|
for record in self._msg_buffer:
|
||||||
super().emit(record)
|
super().emit(record)
|
||||||
self._msg_buffer = None
|
self._msg_buffer = None
|
||||||
@@ -445,7 +444,7 @@ class DelayedFileHandler(logbook.RotatingFileHandler, FormatterMixin):
|
|||||||
msg = super().format(record)
|
msg = super().format(record)
|
||||||
subbed = str(msg)
|
subbed = str(msg)
|
||||||
for escape_sequence in dbt.ui.COLORS.values():
|
for escape_sequence in dbt.ui.COLORS.values():
|
||||||
subbed = subbed.replace(escape_sequence, '')
|
subbed = subbed.replace(escape_sequence, "")
|
||||||
return subbed
|
return subbed
|
||||||
|
|
||||||
def emit(self, record: logbook.LogRecord):
|
def emit(self, record: logbook.LogRecord):
|
||||||
@@ -457,11 +456,13 @@ class DelayedFileHandler(logbook.RotatingFileHandler, FormatterMixin):
|
|||||||
elif self.initialized:
|
elif self.initialized:
|
||||||
super().emit(record)
|
super().emit(record)
|
||||||
else:
|
else:
|
||||||
assert self._msg_buffer is not None, \
|
assert (
|
||||||
'_msg_buffer should never be None if _log_path is set'
|
self._msg_buffer is not None
|
||||||
|
), "_msg_buffer should never be None if _log_path is set"
|
||||||
self._msg_buffer.append(record)
|
self._msg_buffer.append(record)
|
||||||
assert len(self._msg_buffer) < self._bufmax, \
|
assert (
|
||||||
'too many messages received before initilization!'
|
len(self._msg_buffer) < self._bufmax
|
||||||
|
), "too many messages received before initilization!"
|
||||||
|
|
||||||
|
|
||||||
class LogManager(logbook.NestedSetup):
|
class LogManager(logbook.NestedSetup):
|
||||||
@@ -471,19 +472,21 @@ class LogManager(logbook.NestedSetup):
|
|||||||
self._null_handler = logbook.NullHandler()
|
self._null_handler = logbook.NullHandler()
|
||||||
self._output_handler = OutputHandler(self.stdout)
|
self._output_handler = OutputHandler(self.stdout)
|
||||||
self._file_handler = DelayedFileHandler()
|
self._file_handler = DelayedFileHandler()
|
||||||
self._relevel_processor = Relevel(allowed=['dbt', 'werkzeug'])
|
self._relevel_processor = Relevel(allowed=["dbt", "werkzeug"])
|
||||||
self._state_processor = DbtProcessState('internal')
|
self._state_processor = DbtProcessState("internal")
|
||||||
# keep track of wheter we've already entered to decide if we should
|
# keep track of wheter we've already entered to decide if we should
|
||||||
# be actually pushing. This allows us to log in main() and also
|
# be actually pushing. This allows us to log in main() and also
|
||||||
# support entering dbt execution via handle_and_check.
|
# support entering dbt execution via handle_and_check.
|
||||||
self._stack_depth = 0
|
self._stack_depth = 0
|
||||||
super().__init__([
|
super().__init__(
|
||||||
|
[
|
||||||
self._null_handler,
|
self._null_handler,
|
||||||
self._output_handler,
|
self._output_handler,
|
||||||
self._file_handler,
|
self._file_handler,
|
||||||
self._relevel_processor,
|
self._relevel_processor,
|
||||||
self._state_processor,
|
self._state_processor,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def push_application(self):
|
def push_application(self):
|
||||||
self._stack_depth += 1
|
self._stack_depth += 1
|
||||||
@@ -499,8 +502,7 @@ class LogManager(logbook.NestedSetup):
|
|||||||
self.add_handler(logbook.NullHandler())
|
self.add_handler(logbook.NullHandler())
|
||||||
|
|
||||||
def add_handler(self, handler):
|
def add_handler(self, handler):
|
||||||
"""add an handler to the log manager that runs before the file handler.
|
"""add an handler to the log manager that runs before the file handler."""
|
||||||
"""
|
|
||||||
self.objects.append(handler)
|
self.objects.append(handler)
|
||||||
|
|
||||||
# this is used by `dbt ls` to allow piping stdout to jq, etc
|
# this is used by `dbt ls` to allow piping stdout to jq, etc
|
||||||
@@ -558,8 +560,7 @@ log_manager = LogManager()
|
|||||||
|
|
||||||
|
|
||||||
def log_cache_events(flag):
|
def log_cache_events(flag):
|
||||||
"""Set the cache logger to propagate its messages based on the given flag.
|
"""Set the cache logger to propagate its messages based on the given flag."""
|
||||||
"""
|
|
||||||
# the flag is True if we should log, and False if we shouldn't, so disabled
|
# the flag is True if we should log, and False if we shouldn't, so disabled
|
||||||
# is the inverse.
|
# is the inverse.
|
||||||
CACHE_LOGGER.disabled = not flag
|
CACHE_LOGGER.disabled = not flag
|
||||||
@@ -583,7 +584,7 @@ class ListLogHandler(LogMessageHandler):
|
|||||||
level: int = logbook.NOTSET,
|
level: int = logbook.NOTSET,
|
||||||
filter: Callable = None,
|
filter: Callable = None,
|
||||||
bubble: bool = False,
|
bubble: bool = False,
|
||||||
lst: Optional[List[LogMessage]] = None
|
lst: Optional[List[LogMessage]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(level, filter, bubble)
|
super().__init__(level, filter, bubble)
|
||||||
if lst is None:
|
if lst is None:
|
||||||
@@ -592,7 +593,7 @@ class ListLogHandler(LogMessageHandler):
|
|||||||
|
|
||||||
def should_handle(self, record):
|
def should_handle(self, record):
|
||||||
"""Only ever emit dbt-sourced log messages to the ListHandler."""
|
"""Only ever emit dbt-sourced log messages to the ListHandler."""
|
||||||
if _root_channel(record) != 'dbt':
|
if _root_channel(record) != "dbt":
|
||||||
return False
|
return False
|
||||||
return super().should_handle(record)
|
return super().should_handle(record)
|
||||||
|
|
||||||
@@ -609,28 +610,27 @@ def _env_log_level(var_name: str) -> int:
|
|||||||
return logging.ERROR
|
return logging.ERROR
|
||||||
|
|
||||||
|
|
||||||
LOG_LEVEL_GOOGLE = _env_log_level('DBT_GOOGLE_DEBUG_LOGGING')
|
LOG_LEVEL_GOOGLE = _env_log_level("DBT_GOOGLE_DEBUG_LOGGING")
|
||||||
LOG_LEVEL_SNOWFLAKE = _env_log_level('DBT_SNOWFLAKE_CONNECTOR_DEBUG_LOGGING')
|
LOG_LEVEL_SNOWFLAKE = _env_log_level("DBT_SNOWFLAKE_CONNECTOR_DEBUG_LOGGING")
|
||||||
LOG_LEVEL_BOTOCORE = _env_log_level('DBT_BOTOCORE_DEBUG_LOGGING')
|
LOG_LEVEL_BOTOCORE = _env_log_level("DBT_BOTOCORE_DEBUG_LOGGING")
|
||||||
LOG_LEVEL_HTTP = _env_log_level('DBT_HTTP_DEBUG_LOGGING')
|
LOG_LEVEL_HTTP = _env_log_level("DBT_HTTP_DEBUG_LOGGING")
|
||||||
LOG_LEVEL_WERKZEUG = _env_log_level('DBT_WERKZEUG_DEBUG_LOGGING')
|
LOG_LEVEL_WERKZEUG = _env_log_level("DBT_WERKZEUG_DEBUG_LOGGING")
|
||||||
|
|
||||||
logging.getLogger('botocore').setLevel(LOG_LEVEL_BOTOCORE)
|
logging.getLogger("botocore").setLevel(LOG_LEVEL_BOTOCORE)
|
||||||
logging.getLogger('requests').setLevel(LOG_LEVEL_HTTP)
|
logging.getLogger("requests").setLevel(LOG_LEVEL_HTTP)
|
||||||
logging.getLogger('urllib3').setLevel(LOG_LEVEL_HTTP)
|
logging.getLogger("urllib3").setLevel(LOG_LEVEL_HTTP)
|
||||||
logging.getLogger('google').setLevel(LOG_LEVEL_GOOGLE)
|
logging.getLogger("google").setLevel(LOG_LEVEL_GOOGLE)
|
||||||
logging.getLogger('snowflake.connector').setLevel(LOG_LEVEL_SNOWFLAKE)
|
logging.getLogger("snowflake.connector").setLevel(LOG_LEVEL_SNOWFLAKE)
|
||||||
|
|
||||||
logging.getLogger('parsedatetime').setLevel(logging.ERROR)
|
logging.getLogger("parsedatetime").setLevel(logging.ERROR)
|
||||||
logging.getLogger('werkzeug').setLevel(LOG_LEVEL_WERKZEUG)
|
logging.getLogger("werkzeug").setLevel(LOG_LEVEL_WERKZEUG)
|
||||||
|
|
||||||
|
|
||||||
def list_handler(
|
def list_handler(
|
||||||
lst: Optional[List[LogMessage]],
|
lst: Optional[List[LogMessage]],
|
||||||
level=logbook.NOTSET,
|
level=logbook.NOTSET,
|
||||||
) -> ContextManager:
|
) -> ContextManager:
|
||||||
"""Return a context manager that temporarly attaches a list to the logger.
|
"""Return a context manager that temporarly attaches a list to the logger."""
|
||||||
"""
|
|
||||||
return ListLogHandler(lst=lst, level=level, bubble=True)
|
return ListLogHandler(lst=lst, level=level, bubble=True)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
660
core/dbt/main.py
660
core/dbt/main.py
File diff suppressed because it is too large
Load Diff
@@ -4,20 +4,20 @@ from dbt.dataclass_schema import StrEnum
|
|||||||
|
|
||||||
|
|
||||||
class NodeType(StrEnum):
|
class NodeType(StrEnum):
|
||||||
Model = 'model'
|
Model = "model"
|
||||||
Analysis = 'analysis'
|
Analysis = "analysis"
|
||||||
Test = 'test'
|
Test = "test"
|
||||||
Snapshot = 'snapshot'
|
Snapshot = "snapshot"
|
||||||
Operation = 'operation'
|
Operation = "operation"
|
||||||
Seed = 'seed'
|
Seed = "seed"
|
||||||
RPCCall = 'rpc'
|
RPCCall = "rpc"
|
||||||
Documentation = 'docs'
|
Documentation = "docs"
|
||||||
Source = 'source'
|
Source = "source"
|
||||||
Macro = 'macro'
|
Macro = "macro"
|
||||||
Exposure = 'exposure'
|
Exposure = "exposure"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def executable(cls) -> List['NodeType']:
|
def executable(cls) -> List["NodeType"]:
|
||||||
return [
|
return [
|
||||||
cls.Model,
|
cls.Model,
|
||||||
cls.Test,
|
cls.Test,
|
||||||
@@ -30,7 +30,7 @@ class NodeType(StrEnum):
|
|||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def refable(cls) -> List['NodeType']:
|
def refable(cls) -> List["NodeType"]:
|
||||||
return [
|
return [
|
||||||
cls.Model,
|
cls.Model,
|
||||||
cls.Seed,
|
cls.Seed,
|
||||||
@@ -38,7 +38,7 @@ class NodeType(StrEnum):
|
|||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def documentable(cls) -> List['NodeType']:
|
def documentable(cls) -> List["NodeType"]:
|
||||||
return [
|
return [
|
||||||
cls.Model,
|
cls.Model,
|
||||||
cls.Seed,
|
cls.Seed,
|
||||||
@@ -46,16 +46,16 @@ class NodeType(StrEnum):
|
|||||||
cls.Source,
|
cls.Source,
|
||||||
cls.Macro,
|
cls.Macro,
|
||||||
cls.Analysis,
|
cls.Analysis,
|
||||||
cls.Exposure
|
cls.Exposure,
|
||||||
]
|
]
|
||||||
|
|
||||||
def pluralize(self) -> str:
|
def pluralize(self) -> str:
|
||||||
if self == 'analysis':
|
if self == "analysis":
|
||||||
return 'analyses'
|
return "analyses"
|
||||||
else:
|
else:
|
||||||
return f'{self}s'
|
return f"{self}s"
|
||||||
|
|
||||||
|
|
||||||
class RunHookType(StrEnum):
|
class RunHookType(StrEnum):
|
||||||
Start = 'on-run-start'
|
Start = "on-run-start"
|
||||||
End = 'on-run-end'
|
End = "on-run-end"
|
||||||
|
|||||||
@@ -11,6 +11,14 @@ from .seeds import SeedParser # noqa
|
|||||||
from .snapshots import SnapshotParser # noqa
|
from .snapshots import SnapshotParser # noqa
|
||||||
|
|
||||||
from . import ( # noqa
|
from . import ( # noqa
|
||||||
analysis, base, data_test, docs, hooks, macros, models, results, schemas,
|
analysis,
|
||||||
snapshots
|
base,
|
||||||
|
data_test,
|
||||||
|
docs,
|
||||||
|
hooks,
|
||||||
|
macros,
|
||||||
|
models,
|
||||||
|
results,
|
||||||
|
schemas,
|
||||||
|
snapshots,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,9 +8,7 @@ from dbt.parser.search import FilesystemSearcher, FileBlock
|
|||||||
|
|
||||||
class AnalysisParser(SimpleSQLParser[ParsedAnalysisNode]):
|
class AnalysisParser(SimpleSQLParser[ParsedAnalysisNode]):
|
||||||
def get_paths(self):
|
def get_paths(self):
|
||||||
return FilesystemSearcher(
|
return FilesystemSearcher(self.project, self.project.analysis_paths, ".sql")
|
||||||
self.project, self.project.analysis_paths, '.sql'
|
|
||||||
)
|
|
||||||
|
|
||||||
def parse_from_dict(self, dct, validate=True) -> ParsedAnalysisNode:
|
def parse_from_dict(self, dct, validate=True) -> ParsedAnalysisNode:
|
||||||
if validate:
|
if validate:
|
||||||
@@ -23,4 +21,4 @@ class AnalysisParser(SimpleSQLParser[ParsedAnalysisNode]):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_compiled_path(cls, block: FileBlock):
|
def get_compiled_path(cls, block: FileBlock):
|
||||||
return os.path.join('analysis', block.path.relative_path)
|
return os.path.join("analysis", block.path.relative_path)
|
||||||
|
|||||||
@@ -1,9 +1,7 @@
|
|||||||
import abc
|
import abc
|
||||||
import itertools
|
import itertools
|
||||||
import os
|
import os
|
||||||
from typing import (
|
from typing import List, Dict, Any, Iterable, Generic, TypeVar
|
||||||
List, Dict, Any, Iterable, Generic, TypeVar
|
|
||||||
)
|
|
||||||
|
|
||||||
from dbt.dataclass_schema import ValidationError
|
from dbt.dataclass_schema import ValidationError
|
||||||
|
|
||||||
@@ -17,17 +15,15 @@ from dbt.context.providers import (
|
|||||||
from dbt.adapters.factory import get_adapter
|
from dbt.adapters.factory import get_adapter
|
||||||
from dbt.clients.jinja import get_rendered
|
from dbt.clients.jinja import get_rendered
|
||||||
from dbt.config import Project, RuntimeConfig
|
from dbt.config import Project, RuntimeConfig
|
||||||
from dbt.context.context_config import (
|
from dbt.context.context_config import ContextConfig
|
||||||
ContextConfig
|
from dbt.contracts.files import SourceFile, FilePath, FileHash
|
||||||
)
|
|
||||||
from dbt.contracts.files import (
|
|
||||||
SourceFile, FilePath, FileHash
|
|
||||||
)
|
|
||||||
from dbt.contracts.graph.manifest import MacroManifest
|
from dbt.contracts.graph.manifest import MacroManifest
|
||||||
from dbt.contracts.graph.parsed import HasUniqueID
|
from dbt.contracts.graph.parsed import HasUniqueID
|
||||||
from dbt.contracts.graph.unparsed import UnparsedNode
|
from dbt.contracts.graph.unparsed import UnparsedNode
|
||||||
from dbt.exceptions import (
|
from dbt.exceptions import (
|
||||||
CompilationException, validator_error_message, InternalException
|
CompilationException,
|
||||||
|
validator_error_message,
|
||||||
|
InternalException,
|
||||||
)
|
)
|
||||||
from dbt import hooks
|
from dbt import hooks
|
||||||
from dbt.node_types import NodeType
|
from dbt.node_types import NodeType
|
||||||
@@ -37,14 +33,14 @@ from dbt.parser.search import FileBlock
|
|||||||
# internally, the parser may store a less-restrictive type that will be
|
# internally, the parser may store a less-restrictive type that will be
|
||||||
# transformed into the final type. But it will have to be derived from
|
# transformed into the final type. But it will have to be derived from
|
||||||
# ParsedNode to be operable.
|
# ParsedNode to be operable.
|
||||||
FinalValue = TypeVar('FinalValue', bound=HasUniqueID)
|
FinalValue = TypeVar("FinalValue", bound=HasUniqueID)
|
||||||
IntermediateValue = TypeVar('IntermediateValue', bound=HasUniqueID)
|
IntermediateValue = TypeVar("IntermediateValue", bound=HasUniqueID)
|
||||||
|
|
||||||
IntermediateNode = TypeVar('IntermediateNode', bound=Any)
|
IntermediateNode = TypeVar("IntermediateNode", bound=Any)
|
||||||
FinalNode = TypeVar('FinalNode', bound=ManifestNodes)
|
FinalNode = TypeVar("FinalNode", bound=ManifestNodes)
|
||||||
|
|
||||||
|
|
||||||
ConfiguredBlockType = TypeVar('ConfiguredBlockType', bound=FileBlock)
|
ConfiguredBlockType = TypeVar("ConfiguredBlockType", bound=FileBlock)
|
||||||
|
|
||||||
|
|
||||||
class BaseParser(Generic[FinalValue]):
|
class BaseParser(Generic[FinalValue]):
|
||||||
@@ -73,9 +69,9 @@ class BaseParser(Generic[FinalValue]):
|
|||||||
|
|
||||||
def generate_unique_id(self, resource_name: str) -> str:
|
def generate_unique_id(self, resource_name: str) -> str:
|
||||||
"""Returns a unique identifier for a resource"""
|
"""Returns a unique identifier for a resource"""
|
||||||
return "{}.{}.{}".format(self.resource_type,
|
return "{}.{}.{}".format(
|
||||||
self.project.project_name,
|
self.resource_type, self.project.project_name, resource_name
|
||||||
resource_name)
|
)
|
||||||
|
|
||||||
def load_file(
|
def load_file(
|
||||||
self,
|
self,
|
||||||
@@ -89,7 +85,7 @@ class BaseParser(Generic[FinalValue]):
|
|||||||
if set_contents:
|
if set_contents:
|
||||||
source_file.contents = file_contents.strip()
|
source_file.contents = file_contents.strip()
|
||||||
else:
|
else:
|
||||||
source_file.contents = ''
|
source_file.contents = ""
|
||||||
return source_file
|
return source_file
|
||||||
|
|
||||||
|
|
||||||
@@ -108,8 +104,7 @@ class Parser(BaseParser[FinalValue], Generic[FinalValue]):
|
|||||||
|
|
||||||
class RelationUpdate:
|
class RelationUpdate:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, config: RuntimeConfig, macro_manifest: MacroManifest,
|
self, config: RuntimeConfig, macro_manifest: MacroManifest, component: str
|
||||||
component: str
|
|
||||||
) -> None:
|
) -> None:
|
||||||
macro = macro_manifest.find_generate_macro_by_name(
|
macro = macro_manifest.find_generate_macro_by_name(
|
||||||
component=component,
|
component=component,
|
||||||
@@ -117,7 +112,7 @@ class RelationUpdate:
|
|||||||
)
|
)
|
||||||
if macro is None:
|
if macro is None:
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
f'No macro with name generate_{component}_name found'
|
f"No macro with name generate_{component}_name found"
|
||||||
)
|
)
|
||||||
|
|
||||||
root_context = generate_generate_component_name_macro(
|
root_context = generate_generate_component_name_macro(
|
||||||
@@ -126,9 +121,7 @@ class RelationUpdate:
|
|||||||
self.updater = MacroGenerator(macro, root_context)
|
self.updater = MacroGenerator(macro, root_context)
|
||||||
self.component = component
|
self.component = component
|
||||||
|
|
||||||
def __call__(
|
def __call__(self, parsed_node: Any, config_dict: Dict[str, Any]) -> None:
|
||||||
self, parsed_node: Any, config_dict: Dict[str, Any]
|
|
||||||
) -> None:
|
|
||||||
override = config_dict.get(self.component)
|
override = config_dict.get(self.component)
|
||||||
new_value = self.updater(override, parsed_node)
|
new_value = self.updater(override, parsed_node)
|
||||||
if isinstance(new_value, str):
|
if isinstance(new_value, str):
|
||||||
@@ -150,16 +143,13 @@ class ConfiguredParser(
|
|||||||
super().__init__(results, project, root_project, macro_manifest)
|
super().__init__(results, project, root_project, macro_manifest)
|
||||||
|
|
||||||
self._update_node_database = RelationUpdate(
|
self._update_node_database = RelationUpdate(
|
||||||
macro_manifest=macro_manifest, config=root_project,
|
macro_manifest=macro_manifest, config=root_project, component="database"
|
||||||
component='database'
|
|
||||||
)
|
)
|
||||||
self._update_node_schema = RelationUpdate(
|
self._update_node_schema = RelationUpdate(
|
||||||
macro_manifest=macro_manifest, config=root_project,
|
macro_manifest=macro_manifest, config=root_project, component="schema"
|
||||||
component='schema'
|
|
||||||
)
|
)
|
||||||
self._update_node_alias = RelationUpdate(
|
self._update_node_alias = RelationUpdate(
|
||||||
macro_manifest=macro_manifest, config=root_project,
|
macro_manifest=macro_manifest, config=root_project, component="alias"
|
||||||
component='alias'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@abc.abstractclassmethod
|
@abc.abstractclassmethod
|
||||||
@@ -206,7 +196,11 @@ class ConfiguredParser(
|
|||||||
config[key] = [hooks.get_hook_dict(h) for h in config[key]]
|
config[key] = [hooks.get_hook_dict(h) for h in config[key]]
|
||||||
|
|
||||||
def _create_error_node(
|
def _create_error_node(
|
||||||
self, name: str, path: str, original_file_path: str, raw_sql: str,
|
self,
|
||||||
|
name: str,
|
||||||
|
path: str,
|
||||||
|
original_file_path: str,
|
||||||
|
raw_sql: str,
|
||||||
) -> UnparsedNode:
|
) -> UnparsedNode:
|
||||||
"""If we hit an error before we've actually parsed a node, provide some
|
"""If we hit an error before we've actually parsed a node, provide some
|
||||||
level of useful information by attaching this to the exception.
|
level of useful information by attaching this to the exception.
|
||||||
@@ -239,20 +233,20 @@ class ConfiguredParser(
|
|||||||
if name is None:
|
if name is None:
|
||||||
name = block.name
|
name = block.name
|
||||||
dct = {
|
dct = {
|
||||||
'alias': name,
|
"alias": name,
|
||||||
'schema': self.default_schema,
|
"schema": self.default_schema,
|
||||||
'database': self.default_database,
|
"database": self.default_database,
|
||||||
'fqn': fqn,
|
"fqn": fqn,
|
||||||
'name': name,
|
"name": name,
|
||||||
'root_path': self.project.project_root,
|
"root_path": self.project.project_root,
|
||||||
'resource_type': self.resource_type,
|
"resource_type": self.resource_type,
|
||||||
'path': path,
|
"path": path,
|
||||||
'original_file_path': block.path.original_file_path,
|
"original_file_path": block.path.original_file_path,
|
||||||
'package_name': self.project.project_name,
|
"package_name": self.project.project_name,
|
||||||
'raw_sql': block.contents,
|
"raw_sql": block.contents,
|
||||||
'unique_id': self.generate_unique_id(name),
|
"unique_id": self.generate_unique_id(name),
|
||||||
'config': self.config_dict(config),
|
"config": self.config_dict(config),
|
||||||
'checksum': block.file.checksum.to_dict(omit_none=True),
|
"checksum": block.file.checksum.to_dict(omit_none=True),
|
||||||
}
|
}
|
||||||
dct.update(kwargs)
|
dct.update(kwargs)
|
||||||
try:
|
try:
|
||||||
@@ -290,9 +284,7 @@ class ConfiguredParser(
|
|||||||
|
|
||||||
# this goes through the process of rendering, but just throws away
|
# this goes through the process of rendering, but just throws away
|
||||||
# the rendered result. The "macro capture" is the point?
|
# the rendered result. The "macro capture" is the point?
|
||||||
get_rendered(
|
get_rendered(parsed_node.raw_sql, context, parsed_node, capture_macros=True)
|
||||||
parsed_node.raw_sql, context, parsed_node, capture_macros=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# This is taking the original config for the node, converting it to a dict,
|
# This is taking the original config for the node, converting it to a dict,
|
||||||
# updating the config with new config passed in, then re-creating the
|
# updating the config with new config passed in, then re-creating the
|
||||||
@@ -324,12 +316,10 @@ class ConfiguredParser(
|
|||||||
config_dict = config.build_config_dict()
|
config_dict = config.build_config_dict()
|
||||||
|
|
||||||
# Set tags on node provided in config blocks
|
# Set tags on node provided in config blocks
|
||||||
model_tags = config_dict.get('tags', [])
|
model_tags = config_dict.get("tags", [])
|
||||||
parsed_node.tags.extend(model_tags)
|
parsed_node.tags.extend(model_tags)
|
||||||
|
|
||||||
parsed_node.unrendered_config = config.build_config_dict(
|
parsed_node.unrendered_config = config.build_config_dict(rendered=False)
|
||||||
rendered=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# do this once before we parse the node database/schema/alias, so
|
# do this once before we parse the node database/schema/alias, so
|
||||||
# parsed_node.config is what it would be if they did nothing
|
# parsed_node.config is what it would be if they did nothing
|
||||||
@@ -338,8 +328,9 @@ class ConfiguredParser(
|
|||||||
|
|
||||||
# at this point, we've collected our hooks. Use the node context to
|
# at this point, we've collected our hooks. Use the node context to
|
||||||
# render each hook and collect refs/sources
|
# render each hook and collect refs/sources
|
||||||
hooks = list(itertools.chain(parsed_node.config.pre_hook,
|
hooks = list(
|
||||||
parsed_node.config.post_hook))
|
itertools.chain(parsed_node.config.pre_hook, parsed_node.config.post_hook)
|
||||||
|
)
|
||||||
# skip context rebuilding if there aren't any hooks
|
# skip context rebuilding if there aren't any hooks
|
||||||
if not hooks:
|
if not hooks:
|
||||||
return
|
return
|
||||||
@@ -362,20 +353,18 @@ class ConfiguredParser(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
f'Got an unexpected project version={config_version}, '
|
f"Got an unexpected project version={config_version}, " f"expected 2"
|
||||||
f'expected 2'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def config_dict(
|
def config_dict(
|
||||||
self, config: ContextConfig,
|
self,
|
||||||
|
config: ContextConfig,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
config_dict = config.build_config_dict(base=True)
|
config_dict = config.build_config_dict(base=True)
|
||||||
self._mangle_hooks(config_dict)
|
self._mangle_hooks(config_dict)
|
||||||
return config_dict
|
return config_dict
|
||||||
|
|
||||||
def render_update(
|
def render_update(self, node: IntermediateNode, config: ContextConfig) -> None:
|
||||||
self, node: IntermediateNode, config: ContextConfig
|
|
||||||
) -> None:
|
|
||||||
try:
|
try:
|
||||||
self.render_with_context(node, config)
|
self.render_with_context(node, config)
|
||||||
self.update_parsed_node(node, config)
|
self.update_parsed_node(node, config)
|
||||||
@@ -418,7 +407,7 @@ class ConfiguredParser(
|
|||||||
|
|
||||||
class SimpleParser(
|
class SimpleParser(
|
||||||
ConfiguredParser[ConfiguredBlockType, FinalNode, FinalNode],
|
ConfiguredParser[ConfiguredBlockType, FinalNode, FinalNode],
|
||||||
Generic[ConfiguredBlockType, FinalNode]
|
Generic[ConfiguredBlockType, FinalNode],
|
||||||
):
|
):
|
||||||
def transform(self, node):
|
def transform(self, node):
|
||||||
return node
|
return node
|
||||||
@@ -426,14 +415,12 @@ class SimpleParser(
|
|||||||
|
|
||||||
class SQLParser(
|
class SQLParser(
|
||||||
ConfiguredParser[FileBlock, IntermediateNode, FinalNode],
|
ConfiguredParser[FileBlock, IntermediateNode, FinalNode],
|
||||||
Generic[IntermediateNode, FinalNode]
|
Generic[IntermediateNode, FinalNode],
|
||||||
):
|
):
|
||||||
def parse_file(self, file_block: FileBlock) -> None:
|
def parse_file(self, file_block: FileBlock) -> None:
|
||||||
self.parse_node(file_block)
|
self.parse_node(file_block)
|
||||||
|
|
||||||
|
|
||||||
class SimpleSQLParser(
|
class SimpleSQLParser(SQLParser[FinalNode, FinalNode]):
|
||||||
SQLParser[FinalNode, FinalNode]
|
|
||||||
):
|
|
||||||
def transform(self, node):
|
def transform(self, node):
|
||||||
return node
|
return node
|
||||||
|
|||||||
@@ -7,9 +7,7 @@ from dbt.utils import get_pseudo_test_path
|
|||||||
|
|
||||||
class DataTestParser(SimpleSQLParser[ParsedDataTestNode]):
|
class DataTestParser(SimpleSQLParser[ParsedDataTestNode]):
|
||||||
def get_paths(self):
|
def get_paths(self):
|
||||||
return FilesystemSearcher(
|
return FilesystemSearcher(self.project, self.project.test_paths, ".sql")
|
||||||
self.project, self.project.test_paths, '.sql'
|
|
||||||
)
|
|
||||||
|
|
||||||
def parse_from_dict(self, dct, validate=True) -> ParsedDataTestNode:
|
def parse_from_dict(self, dct, validate=True) -> ParsedDataTestNode:
|
||||||
if validate:
|
if validate:
|
||||||
@@ -21,11 +19,10 @@ class DataTestParser(SimpleSQLParser[ParsedDataTestNode]):
|
|||||||
return NodeType.Test
|
return NodeType.Test
|
||||||
|
|
||||||
def transform(self, node):
|
def transform(self, node):
|
||||||
if 'data' not in node.tags:
|
if "data" not in node.tags:
|
||||||
node.tags.append('data')
|
node.tags.append("data")
|
||||||
return node
|
return node
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_compiled_path(cls, block: FileBlock):
|
def get_compiled_path(cls, block: FileBlock):
|
||||||
return get_pseudo_test_path(block.name, block.path.relative_path,
|
return get_pseudo_test_path(block.name, block.path.relative_path, "data_test")
|
||||||
'data_test')
|
|
||||||
|
|||||||
@@ -7,11 +7,14 @@ from dbt.contracts.graph.parsed import ParsedDocumentation
|
|||||||
from dbt.node_types import NodeType
|
from dbt.node_types import NodeType
|
||||||
from dbt.parser.base import Parser
|
from dbt.parser.base import Parser
|
||||||
from dbt.parser.search import (
|
from dbt.parser.search import (
|
||||||
BlockContents, FileBlock, FilesystemSearcher, BlockSearcher
|
BlockContents,
|
||||||
|
FileBlock,
|
||||||
|
FilesystemSearcher,
|
||||||
|
BlockSearcher,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
SHOULD_PARSE_RE = re.compile(r'{[{%]')
|
SHOULD_PARSE_RE = re.compile(r"{[{%]")
|
||||||
|
|
||||||
|
|
||||||
class DocumentationParser(Parser[ParsedDocumentation]):
|
class DocumentationParser(Parser[ParsedDocumentation]):
|
||||||
@@ -19,7 +22,7 @@ class DocumentationParser(Parser[ParsedDocumentation]):
|
|||||||
return FilesystemSearcher(
|
return FilesystemSearcher(
|
||||||
project=self.project,
|
project=self.project,
|
||||||
relative_dirs=self.project.docs_paths,
|
relative_dirs=self.project.docs_paths,
|
||||||
extension='.md',
|
extension=".md",
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -33,11 +36,9 @@ class DocumentationParser(Parser[ParsedDocumentation]):
|
|||||||
def generate_unique_id(self, resource_name: str) -> str:
|
def generate_unique_id(self, resource_name: str) -> str:
|
||||||
# because docs are in their own graph namespace, node type doesn't
|
# because docs are in their own graph namespace, node type doesn't
|
||||||
# need to be part of the unique ID.
|
# need to be part of the unique ID.
|
||||||
return '{}.{}'.format(self.project.project_name, resource_name)
|
return "{}.{}".format(self.project.project_name, resource_name)
|
||||||
|
|
||||||
def parse_block(
|
def parse_block(self, block: BlockContents) -> Iterable[ParsedDocumentation]:
|
||||||
self, block: BlockContents
|
|
||||||
) -> Iterable[ParsedDocumentation]:
|
|
||||||
unique_id = self.generate_unique_id(block.name)
|
unique_id = self.generate_unique_id(block.name)
|
||||||
contents = get_rendered(block.contents, {}).strip()
|
contents = get_rendered(block.contents, {}).strip()
|
||||||
|
|
||||||
@@ -55,7 +56,7 @@ class DocumentationParser(Parser[ParsedDocumentation]):
|
|||||||
def parse_file(self, file_block: FileBlock):
|
def parse_file(self, file_block: FileBlock):
|
||||||
searcher: Iterable[BlockContents] = BlockSearcher(
|
searcher: Iterable[BlockContents] = BlockSearcher(
|
||||||
source=[file_block],
|
source=[file_block],
|
||||||
allowed_blocks={'docs'},
|
allowed_blocks={"docs"},
|
||||||
source_tag_factory=BlockContents,
|
source_tag_factory=BlockContents,
|
||||||
)
|
)
|
||||||
for block in searcher:
|
for block in searcher:
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ class HookBlock(FileBlock):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
return '{}-{!s}-{!s}'.format(self.project, self.hook_type, self.index)
|
return "{}-{!s}-{!s}".format(self.project, self.hook_type, self.index)
|
||||||
|
|
||||||
|
|
||||||
class HookSearcher(Iterable[HookBlock]):
|
class HookSearcher(Iterable[HookBlock]):
|
||||||
@@ -33,9 +33,7 @@ class HookSearcher(Iterable[HookBlock]):
|
|||||||
self.source_file = source_file
|
self.source_file = source_file
|
||||||
self.hook_type = hook_type
|
self.hook_type = hook_type
|
||||||
|
|
||||||
def _hook_list(
|
def _hook_list(self, hooks: Union[str, List[str], Tuple[str, ...]]) -> List[str]:
|
||||||
self, hooks: Union[str, List[str], Tuple[str, ...]]
|
|
||||||
) -> List[str]:
|
|
||||||
if isinstance(hooks, tuple):
|
if isinstance(hooks, tuple):
|
||||||
hooks = list(hooks)
|
hooks = list(hooks)
|
||||||
elif not isinstance(hooks, list):
|
elif not isinstance(hooks, list):
|
||||||
@@ -49,8 +47,9 @@ class HookSearcher(Iterable[HookBlock]):
|
|||||||
hooks = self.project.on_run_end
|
hooks = self.project.on_run_end
|
||||||
else:
|
else:
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
'hook_type must be one of "{}" or "{}" (got {})'
|
'hook_type must be one of "{}" or "{}" (got {})'.format(
|
||||||
.format(RunHookType.Start, RunHookType.End, self.hook_type)
|
RunHookType.Start, RunHookType.End, self.hook_type
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return self._hook_list(hooks)
|
return self._hook_list(hooks)
|
||||||
|
|
||||||
@@ -73,8 +72,8 @@ class HookParser(SimpleParser[HookBlock, ParsedHookNode]):
|
|||||||
def get_paths(self) -> List[FilePath]:
|
def get_paths(self) -> List[FilePath]:
|
||||||
path = FilePath(
|
path = FilePath(
|
||||||
project_root=self.project.project_root,
|
project_root=self.project.project_root,
|
||||||
searched_path='.',
|
searched_path=".",
|
||||||
relative_path='dbt_project.yml',
|
relative_path="dbt_project.yml",
|
||||||
)
|
)
|
||||||
return [path]
|
return [path]
|
||||||
|
|
||||||
@@ -98,9 +97,13 @@ class HookParser(SimpleParser[HookBlock, ParsedHookNode]):
|
|||||||
) -> ParsedHookNode:
|
) -> ParsedHookNode:
|
||||||
|
|
||||||
return super()._create_parsetime_node(
|
return super()._create_parsetime_node(
|
||||||
block=block, path=path, config=config, fqn=fqn,
|
block=block,
|
||||||
index=block.index, name=name,
|
path=path,
|
||||||
tags=[str(block.hook_type)]
|
config=config,
|
||||||
|
fqn=fqn,
|
||||||
|
index=block.index,
|
||||||
|
name=name,
|
||||||
|
tags=[str(block.hook_type)],
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ class MacroParser(BaseParser[ParsedMacro]):
|
|||||||
return FilesystemSearcher(
|
return FilesystemSearcher(
|
||||||
project=self.project,
|
project=self.project,
|
||||||
relative_dirs=self.project.macro_paths,
|
relative_dirs=self.project.macro_paths,
|
||||||
extension='.sql',
|
extension=".sql",
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -45,15 +45,13 @@ class MacroParser(BaseParser[ParsedMacro]):
|
|||||||
unique_id=unique_id,
|
unique_id=unique_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def parse_unparsed_macros(
|
def parse_unparsed_macros(self, base_node: UnparsedMacro) -> Iterable[ParsedMacro]:
|
||||||
self, base_node: UnparsedMacro
|
|
||||||
) -> Iterable[ParsedMacro]:
|
|
||||||
try:
|
try:
|
||||||
blocks: List[jinja.BlockTag] = [
|
blocks: List[jinja.BlockTag] = [
|
||||||
t for t in
|
t
|
||||||
jinja.extract_toplevel_blocks(
|
for t in jinja.extract_toplevel_blocks(
|
||||||
base_node.raw_sql,
|
base_node.raw_sql,
|
||||||
allowed_blocks={'macro', 'materialization'},
|
allowed_blocks={"macro", "materialization"},
|
||||||
collect_raw_data=False,
|
collect_raw_data=False,
|
||||||
)
|
)
|
||||||
if isinstance(t, jinja.BlockTag)
|
if isinstance(t, jinja.BlockTag)
|
||||||
@@ -75,8 +73,8 @@ class MacroParser(BaseParser[ParsedMacro]):
|
|||||||
# things have gone disastrously wrong, we thought we only
|
# things have gone disastrously wrong, we thought we only
|
||||||
# parsed one block!
|
# parsed one block!
|
||||||
raise CompilationException(
|
raise CompilationException(
|
||||||
f'Found multiple macros in {block.full_block}, expected 1',
|
f"Found multiple macros in {block.full_block}, expected 1",
|
||||||
node=base_node
|
node=base_node,
|
||||||
)
|
)
|
||||||
|
|
||||||
macro_name = macro_nodes[0].name
|
macro_name = macro_nodes[0].name
|
||||||
@@ -84,7 +82,7 @@ class MacroParser(BaseParser[ParsedMacro]):
|
|||||||
if not macro_name.startswith(MACRO_PREFIX):
|
if not macro_name.startswith(MACRO_PREFIX):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
name: str = macro_name.replace(MACRO_PREFIX, '')
|
name: str = macro_name.replace(MACRO_PREFIX, "")
|
||||||
node = self.parse_macro(block, base_node, name)
|
node = self.parse_macro(block, base_node, name)
|
||||||
yield node
|
yield node
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,15 @@ from dataclasses import field
|
|||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
from typing import (
|
from typing import (
|
||||||
Dict, Optional, Mapping, Callable, Any, List, Type, Union, MutableMapping
|
Dict,
|
||||||
|
Optional,
|
||||||
|
Mapping,
|
||||||
|
Callable,
|
||||||
|
Any,
|
||||||
|
List,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
MutableMapping,
|
||||||
)
|
)
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@@ -23,11 +31,13 @@ from dbt.config import Project, RuntimeConfig
|
|||||||
from dbt.context.docs import generate_runtime_docs
|
from dbt.context.docs import generate_runtime_docs
|
||||||
from dbt.contracts.files import FilePath, FileHash
|
from dbt.contracts.files import FilePath, FileHash
|
||||||
from dbt.contracts.graph.compiled import ManifestNode
|
from dbt.contracts.graph.compiled import ManifestNode
|
||||||
from dbt.contracts.graph.manifest import (
|
from dbt.contracts.graph.manifest import Manifest, MacroManifest, AnyManifest, Disabled
|
||||||
Manifest, MacroManifest, AnyManifest, Disabled
|
|
||||||
)
|
|
||||||
from dbt.contracts.graph.parsed import (
|
from dbt.contracts.graph.parsed import (
|
||||||
ParsedSourceDefinition, ParsedNode, ParsedMacro, ColumnInfo, ParsedExposure
|
ParsedSourceDefinition,
|
||||||
|
ParsedNode,
|
||||||
|
ParsedMacro,
|
||||||
|
ColumnInfo,
|
||||||
|
ParsedExposure,
|
||||||
)
|
)
|
||||||
from dbt.contracts.util import Writable
|
from dbt.contracts.util import Writable
|
||||||
from dbt.exceptions import (
|
from dbt.exceptions import (
|
||||||
@@ -55,8 +65,8 @@ from dbt.version import __version__
|
|||||||
|
|
||||||
from dbt.dataclass_schema import dbtClassMixin
|
from dbt.dataclass_schema import dbtClassMixin
|
||||||
|
|
||||||
PARTIAL_PARSE_FILE_NAME = 'partial_parse.pickle'
|
PARTIAL_PARSE_FILE_NAME = "partial_parse.pickle"
|
||||||
PARSING_STATE = DbtProcessState('parsing')
|
PARSING_STATE = DbtProcessState("parsing")
|
||||||
DEFAULT_PARTIAL_PARSE = False
|
DEFAULT_PARTIAL_PARSE = False
|
||||||
|
|
||||||
|
|
||||||
@@ -110,20 +120,22 @@ def make_parse_result(
|
|||||||
"""Make a ParseResult from the project configuration and the profile."""
|
"""Make a ParseResult from the project configuration and the profile."""
|
||||||
# if any of these change, we need to reject the parser
|
# if any of these change, we need to reject the parser
|
||||||
vars_hash = FileHash.from_contents(
|
vars_hash = FileHash.from_contents(
|
||||||
'\x00'.join([
|
"\x00".join(
|
||||||
getattr(config.args, 'vars', '{}') or '{}',
|
[
|
||||||
getattr(config.args, 'profile', '') or '',
|
getattr(config.args, "vars", "{}") or "{}",
|
||||||
getattr(config.args, 'target', '') or '',
|
getattr(config.args, "profile", "") or "",
|
||||||
__version__
|
getattr(config.args, "target", "") or "",
|
||||||
])
|
__version__,
|
||||||
|
]
|
||||||
)
|
)
|
||||||
profile_path = os.path.join(config.args.profiles_dir, 'profiles.yml')
|
)
|
||||||
|
profile_path = os.path.join(config.args.profiles_dir, "profiles.yml")
|
||||||
with open(profile_path) as fp:
|
with open(profile_path) as fp:
|
||||||
profile_hash = FileHash.from_contents(fp.read())
|
profile_hash = FileHash.from_contents(fp.read())
|
||||||
|
|
||||||
project_hashes = {}
|
project_hashes = {}
|
||||||
for name, project in all_projects.items():
|
for name, project in all_projects.items():
|
||||||
path = os.path.join(project.project_root, 'dbt_project.yml')
|
path = os.path.join(project.project_root, "dbt_project.yml")
|
||||||
with open(path) as fp:
|
with open(path) as fp:
|
||||||
project_hashes[name] = FileHash.from_contents(fp.read())
|
project_hashes[name] = FileHash.from_contents(fp.read())
|
||||||
|
|
||||||
@@ -153,7 +165,8 @@ class ManifestLoader:
|
|||||||
# in dictionaries: nodes, sources, docs, macros, exposures,
|
# in dictionaries: nodes, sources, docs, macros, exposures,
|
||||||
# macro_patches, patches, source_patches, files, etc
|
# macro_patches, patches, source_patches, files, etc
|
||||||
self.results: ParseResult = make_parse_result(
|
self.results: ParseResult = make_parse_result(
|
||||||
root_project, all_projects,
|
root_project,
|
||||||
|
all_projects,
|
||||||
)
|
)
|
||||||
self._loaded_file_cache: Dict[str, FileBlock] = {}
|
self._loaded_file_cache: Dict[str, FileBlock] = {}
|
||||||
self._perf_info = ManifestLoaderInfo(
|
self._perf_info = ManifestLoaderInfo(
|
||||||
@@ -162,20 +175,18 @@ class ManifestLoader:
|
|||||||
|
|
||||||
def track_project_load(self):
|
def track_project_load(self):
|
||||||
invocation_id = dbt.tracking.active_user.invocation_id
|
invocation_id = dbt.tracking.active_user.invocation_id
|
||||||
dbt.tracking.track_project_load({
|
dbt.tracking.track_project_load(
|
||||||
|
{
|
||||||
"invocation_id": invocation_id,
|
"invocation_id": invocation_id,
|
||||||
"project_id": self.root_project.hashed_name(),
|
"project_id": self.root_project.hashed_name(),
|
||||||
"path_count": self._perf_info.path_count,
|
"path_count": self._perf_info.path_count,
|
||||||
"parse_project_elapsed": self._perf_info.parse_project_elapsed,
|
"parse_project_elapsed": self._perf_info.parse_project_elapsed,
|
||||||
"patch_sources_elapsed": self._perf_info.patch_sources_elapsed,
|
"patch_sources_elapsed": self._perf_info.patch_sources_elapsed,
|
||||||
"process_manifest_elapsed": (
|
"process_manifest_elapsed": (self._perf_info.process_manifest_elapsed),
|
||||||
self._perf_info.process_manifest_elapsed
|
|
||||||
),
|
|
||||||
"load_all_elapsed": self._perf_info.load_all_elapsed,
|
"load_all_elapsed": self._perf_info.load_all_elapsed,
|
||||||
"is_partial_parse_enabled": (
|
"is_partial_parse_enabled": (self._perf_info.is_partial_parse_enabled),
|
||||||
self._perf_info.is_partial_parse_enabled
|
}
|
||||||
),
|
)
|
||||||
})
|
|
||||||
|
|
||||||
def parse_with_cache(
|
def parse_with_cache(
|
||||||
self,
|
self,
|
||||||
@@ -220,8 +231,7 @@ class ManifestLoader:
|
|||||||
) -> None:
|
) -> None:
|
||||||
parsers: List[Parser] = []
|
parsers: List[Parser] = []
|
||||||
for cls in _parser_types:
|
for cls in _parser_types:
|
||||||
parser = cls(self.results, project, self.root_project,
|
parser = cls(self.results, project, self.root_project, macro_manifest)
|
||||||
macro_manifest)
|
|
||||||
parsers.append(parser)
|
parsers.append(parser)
|
||||||
|
|
||||||
# per-project cache.
|
# per-project cache.
|
||||||
@@ -238,11 +248,13 @@ class ManifestLoader:
|
|||||||
parser_path_count = parser_path_count + 1
|
parser_path_count = parser_path_count + 1
|
||||||
|
|
||||||
if parser_path_count > 0:
|
if parser_path_count > 0:
|
||||||
project_parser_info.append(ParserInfo(
|
project_parser_info.append(
|
||||||
|
ParserInfo(
|
||||||
parser=parser.resource_type,
|
parser=parser.resource_type,
|
||||||
path_count=parser_path_count,
|
path_count=parser_path_count,
|
||||||
elapsed=time.perf_counter() - parser_start_timer
|
elapsed=time.perf_counter() - parser_start_timer,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
total_path_count = total_path_count + parser_path_count
|
total_path_count = total_path_count + parser_path_count
|
||||||
|
|
||||||
elapsed = time.perf_counter() - start_timer
|
elapsed = time.perf_counter() - start_timer
|
||||||
@@ -250,12 +262,10 @@ class ManifestLoader:
|
|||||||
project_name=project.project_name,
|
project_name=project.project_name,
|
||||||
path_count=total_path_count,
|
path_count=total_path_count,
|
||||||
elapsed=elapsed,
|
elapsed=elapsed,
|
||||||
parsers=project_parser_info
|
parsers=project_parser_info,
|
||||||
)
|
)
|
||||||
self._perf_info.projects.append(project_info)
|
self._perf_info.projects.append(project_info)
|
||||||
self._perf_info.path_count = (
|
self._perf_info.path_count = self._perf_info.path_count + total_path_count
|
||||||
self._perf_info.path_count + total_path_count
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_only_macros(self) -> MacroManifest:
|
def load_only_macros(self) -> MacroManifest:
|
||||||
old_results = self.read_parse_results()
|
old_results = self.read_parse_results()
|
||||||
@@ -267,8 +277,7 @@ class ManifestLoader:
|
|||||||
|
|
||||||
# make a manifest with just the macros to get the context
|
# make a manifest with just the macros to get the context
|
||||||
macro_manifest = MacroManifest(
|
macro_manifest = MacroManifest(
|
||||||
macros=self.results.macros,
|
macros=self.results.macros, files=self.results.files
|
||||||
files=self.results.files
|
|
||||||
)
|
)
|
||||||
self.macro_hook(macro_manifest)
|
self.macro_hook(macro_manifest)
|
||||||
return macro_manifest
|
return macro_manifest
|
||||||
@@ -278,7 +287,7 @@ class ManifestLoader:
|
|||||||
# if partial parse is enabled, load old results
|
# if partial parse is enabled, load old results
|
||||||
old_results = self.read_parse_results()
|
old_results = self.read_parse_results()
|
||||||
if old_results is not None:
|
if old_results is not None:
|
||||||
logger.debug('Got an acceptable cached parse result')
|
logger.debug("Got an acceptable cached parse result")
|
||||||
# store the macros & files from the adapter macro manifest
|
# store the macros & files from the adapter macro manifest
|
||||||
self.results.macros.update(macro_manifest.macros)
|
self.results.macros.update(macro_manifest.macros)
|
||||||
self.results.files.update(macro_manifest.files)
|
self.results.files.update(macro_manifest.files)
|
||||||
@@ -289,15 +298,12 @@ class ManifestLoader:
|
|||||||
# parse a single project
|
# parse a single project
|
||||||
self.parse_project(project, macro_manifest, old_results)
|
self.parse_project(project, macro_manifest, old_results)
|
||||||
|
|
||||||
self._perf_info.parse_project_elapsed = (
|
self._perf_info.parse_project_elapsed = time.perf_counter() - start_timer
|
||||||
time.perf_counter() - start_timer
|
|
||||||
)
|
|
||||||
|
|
||||||
def write_parse_results(self):
|
def write_parse_results(self):
|
||||||
path = os.path.join(self.root_project.target_path,
|
path = os.path.join(self.root_project.target_path, PARTIAL_PARSE_FILE_NAME)
|
||||||
PARTIAL_PARSE_FILE_NAME)
|
|
||||||
make_directory(self.root_project.target_path)
|
make_directory(self.root_project.target_path)
|
||||||
with open(path, 'wb') as fp:
|
with open(path, "wb") as fp:
|
||||||
pickle.dump(self.results, fp)
|
pickle.dump(self.results, fp)
|
||||||
|
|
||||||
def matching_parse_results(self, result: ParseResult) -> bool:
|
def matching_parse_results(self, result: ParseResult) -> bool:
|
||||||
@@ -307,31 +313,32 @@ class ManifestLoader:
|
|||||||
try:
|
try:
|
||||||
if result.dbt_version != __version__:
|
if result.dbt_version != __version__:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
'dbt version mismatch: {} != {}, cache invalidated'
|
"dbt version mismatch: {} != {}, cache invalidated".format(
|
||||||
.format(result.dbt_version, __version__)
|
result.dbt_version, __version__
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
logger.debug('malformed result file, cache invalidated')
|
logger.debug("malformed result file, cache invalidated")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
valid = True
|
valid = True
|
||||||
|
|
||||||
if self.results.vars_hash != result.vars_hash:
|
if self.results.vars_hash != result.vars_hash:
|
||||||
logger.debug('vars hash mismatch, cache invalidated')
|
logger.debug("vars hash mismatch, cache invalidated")
|
||||||
valid = False
|
valid = False
|
||||||
if self.results.profile_hash != result.profile_hash:
|
if self.results.profile_hash != result.profile_hash:
|
||||||
logger.debug('profile hash mismatch, cache invalidated')
|
logger.debug("profile hash mismatch, cache invalidated")
|
||||||
valid = False
|
valid = False
|
||||||
|
|
||||||
missing_keys = {
|
missing_keys = {
|
||||||
k for k in self.results.project_hashes
|
k for k in self.results.project_hashes if k not in result.project_hashes
|
||||||
if k not in result.project_hashes
|
|
||||||
}
|
}
|
||||||
if missing_keys:
|
if missing_keys:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
'project hash mismatch: values missing, cache invalidated: {}'
|
"project hash mismatch: values missing, cache invalidated: {}".format(
|
||||||
.format(missing_keys)
|
missing_keys
|
||||||
|
)
|
||||||
)
|
)
|
||||||
valid = False
|
valid = False
|
||||||
|
|
||||||
@@ -340,9 +347,8 @@ class ManifestLoader:
|
|||||||
old_value = result.project_hashes[key]
|
old_value = result.project_hashes[key]
|
||||||
if new_value != old_value:
|
if new_value != old_value:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
'For key {}, hash mismatch ({} -> {}), cache '
|
"For key {}, hash mismatch ({} -> {}), cache "
|
||||||
'invalidated'
|
"invalidated".format(key, old_value, new_value)
|
||||||
.format(key, old_value, new_value)
|
|
||||||
)
|
)
|
||||||
valid = False
|
valid = False
|
||||||
return valid
|
return valid
|
||||||
@@ -359,14 +365,13 @@ class ManifestLoader:
|
|||||||
|
|
||||||
def read_parse_results(self) -> Optional[ParseResult]:
|
def read_parse_results(self) -> Optional[ParseResult]:
|
||||||
if not self._partial_parse_enabled():
|
if not self._partial_parse_enabled():
|
||||||
logger.debug('Partial parsing not enabled')
|
logger.debug("Partial parsing not enabled")
|
||||||
return None
|
return None
|
||||||
path = os.path.join(self.root_project.target_path,
|
path = os.path.join(self.root_project.target_path, PARTIAL_PARSE_FILE_NAME)
|
||||||
PARTIAL_PARSE_FILE_NAME)
|
|
||||||
|
|
||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
try:
|
try:
|
||||||
with open(path, 'rb') as fp:
|
with open(path, "rb") as fp:
|
||||||
result: ParseResult = pickle.load(fp)
|
result: ParseResult = pickle.load(fp)
|
||||||
# keep this check inside the try/except in case something about
|
# keep this check inside the try/except in case something about
|
||||||
# the file has changed in weird ways, perhaps due to being a
|
# the file has changed in weird ways, perhaps due to being a
|
||||||
@@ -375,9 +380,8 @@ class ManifestLoader:
|
|||||||
return result
|
return result
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
'Failed to load parsed file from disk at {}: {}'
|
"Failed to load parsed file from disk at {}: {}".format(path, exc),
|
||||||
.format(path, exc),
|
exc_info=True,
|
||||||
exc_info=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
@@ -394,9 +398,7 @@ class ManifestLoader:
|
|||||||
# list is created
|
# list is created
|
||||||
start_patch = time.perf_counter()
|
start_patch = time.perf_counter()
|
||||||
sources = patch_sources(self.results, self.root_project)
|
sources = patch_sources(self.results, self.root_project)
|
||||||
self._perf_info.patch_sources_elapsed = (
|
self._perf_info.patch_sources_elapsed = time.perf_counter() - start_patch
|
||||||
time.perf_counter() - start_patch
|
|
||||||
)
|
|
||||||
disabled = []
|
disabled = []
|
||||||
for value in self.results.disabled.values():
|
for value in self.results.disabled.values():
|
||||||
disabled.extend(value)
|
disabled.extend(value)
|
||||||
@@ -421,9 +423,7 @@ class ManifestLoader:
|
|||||||
start_process = time.perf_counter()
|
start_process = time.perf_counter()
|
||||||
self.process_manifest(manifest)
|
self.process_manifest(manifest)
|
||||||
|
|
||||||
self._perf_info.process_manifest_elapsed = (
|
self._perf_info.process_manifest_elapsed = time.perf_counter() - start_process
|
||||||
time.perf_counter() - start_process
|
|
||||||
)
|
|
||||||
|
|
||||||
return manifest
|
return manifest
|
||||||
|
|
||||||
@@ -445,9 +445,7 @@ class ManifestLoader:
|
|||||||
_check_manifest(manifest, root_config)
|
_check_manifest(manifest, root_config)
|
||||||
manifest.build_flat_graph()
|
manifest.build_flat_graph()
|
||||||
|
|
||||||
loader._perf_info.load_all_elapsed = (
|
loader._perf_info.load_all_elapsed = time.perf_counter() - start_load_all
|
||||||
time.perf_counter() - start_load_all
|
|
||||||
)
|
|
||||||
|
|
||||||
loader.track_project_load()
|
loader.track_project_load()
|
||||||
|
|
||||||
@@ -465,8 +463,9 @@ class ManifestLoader:
|
|||||||
return loader.load_only_macros()
|
return loader.load_only_macros()
|
||||||
|
|
||||||
|
|
||||||
def invalid_ref_fail_unless_test(node, target_model_name,
|
def invalid_ref_fail_unless_test(
|
||||||
target_model_package, disabled):
|
node, target_model_name, target_model_package, disabled
|
||||||
|
):
|
||||||
|
|
||||||
if node.resource_type == NodeType.Test:
|
if node.resource_type == NodeType.Test:
|
||||||
msg = get_target_not_found_or_disabled_msg(
|
msg = get_target_not_found_or_disabled_msg(
|
||||||
@@ -475,10 +474,7 @@ def invalid_ref_fail_unless_test(node, target_model_name,
|
|||||||
if disabled:
|
if disabled:
|
||||||
logger.debug(warning_tag(msg))
|
logger.debug(warning_tag(msg))
|
||||||
else:
|
else:
|
||||||
warn_or_error(
|
warn_or_error(msg, log_fmt=warning_tag("{}"))
|
||||||
msg,
|
|
||||||
log_fmt=warning_tag('{}')
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
ref_target_not_found(
|
ref_target_not_found(
|
||||||
node,
|
node,
|
||||||
@@ -488,9 +484,7 @@ def invalid_ref_fail_unless_test(node, target_model_name,
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def invalid_source_fail_unless_test(
|
def invalid_source_fail_unless_test(node, target_name, target_table_name, disabled):
|
||||||
node, target_name, target_table_name, disabled
|
|
||||||
):
|
|
||||||
if node.resource_type == NodeType.Test:
|
if node.resource_type == NodeType.Test:
|
||||||
msg = get_source_not_found_or_disabled_msg(
|
msg = get_source_not_found_or_disabled_msg(
|
||||||
node, target_name, target_table_name, disabled
|
node, target_name, target_table_name, disabled
|
||||||
@@ -498,17 +492,9 @@ def invalid_source_fail_unless_test(
|
|||||||
if disabled:
|
if disabled:
|
||||||
logger.debug(warning_tag(msg))
|
logger.debug(warning_tag(msg))
|
||||||
else:
|
else:
|
||||||
warn_or_error(
|
warn_or_error(msg, log_fmt=warning_tag("{}"))
|
||||||
msg,
|
|
||||||
log_fmt=warning_tag('{}')
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
source_target_not_found(
|
source_target_not_found(node, target_name, target_table_name, disabled=disabled)
|
||||||
node,
|
|
||||||
target_name,
|
|
||||||
target_table_name,
|
|
||||||
disabled=disabled
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _check_resource_uniqueness(
|
def _check_resource_uniqueness(
|
||||||
@@ -532,15 +518,11 @@ def _check_resource_uniqueness(
|
|||||||
|
|
||||||
existing_node = names_resources.get(name)
|
existing_node = names_resources.get(name)
|
||||||
if existing_node is not None:
|
if existing_node is not None:
|
||||||
dbt.exceptions.raise_duplicate_resource_name(
|
dbt.exceptions.raise_duplicate_resource_name(existing_node, node)
|
||||||
existing_node, node
|
|
||||||
)
|
|
||||||
|
|
||||||
existing_alias = alias_resources.get(full_node_name)
|
existing_alias = alias_resources.get(full_node_name)
|
||||||
if existing_alias is not None:
|
if existing_alias is not None:
|
||||||
dbt.exceptions.raise_ambiguous_alias(
|
dbt.exceptions.raise_ambiguous_alias(existing_alias, node, full_node_name)
|
||||||
existing_alias, node, full_node_name
|
|
||||||
)
|
|
||||||
|
|
||||||
names_resources[name] = node
|
names_resources[name] = node
|
||||||
alias_resources[full_node_name] = node
|
alias_resources[full_node_name] = node
|
||||||
@@ -565,8 +547,7 @@ def _load_projects(config, paths):
|
|||||||
project = config.new_project(path)
|
project = config.new_project(path)
|
||||||
except dbt.exceptions.DbtProjectError as e:
|
except dbt.exceptions.DbtProjectError as e:
|
||||||
raise dbt.exceptions.DbtProjectError(
|
raise dbt.exceptions.DbtProjectError(
|
||||||
'Failed to read package at {}: {}'
|
"Failed to read package at {}: {}".format(path, e)
|
||||||
.format(path, e)
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
yield project.project_name, project
|
yield project.project_name, project
|
||||||
@@ -587,8 +568,7 @@ def _get_node_column(node, column_name):
|
|||||||
|
|
||||||
|
|
||||||
DocsContextCallback = Callable[
|
DocsContextCallback = Callable[
|
||||||
[Union[ParsedNode, ParsedSourceDefinition]],
|
[Union[ParsedNode, ParsedSourceDefinition]], Dict[str, Any]
|
||||||
Dict[str, Any]
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -618,9 +598,7 @@ def _process_docs_for_source(
|
|||||||
column.description = column_desc
|
column.description = column_desc
|
||||||
|
|
||||||
|
|
||||||
def _process_docs_for_macro(
|
def _process_docs_for_macro(context: Dict[str, Any], macro: ParsedMacro) -> None:
|
||||||
context: Dict[str, Any], macro: ParsedMacro
|
|
||||||
) -> None:
|
|
||||||
macro.description = get_rendered(macro.description, context)
|
macro.description = get_rendered(macro.description, context)
|
||||||
for arg in macro.arguments:
|
for arg in macro.arguments:
|
||||||
arg.description = get_rendered(arg.description, context)
|
arg.description = get_rendered(arg.description, context)
|
||||||
@@ -682,7 +660,7 @@ def _process_refs_for_exposure(
|
|||||||
target_model_package, target_model_name = ref
|
target_model_package, target_model_name = ref
|
||||||
else:
|
else:
|
||||||
raise dbt.exceptions.InternalException(
|
raise dbt.exceptions.InternalException(
|
||||||
f'Refs should always be 1 or 2 arguments - got {len(ref)}'
|
f"Refs should always be 1 or 2 arguments - got {len(ref)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
target_model = manifest.resolve_ref(
|
target_model = manifest.resolve_ref(
|
||||||
@@ -696,8 +674,10 @@ def _process_refs_for_exposure(
|
|||||||
# This may raise. Even if it doesn't, we don't want to add
|
# This may raise. Even if it doesn't, we don't want to add
|
||||||
# this exposure to the graph b/c there is no destination exposure
|
# this exposure to the graph b/c there is no destination exposure
|
||||||
invalid_ref_fail_unless_test(
|
invalid_ref_fail_unless_test(
|
||||||
exposure, target_model_name, target_model_package,
|
exposure,
|
||||||
disabled=(isinstance(target_model, Disabled))
|
target_model_name,
|
||||||
|
target_model_package,
|
||||||
|
disabled=(isinstance(target_model, Disabled)),
|
||||||
)
|
)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
@@ -723,7 +703,7 @@ def _process_refs_for_node(
|
|||||||
target_model_package, target_model_name = ref
|
target_model_package, target_model_name = ref
|
||||||
else:
|
else:
|
||||||
raise dbt.exceptions.InternalException(
|
raise dbt.exceptions.InternalException(
|
||||||
f'Refs should always be 1 or 2 arguments - got {len(ref)}'
|
f"Refs should always be 1 or 2 arguments - got {len(ref)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
target_model = manifest.resolve_ref(
|
target_model = manifest.resolve_ref(
|
||||||
@@ -738,8 +718,10 @@ def _process_refs_for_node(
|
|||||||
# this node to the graph b/c there is no destination node
|
# this node to the graph b/c there is no destination node
|
||||||
node.config.enabled = False
|
node.config.enabled = False
|
||||||
invalid_ref_fail_unless_test(
|
invalid_ref_fail_unless_test(
|
||||||
node, target_model_name, target_model_package,
|
node,
|
||||||
disabled=(isinstance(target_model, Disabled))
|
target_model_name,
|
||||||
|
target_model_package,
|
||||||
|
disabled=(isinstance(target_model, Disabled)),
|
||||||
)
|
)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
@@ -777,7 +759,7 @@ def _process_sources_for_exposure(
|
|||||||
exposure,
|
exposure,
|
||||||
source_name,
|
source_name,
|
||||||
table_name,
|
table_name,
|
||||||
disabled=(isinstance(target_source, Disabled))
|
disabled=(isinstance(target_source, Disabled)),
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
target_source_id = target_source.unique_id
|
target_source_id = target_source.unique_id
|
||||||
@@ -804,7 +786,7 @@ def _process_sources_for_node(
|
|||||||
node,
|
node,
|
||||||
source_name,
|
source_name,
|
||||||
table_name,
|
table_name,
|
||||||
disabled=(isinstance(target_source, Disabled))
|
disabled=(isinstance(target_source, Disabled)),
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
target_source_id = target_source.unique_id
|
target_source_id = target_source.unique_id
|
||||||
@@ -835,13 +817,9 @@ def process_macro(
|
|||||||
_process_docs_for_macro(ctx, macro)
|
_process_docs_for_macro(ctx, macro)
|
||||||
|
|
||||||
|
|
||||||
def process_node(
|
def process_node(config: RuntimeConfig, manifest: Manifest, node: ManifestNode):
|
||||||
config: RuntimeConfig, manifest: Manifest, node: ManifestNode
|
|
||||||
):
|
|
||||||
|
|
||||||
_process_sources_for_node(
|
_process_sources_for_node(manifest, config.project_name, node)
|
||||||
manifest, config.project_name, node
|
|
||||||
)
|
|
||||||
_process_refs_for_node(manifest, config.project_name, node)
|
_process_refs_for_node(manifest, config.project_name, node)
|
||||||
ctx = generate_runtime_docs(config, node, manifest, config.project_name)
|
ctx = generate_runtime_docs(config, node, manifest, config.project_name)
|
||||||
_process_docs_for_node(ctx, node)
|
_process_docs_for_node(ctx, node)
|
||||||
|
|||||||
@@ -6,9 +6,7 @@ from dbt.parser.search import FilesystemSearcher, FileBlock
|
|||||||
|
|
||||||
class ModelParser(SimpleSQLParser[ParsedModelNode]):
|
class ModelParser(SimpleSQLParser[ParsedModelNode]):
|
||||||
def get_paths(self):
|
def get_paths(self):
|
||||||
return FilesystemSearcher(
|
return FilesystemSearcher(self.project, self.project.source_paths, ".sql")
|
||||||
self.project, self.project.source_paths, '.sql'
|
|
||||||
)
|
|
||||||
|
|
||||||
def parse_from_dict(self, dct, validate=True) -> ParsedModelNode:
|
def parse_from_dict(self, dct, validate=True) -> ParsedModelNode:
|
||||||
if validate:
|
if validate:
|
||||||
|
|||||||
@@ -25,9 +25,13 @@ from dbt.contracts.graph.parsed import (
|
|||||||
from dbt.contracts.graph.unparsed import SourcePatch
|
from dbt.contracts.graph.unparsed import SourcePatch
|
||||||
from dbt.contracts.util import Writable, Replaceable, MacroKey, SourceKey
|
from dbt.contracts.util import Writable, Replaceable, MacroKey, SourceKey
|
||||||
from dbt.exceptions import (
|
from dbt.exceptions import (
|
||||||
raise_duplicate_resource_name, raise_duplicate_patch_name,
|
raise_duplicate_resource_name,
|
||||||
raise_duplicate_macro_patch_name, CompilationException, InternalException,
|
raise_duplicate_patch_name,
|
||||||
raise_compiler_error, raise_duplicate_source_patch_name
|
raise_duplicate_macro_patch_name,
|
||||||
|
CompilationException,
|
||||||
|
InternalException,
|
||||||
|
raise_compiler_error,
|
||||||
|
raise_duplicate_source_patch_name,
|
||||||
)
|
)
|
||||||
from dbt.node_types import NodeType
|
from dbt.node_types import NodeType
|
||||||
from dbt.ui import line_wrap_message
|
from dbt.ui import line_wrap_message
|
||||||
@@ -35,12 +39,10 @@ from dbt.version import __version__
|
|||||||
|
|
||||||
|
|
||||||
# Parsers can return anything as long as it's a unique ID
|
# Parsers can return anything as long as it's a unique ID
|
||||||
ParsedValueType = TypeVar('ParsedValueType', bound=HasUniqueID)
|
ParsedValueType = TypeVar("ParsedValueType", bound=HasUniqueID)
|
||||||
|
|
||||||
|
|
||||||
def _check_duplicates(
|
def _check_duplicates(value: HasUniqueID, src: Mapping[str, HasUniqueID]):
|
||||||
value: HasUniqueID, src: Mapping[str, HasUniqueID]
|
|
||||||
):
|
|
||||||
if value.unique_id in src:
|
if value.unique_id in src:
|
||||||
raise_duplicate_resource_name(value, src[value.unique_id])
|
raise_duplicate_resource_name(value, src[value.unique_id])
|
||||||
|
|
||||||
@@ -86,9 +88,7 @@ class ParseResult(dbtClassMixin, Writable, Replaceable):
|
|||||||
self.files[key] = source_file
|
self.files[key] = source_file
|
||||||
return self.files[key]
|
return self.files[key]
|
||||||
|
|
||||||
def add_source(
|
def add_source(self, source_file: SourceFile, source: UnpatchedSourceDefinition):
|
||||||
self, source_file: SourceFile, source: UnpatchedSourceDefinition
|
|
||||||
):
|
|
||||||
# sources can't be overwritten!
|
# sources can't be overwritten!
|
||||||
_check_duplicates(source, self.sources)
|
_check_duplicates(source, self.sources)
|
||||||
self.sources[source.unique_id] = source
|
self.sources[source.unique_id] = source
|
||||||
@@ -126,7 +126,7 @@ class ParseResult(dbtClassMixin, Writable, Replaceable):
|
|||||||
# note that the line wrap eats newlines, so if you want newlines,
|
# note that the line wrap eats newlines, so if you want newlines,
|
||||||
# this is the result :(
|
# this is the result :(
|
||||||
msg = line_wrap_message(
|
msg = line_wrap_message(
|
||||||
f'''\
|
f"""\
|
||||||
dbt found two macros named "{macro.name}" in the project
|
dbt found two macros named "{macro.name}" in the project
|
||||||
"{macro.package_name}".
|
"{macro.package_name}".
|
||||||
|
|
||||||
@@ -137,8 +137,8 @@ class ParseResult(dbtClassMixin, Writable, Replaceable):
|
|||||||
- {macro.original_file_path}
|
- {macro.original_file_path}
|
||||||
|
|
||||||
- {other_path}
|
- {other_path}
|
||||||
''',
|
""",
|
||||||
subtract=2
|
subtract=2,
|
||||||
)
|
)
|
||||||
raise_compiler_error(msg)
|
raise_compiler_error(msg)
|
||||||
|
|
||||||
@@ -150,18 +150,14 @@ class ParseResult(dbtClassMixin, Writable, Replaceable):
|
|||||||
self.docs[doc.unique_id] = doc
|
self.docs[doc.unique_id] = doc
|
||||||
self.get_file(source_file).docs.append(doc.unique_id)
|
self.get_file(source_file).docs.append(doc.unique_id)
|
||||||
|
|
||||||
def add_patch(
|
def add_patch(self, source_file: SourceFile, patch: ParsedNodePatch) -> None:
|
||||||
self, source_file: SourceFile, patch: ParsedNodePatch
|
|
||||||
) -> None:
|
|
||||||
# patches can't be overwritten
|
# patches can't be overwritten
|
||||||
if patch.name in self.patches:
|
if patch.name in self.patches:
|
||||||
raise_duplicate_patch_name(patch, self.patches[patch.name])
|
raise_duplicate_patch_name(patch, self.patches[patch.name])
|
||||||
self.patches[patch.name] = patch
|
self.patches[patch.name] = patch
|
||||||
self.get_file(source_file).patches.append(patch.name)
|
self.get_file(source_file).patches.append(patch.name)
|
||||||
|
|
||||||
def add_macro_patch(
|
def add_macro_patch(self, source_file: SourceFile, patch: ParsedMacroPatch) -> None:
|
||||||
self, source_file: SourceFile, patch: ParsedMacroPatch
|
|
||||||
) -> None:
|
|
||||||
# macros are fully namespaced
|
# macros are fully namespaced
|
||||||
key = (patch.package_name, patch.name)
|
key = (patch.package_name, patch.name)
|
||||||
if key in self.macro_patches:
|
if key in self.macro_patches:
|
||||||
@@ -169,9 +165,7 @@ class ParseResult(dbtClassMixin, Writable, Replaceable):
|
|||||||
self.macro_patches[key] = patch
|
self.macro_patches[key] = patch
|
||||||
self.get_file(source_file).macro_patches.append(key)
|
self.get_file(source_file).macro_patches.append(key)
|
||||||
|
|
||||||
def add_source_patch(
|
def add_source_patch(self, source_file: SourceFile, patch: SourcePatch) -> None:
|
||||||
self, source_file: SourceFile, patch: SourcePatch
|
|
||||||
) -> None:
|
|
||||||
# source patches must be unique
|
# source patches must be unique
|
||||||
key = (patch.overrides, patch.name)
|
key = (patch.overrides, patch.name)
|
||||||
if key in self.source_patches:
|
if key in self.source_patches:
|
||||||
@@ -186,11 +180,13 @@ class ParseResult(dbtClassMixin, Writable, Replaceable):
|
|||||||
) -> List[CompileResultNode]:
|
) -> List[CompileResultNode]:
|
||||||
if unique_id not in self.disabled:
|
if unique_id not in self.disabled:
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
'called _get_disabled with id={}, but it does not exist'
|
"called _get_disabled with id={}, but it does not exist".format(
|
||||||
.format(unique_id)
|
unique_id
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return [
|
return [
|
||||||
n for n in self.disabled[unique_id]
|
n
|
||||||
|
for n in self.disabled[unique_id]
|
||||||
if n.original_file_path == match_file.path.original_file_path
|
if n.original_file_path == match_file.path.original_file_path
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -199,7 +195,7 @@ class ParseResult(dbtClassMixin, Writable, Replaceable):
|
|||||||
node_id: str,
|
node_id: str,
|
||||||
source_file: SourceFile,
|
source_file: SourceFile,
|
||||||
old_file: SourceFile,
|
old_file: SourceFile,
|
||||||
old_result: 'ParseResult',
|
old_result: "ParseResult",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Nodes are a special kind of complicated - there can be multiple
|
"""Nodes are a special kind of complicated - there can be multiple
|
||||||
with the same name, as long as all but one are disabled.
|
with the same name, as long as all but one are disabled.
|
||||||
@@ -224,14 +220,15 @@ class ParseResult(dbtClassMixin, Writable, Replaceable):
|
|||||||
if not found:
|
if not found:
|
||||||
raise CompilationException(
|
raise CompilationException(
|
||||||
'Expected to find "{}" in cached "manifest.nodes" or '
|
'Expected to find "{}" in cached "manifest.nodes" or '
|
||||||
'"manifest.disabled" based on cached file information: {}!'
|
'"manifest.disabled" based on cached file information: {}!'.format(
|
||||||
.format(node_id, old_file)
|
node_id, old_file
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def sanitized_update(
|
def sanitized_update(
|
||||||
self,
|
self,
|
||||||
source_file: SourceFile,
|
source_file: SourceFile,
|
||||||
old_result: 'ParseResult',
|
old_result: "ParseResult",
|
||||||
resource_type: NodeType,
|
resource_type: NodeType,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Perform a santized update. If the file can't be updated, invalidate
|
"""Perform a santized update. If the file can't be updated, invalidate
|
||||||
@@ -246,15 +243,11 @@ class ParseResult(dbtClassMixin, Writable, Replaceable):
|
|||||||
self.add_doc(source_file, doc)
|
self.add_doc(source_file, doc)
|
||||||
|
|
||||||
for macro_id in old_file.macros:
|
for macro_id in old_file.macros:
|
||||||
macro = _expect_value(
|
macro = _expect_value(macro_id, old_result.macros, old_file, "macros")
|
||||||
macro_id, old_result.macros, old_file, "macros"
|
|
||||||
)
|
|
||||||
self.add_macro(source_file, macro)
|
self.add_macro(source_file, macro)
|
||||||
|
|
||||||
for source_id in old_file.sources:
|
for source_id in old_file.sources:
|
||||||
source = _expect_value(
|
source = _expect_value(source_id, old_result.sources, old_file, "sources")
|
||||||
source_id, old_result.sources, old_file, "sources"
|
|
||||||
)
|
|
||||||
self.add_source(source_file, source)
|
self.add_source(source_file, source)
|
||||||
|
|
||||||
# because we know this is how we _parsed_ the node, we can safely
|
# because we know this is how we _parsed_ the node, we can safely
|
||||||
@@ -265,7 +258,7 @@ class ParseResult(dbtClassMixin, Writable, Replaceable):
|
|||||||
for node_id in old_file.nodes:
|
for node_id in old_file.nodes:
|
||||||
# cheat: look at the first part of the node ID and compare it to
|
# cheat: look at the first part of the node ID and compare it to
|
||||||
# the parser resource type. On a mismatch, bail out.
|
# the parser resource type. On a mismatch, bail out.
|
||||||
if resource_type != node_id.split('.')[0]:
|
if resource_type != node_id.split(".")[0]:
|
||||||
continue
|
continue
|
||||||
self._process_node(node_id, source_file, old_file, old_result)
|
self._process_node(node_id, source_file, old_file, old_result)
|
||||||
|
|
||||||
@@ -277,9 +270,7 @@ class ParseResult(dbtClassMixin, Writable, Replaceable):
|
|||||||
|
|
||||||
patched = False
|
patched = False
|
||||||
for name in old_file.patches:
|
for name in old_file.patches:
|
||||||
patch = _expect_value(
|
patch = _expect_value(name, old_result.patches, old_file, "patches")
|
||||||
name, old_result.patches, old_file, "patches"
|
|
||||||
)
|
|
||||||
self.add_patch(source_file, patch)
|
self.add_patch(source_file, patch)
|
||||||
patched = True
|
patched = True
|
||||||
if patched:
|
if patched:
|
||||||
@@ -312,8 +303,8 @@ class ParseResult(dbtClassMixin, Writable, Replaceable):
|
|||||||
return cls(FileHash.empty(), FileHash.empty(), {})
|
return cls(FileHash.empty(), FileHash.empty(), {})
|
||||||
|
|
||||||
|
|
||||||
K_T = TypeVar('K_T')
|
K_T = TypeVar("K_T")
|
||||||
V_T = TypeVar('V_T')
|
V_T = TypeVar("V_T")
|
||||||
|
|
||||||
|
|
||||||
def _expect_value(
|
def _expect_value(
|
||||||
@@ -322,7 +313,6 @@ def _expect_value(
|
|||||||
if key not in src:
|
if key not in src:
|
||||||
raise CompilationException(
|
raise CompilationException(
|
||||||
'Expected to find "{}" in cached "result.{}" based '
|
'Expected to find "{}" in cached "result.{}" based '
|
||||||
'on cached file information: {}!'
|
"on cached file information: {}!".format(key, name, old_file)
|
||||||
.format(key, name, old_file)
|
|
||||||
)
|
)
|
||||||
return src[key]
|
return src[key]
|
||||||
|
|||||||
@@ -38,11 +38,11 @@ class RPCCallParser(SimpleSQLParser[ParsedRPCNode]):
|
|||||||
# we do it this way to make mypy happy
|
# we do it this way to make mypy happy
|
||||||
if not isinstance(block, RPCBlock):
|
if not isinstance(block, RPCBlock):
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
'While parsing RPC calls, got an actual file block instead of '
|
"While parsing RPC calls, got an actual file block instead of "
|
||||||
'an RPC block: {}'.format(block)
|
"an RPC block: {}".format(block)
|
||||||
)
|
)
|
||||||
|
|
||||||
return os.path.join('rpc', block.name)
|
return os.path.join("rpc", block.name)
|
||||||
|
|
||||||
def parse_remote(self, sql: str, name: str) -> ParsedRPCNode:
|
def parse_remote(self, sql: str, name: str) -> ParsedRPCNode:
|
||||||
source_file = SourceFile.remote(contents=sql)
|
source_file = SourceFile.remote(contents=sql)
|
||||||
@@ -53,8 +53,8 @@ class RPCCallParser(SimpleSQLParser[ParsedRPCNode]):
|
|||||||
class RPCMacroParser(MacroParser):
|
class RPCMacroParser(MacroParser):
|
||||||
def parse_remote(self, contents) -> Iterable[ParsedMacro]:
|
def parse_remote(self, contents) -> Iterable[ParsedMacro]:
|
||||||
base = UnparsedMacro(
|
base = UnparsedMacro(
|
||||||
path='from remote system',
|
path="from remote system",
|
||||||
original_file_path='from remote system',
|
original_file_path="from remote system",
|
||||||
package_name=self.project.project_name,
|
package_name=self.project.project_name,
|
||||||
raw_sql=contents,
|
raw_sql=contents,
|
||||||
root_path=self.project.project_root,
|
root_path=self.project.project_root,
|
||||||
|
|||||||
@@ -3,7 +3,13 @@ import re
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (
|
from typing import (
|
||||||
Generic, TypeVar, Dict, Any, Tuple, Optional, List,
|
Generic,
|
||||||
|
TypeVar,
|
||||||
|
Dict,
|
||||||
|
Any,
|
||||||
|
Tuple,
|
||||||
|
Optional,
|
||||||
|
List,
|
||||||
)
|
)
|
||||||
|
|
||||||
from dbt.clients.jinja import get_rendered, SCHEMA_TEST_KWARGS_NAME
|
from dbt.clients.jinja import get_rendered, SCHEMA_TEST_KWARGS_NAME
|
||||||
@@ -25,7 +31,7 @@ def get_nice_schema_test_name(
|
|||||||
flat_args = []
|
flat_args = []
|
||||||
for arg_name in sorted(args):
|
for arg_name in sorted(args):
|
||||||
# the model is already embedded in the name, so skip it
|
# the model is already embedded in the name, so skip it
|
||||||
if arg_name == 'model':
|
if arg_name == "model":
|
||||||
continue
|
continue
|
||||||
arg_val = args[arg_name]
|
arg_val = args[arg_name]
|
||||||
|
|
||||||
@@ -38,17 +44,17 @@ def get_nice_schema_test_name(
|
|||||||
|
|
||||||
flat_args.extend([str(part) for part in parts])
|
flat_args.extend([str(part) for part in parts])
|
||||||
|
|
||||||
clean_flat_args = [re.sub('[^0-9a-zA-Z_]+', '_', arg) for arg in flat_args]
|
clean_flat_args = [re.sub("[^0-9a-zA-Z_]+", "_", arg) for arg in flat_args]
|
||||||
unique = "__".join(clean_flat_args)
|
unique = "__".join(clean_flat_args)
|
||||||
|
|
||||||
cutoff = 32
|
cutoff = 32
|
||||||
if len(unique) <= cutoff:
|
if len(unique) <= cutoff:
|
||||||
label = unique
|
label = unique
|
||||||
else:
|
else:
|
||||||
label = hashlib.md5(unique.encode('utf-8')).hexdigest()
|
label = hashlib.md5(unique.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
filename = '{}_{}_{}'.format(test_type, test_name, label)
|
filename = "{}_{}_{}".format(test_type, test_name, label)
|
||||||
name = '{}_{}_{}'.format(test_type, test_name, unique)
|
name = "{}_{}_{}".format(test_type, test_name, unique)
|
||||||
|
|
||||||
return filename, name
|
return filename, name
|
||||||
|
|
||||||
@@ -65,19 +71,17 @@ class YamlBlock(FileBlock):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
Testable = TypeVar(
|
Testable = TypeVar("Testable", UnparsedNodeUpdate, UnpatchedSourceDefinition)
|
||||||
'Testable', UnparsedNodeUpdate, UnpatchedSourceDefinition
|
|
||||||
)
|
|
||||||
|
|
||||||
ColumnTarget = TypeVar(
|
ColumnTarget = TypeVar(
|
||||||
'ColumnTarget',
|
"ColumnTarget",
|
||||||
UnparsedNodeUpdate,
|
UnparsedNodeUpdate,
|
||||||
UnparsedAnalysisUpdate,
|
UnparsedAnalysisUpdate,
|
||||||
UnpatchedSourceDefinition,
|
UnpatchedSourceDefinition,
|
||||||
)
|
)
|
||||||
|
|
||||||
Target = TypeVar(
|
Target = TypeVar(
|
||||||
'Target',
|
"Target",
|
||||||
UnparsedNodeUpdate,
|
UnparsedNodeUpdate,
|
||||||
UnparsedMacroUpdate,
|
UnparsedMacroUpdate,
|
||||||
UnparsedAnalysisUpdate,
|
UnparsedAnalysisUpdate,
|
||||||
@@ -103,9 +107,7 @@ class TargetBlock(YamlBlock, Generic[Target]):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_yaml_block(
|
def from_yaml_block(cls, src: YamlBlock, target: Target) -> "TargetBlock[Target]":
|
||||||
cls, src: YamlBlock, target: Target
|
|
||||||
) -> 'TargetBlock[Target]':
|
|
||||||
return cls(
|
return cls(
|
||||||
file=src.file,
|
file=src.file,
|
||||||
data=src.data,
|
data=src.data,
|
||||||
@@ -137,9 +139,7 @@ class TestBlock(TargetColumnsBlock[Testable], Generic[Testable]):
|
|||||||
return self.target.quote_columns
|
return self.target.quote_columns
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_yaml_block(
|
def from_yaml_block(cls, src: YamlBlock, target: Testable) -> "TestBlock[Testable]":
|
||||||
cls, src: YamlBlock, target: Testable
|
|
||||||
) -> 'TestBlock[Testable]':
|
|
||||||
return cls(
|
return cls(
|
||||||
file=src.file,
|
file=src.file,
|
||||||
data=src.data,
|
data=src.data,
|
||||||
@@ -160,7 +160,7 @@ class SchemaTestBlock(TestBlock[Testable], Generic[Testable]):
|
|||||||
test: Dict[str, Any],
|
test: Dict[str, Any],
|
||||||
column_name: Optional[str],
|
column_name: Optional[str],
|
||||||
tags: List[str],
|
tags: List[str],
|
||||||
) -> 'SchemaTestBlock':
|
) -> "SchemaTestBlock":
|
||||||
return cls(
|
return cls(
|
||||||
file=src.file,
|
file=src.file,
|
||||||
data=src.data,
|
data=src.data,
|
||||||
@@ -179,13 +179,14 @@ class TestBuilder(Generic[Testable]):
|
|||||||
- or it may not be namespaced (test)
|
- or it may not be namespaced (test)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# The 'test_name' is used to find the 'macro' that implements the test
|
# The 'test_name' is used to find the 'macro' that implements the test
|
||||||
TEST_NAME_PATTERN = re.compile(
|
TEST_NAME_PATTERN = re.compile(
|
||||||
r'((?P<test_namespace>([a-zA-Z_][0-9a-zA-Z_]*))\.)?'
|
r"((?P<test_namespace>([a-zA-Z_][0-9a-zA-Z_]*))\.)?"
|
||||||
r'(?P<test_name>([a-zA-Z_][0-9a-zA-Z_]*))'
|
r"(?P<test_name>([a-zA-Z_][0-9a-zA-Z_]*))"
|
||||||
)
|
)
|
||||||
# map magic keys to default values
|
# map magic keys to default values
|
||||||
MODIFIER_ARGS = {'severity': 'ERROR', 'tags': []}
|
MODIFIER_ARGS = {"severity": "ERROR", "tags": []}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -197,25 +198,24 @@ class TestBuilder(Generic[Testable]):
|
|||||||
) -> None:
|
) -> None:
|
||||||
test_name, test_args = self.extract_test_args(test, column_name)
|
test_name, test_args = self.extract_test_args(test, column_name)
|
||||||
self.args: Dict[str, Any] = test_args
|
self.args: Dict[str, Any] = test_args
|
||||||
if 'model' in self.args:
|
if "model" in self.args:
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
'Test arguments include "model", which is a reserved argument',
|
'Test arguments include "model", which is a reserved argument',
|
||||||
)
|
)
|
||||||
self.package_name: str = package_name
|
self.package_name: str = package_name
|
||||||
self.target: Testable = target
|
self.target: Testable = target
|
||||||
|
|
||||||
self.args['model'] = self.build_model_str()
|
self.args["model"] = self.build_model_str()
|
||||||
|
|
||||||
match = self.TEST_NAME_PATTERN.match(test_name)
|
match = self.TEST_NAME_PATTERN.match(test_name)
|
||||||
if match is None:
|
if match is None:
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
'Test name string did not match expected pattern: {}'
|
"Test name string did not match expected pattern: {}".format(test_name)
|
||||||
.format(test_name)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
groups = match.groupdict()
|
groups = match.groupdict()
|
||||||
self.name: str = groups['test_name']
|
self.name: str = groups["test_name"]
|
||||||
self.namespace: str = groups['test_namespace']
|
self.namespace: str = groups["test_namespace"]
|
||||||
self.modifiers: Dict[str, Any] = {}
|
self.modifiers: Dict[str, Any] = {}
|
||||||
for key, default in self.MODIFIER_ARGS.items():
|
for key, default in self.MODIFIER_ARGS.items():
|
||||||
value = self.args.pop(key, default)
|
value = self.args.pop(key, default)
|
||||||
@@ -237,57 +237,52 @@ class TestBuilder(Generic[Testable]):
|
|||||||
def extract_test_args(test, name=None) -> Tuple[str, Dict[str, Any]]:
|
def extract_test_args(test, name=None) -> Tuple[str, Dict[str, Any]]:
|
||||||
if not isinstance(test, dict):
|
if not isinstance(test, dict):
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
'test must be dict or str, got {} (value {})'.format(
|
"test must be dict or str, got {} (value {})".format(type(test), test)
|
||||||
type(test), test
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
test = list(test.items())
|
test = list(test.items())
|
||||||
if len(test) != 1:
|
if len(test) != 1:
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
'test definition dictionary must have exactly one key, got'
|
"test definition dictionary must have exactly one key, got"
|
||||||
' {} instead ({} keys)'.format(test, len(test))
|
" {} instead ({} keys)".format(test, len(test))
|
||||||
)
|
)
|
||||||
test_name, test_args = test[0]
|
test_name, test_args = test[0]
|
||||||
|
|
||||||
if not isinstance(test_args, dict):
|
if not isinstance(test_args, dict):
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
'test arguments must be dict, got {} (value {})'.format(
|
"test arguments must be dict, got {} (value {})".format(
|
||||||
type(test_args), test_args
|
type(test_args), test_args
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if not isinstance(test_name, str):
|
if not isinstance(test_name, str):
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
'test name must be a str, got {} (value {})'.format(
|
"test name must be a str, got {} (value {})".format(
|
||||||
type(test_name), test_name
|
type(test_name), test_name
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
test_args = deepcopy(test_args)
|
test_args = deepcopy(test_args)
|
||||||
if name is not None:
|
if name is not None:
|
||||||
test_args['column_name'] = name
|
test_args["column_name"] = name
|
||||||
return test_name, test_args
|
return test_name, test_args
|
||||||
|
|
||||||
def severity(self) -> str:
|
def severity(self) -> str:
|
||||||
return self.modifiers.get('severity', 'ERROR').upper()
|
return self.modifiers.get("severity", "ERROR").upper()
|
||||||
|
|
||||||
def tags(self) -> List[str]:
|
def tags(self) -> List[str]:
|
||||||
tags = self.modifiers.get('tags', [])
|
tags = self.modifiers.get("tags", [])
|
||||||
if isinstance(tags, str):
|
if isinstance(tags, str):
|
||||||
tags = [tags]
|
tags = [tags]
|
||||||
if not isinstance(tags, list):
|
if not isinstance(tags, list):
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
f'got {tags} ({type(tags)}) for tags, expected a list of '
|
f"got {tags} ({type(tags)}) for tags, expected a list of " f"strings"
|
||||||
f'strings'
|
|
||||||
)
|
)
|
||||||
for tag in tags:
|
for tag in tags:
|
||||||
if not isinstance(tag, str):
|
if not isinstance(tag, str):
|
||||||
raise_compiler_error(
|
raise_compiler_error(f"got {tag} ({type(tag)}) for tag, expected a str")
|
||||||
f'got {tag} ({type(tag)}) for tag, expected a str'
|
|
||||||
)
|
|
||||||
return tags[:]
|
return tags[:]
|
||||||
|
|
||||||
def macro_name(self) -> str:
|
def macro_name(self) -> str:
|
||||||
macro_name = 'test_{}'.format(self.name)
|
macro_name = "test_{}".format(self.name)
|
||||||
if self.namespace is not None:
|
if self.namespace is not None:
|
||||||
macro_name = "{}.{}".format(self.namespace, macro_name)
|
macro_name = "{}.{}".format(self.namespace, macro_name)
|
||||||
return macro_name
|
return macro_name
|
||||||
@@ -296,11 +291,11 @@ class TestBuilder(Generic[Testable]):
|
|||||||
if isinstance(self.target, UnparsedNodeUpdate):
|
if isinstance(self.target, UnparsedNodeUpdate):
|
||||||
name = self.name
|
name = self.name
|
||||||
elif isinstance(self.target, UnpatchedSourceDefinition):
|
elif isinstance(self.target, UnpatchedSourceDefinition):
|
||||||
name = 'source_' + self.name
|
name = "source_" + self.name
|
||||||
else:
|
else:
|
||||||
raise self._bad_type()
|
raise self._bad_type()
|
||||||
if self.namespace is not None:
|
if self.namespace is not None:
|
||||||
name = '{}_{}'.format(self.namespace, name)
|
name = "{}_{}".format(self.namespace, name)
|
||||||
return get_nice_schema_test_name(name, self.target.name, self.args)
|
return get_nice_schema_test_name(name, self.target.name, self.args)
|
||||||
|
|
||||||
# this is the 'raw_sql' that's used in 'render_update' and execution
|
# this is the 'raw_sql' that's used in 'render_update' and execution
|
||||||
|
|||||||
@@ -2,9 +2,7 @@ import itertools
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
from typing import (
|
from typing import Iterable, Dict, Any, Union, List, Optional, Generic, TypeVar, Type
|
||||||
Iterable, Dict, Any, Union, List, Optional, Generic, TypeVar, Type
|
|
||||||
)
|
|
||||||
|
|
||||||
from dbt.dataclass_schema import ValidationError, dbtClassMixin
|
from dbt.dataclass_schema import ValidationError, dbtClassMixin
|
||||||
|
|
||||||
@@ -20,9 +18,7 @@ from dbt.context.context_config import (
|
|||||||
)
|
)
|
||||||
from dbt.context.configured import generate_schema_yml
|
from dbt.context.configured import generate_schema_yml
|
||||||
from dbt.context.target import generate_target_context
|
from dbt.context.target import generate_target_context
|
||||||
from dbt.context.providers import (
|
from dbt.context.providers import generate_parse_exposure, generate_test_context
|
||||||
generate_parse_exposure, generate_test_context
|
|
||||||
)
|
|
||||||
from dbt.context.macro_resolver import MacroResolver
|
from dbt.context.macro_resolver import MacroResolver
|
||||||
from dbt.contracts.files import FileHash
|
from dbt.contracts.files import FileHash
|
||||||
from dbt.contracts.graph.manifest import SourceFile
|
from dbt.contracts.graph.manifest import SourceFile
|
||||||
@@ -50,20 +46,26 @@ from dbt.contracts.graph.unparsed import (
|
|||||||
UnparsedSourceDefinition,
|
UnparsedSourceDefinition,
|
||||||
)
|
)
|
||||||
from dbt.exceptions import (
|
from dbt.exceptions import (
|
||||||
validator_error_message, JSONValidationException,
|
validator_error_message,
|
||||||
raise_invalid_schema_yml_version, ValidationException,
|
JSONValidationException,
|
||||||
CompilationException, warn_or_error, InternalException
|
raise_invalid_schema_yml_version,
|
||||||
|
ValidationException,
|
||||||
|
CompilationException,
|
||||||
|
warn_or_error,
|
||||||
|
InternalException,
|
||||||
)
|
)
|
||||||
from dbt.node_types import NodeType
|
from dbt.node_types import NodeType
|
||||||
from dbt.parser.base import SimpleParser
|
from dbt.parser.base import SimpleParser
|
||||||
from dbt.parser.search import FileBlock, FilesystemSearcher
|
from dbt.parser.search import FileBlock, FilesystemSearcher
|
||||||
from dbt.parser.schema_test_builders import (
|
from dbt.parser.schema_test_builders import (
|
||||||
TestBuilder, SchemaTestBlock, TargetBlock, YamlBlock,
|
TestBuilder,
|
||||||
TestBlock, Testable
|
SchemaTestBlock,
|
||||||
)
|
TargetBlock,
|
||||||
from dbt.utils import (
|
YamlBlock,
|
||||||
get_pseudo_test_path, coerce_dict_str
|
TestBlock,
|
||||||
|
Testable,
|
||||||
)
|
)
|
||||||
|
from dbt.utils import get_pseudo_test_path, coerce_dict_str
|
||||||
|
|
||||||
|
|
||||||
UnparsedSchemaYaml = Union[
|
UnparsedSchemaYaml = Union[
|
||||||
@@ -80,19 +82,17 @@ def error_context(
|
|||||||
path: str,
|
path: str,
|
||||||
key: str,
|
key: str,
|
||||||
data: Any,
|
data: Any,
|
||||||
cause: Union[str, ValidationException, JSONValidationException]
|
cause: Union[str, ValidationException, JSONValidationException],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Provide contextual information about an error while parsing
|
"""Provide contextual information about an error while parsing"""
|
||||||
"""
|
|
||||||
if isinstance(cause, str):
|
if isinstance(cause, str):
|
||||||
reason = cause
|
reason = cause
|
||||||
elif isinstance(cause, ValidationError):
|
elif isinstance(cause, ValidationError):
|
||||||
reason = validator_error_message(cause)
|
reason = validator_error_message(cause)
|
||||||
else:
|
else:
|
||||||
reason = cause.msg
|
reason = cause.msg
|
||||||
return (
|
return "Invalid {key} config given in {path} @ {key}: {data} - {reason}".format(
|
||||||
'Invalid {key} config given in {path} @ {key}: {data} - {reason}'
|
key=key, path=path, data=data, reason=reason
|
||||||
.format(key=key, path=path, data=data, reason=reason)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -110,7 +110,7 @@ class ParserRef:
|
|||||||
meta: Dict[str, Any],
|
meta: Dict[str, Any],
|
||||||
):
|
):
|
||||||
tags: List[str] = []
|
tags: List[str] = []
|
||||||
tags.extend(getattr(column, 'tags', ()))
|
tags.extend(getattr(column, "tags", ()))
|
||||||
quote: Optional[bool]
|
quote: Optional[bool]
|
||||||
if isinstance(column, UnparsedColumn):
|
if isinstance(column, UnparsedColumn):
|
||||||
quote = column.quote
|
quote = column.quote
|
||||||
@@ -123,13 +123,11 @@ class ParserRef:
|
|||||||
meta=meta,
|
meta=meta,
|
||||||
tags=tags,
|
tags=tags,
|
||||||
quote=quote,
|
quote=quote,
|
||||||
_extra=column.extra
|
_extra=column.extra,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_target(
|
def from_target(cls, target: Union[HasColumnDocs, HasColumnTests]) -> "ParserRef":
|
||||||
cls, target: Union[HasColumnDocs, HasColumnTests]
|
|
||||||
) -> 'ParserRef':
|
|
||||||
refs = cls()
|
refs = cls()
|
||||||
for column in target.columns:
|
for column in target.columns:
|
||||||
description = column.description
|
description = column.description
|
||||||
@@ -142,7 +140,7 @@ class ParserRef:
|
|||||||
def _trimmed(inp: str) -> str:
|
def _trimmed(inp: str) -> str:
|
||||||
if len(inp) < 50:
|
if len(inp) < 50:
|
||||||
return inp
|
return inp
|
||||||
return inp[:44] + '...' + inp[-3:]
|
return inp[:44] + "..." + inp[-3:]
|
||||||
|
|
||||||
|
|
||||||
def merge_freshness(
|
def merge_freshness(
|
||||||
@@ -158,21 +156,20 @@ def merge_freshness(
|
|||||||
|
|
||||||
class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, results, project, root_project, macro_manifest,
|
self,
|
||||||
|
results,
|
||||||
|
project,
|
||||||
|
root_project,
|
||||||
|
macro_manifest,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(results, project, root_project, macro_manifest)
|
super().__init__(results, project, root_project, macro_manifest)
|
||||||
all_v_2 = (
|
all_v_2 = (
|
||||||
self.root_project.config_version == 2 and
|
self.root_project.config_version == 2 and self.project.config_version == 2
|
||||||
self.project.config_version == 2
|
|
||||||
)
|
)
|
||||||
if all_v_2:
|
if all_v_2:
|
||||||
ctx = generate_schema_yml(
|
ctx = generate_schema_yml(self.root_project, self.project.project_name)
|
||||||
self.root_project, self.project.project_name
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
ctx = generate_target_context(
|
ctx = generate_target_context(self.root_project, self.root_project.cli_vars)
|
||||||
self.root_project, self.root_project.cli_vars
|
|
||||||
)
|
|
||||||
|
|
||||||
self.raw_renderer = SchemaYamlRenderer(ctx)
|
self.raw_renderer = SchemaYamlRenderer(ctx)
|
||||||
|
|
||||||
@@ -182,7 +179,7 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
|||||||
self.macro_resolver = MacroResolver(
|
self.macro_resolver = MacroResolver(
|
||||||
self.macro_manifest.macros,
|
self.macro_manifest.macros,
|
||||||
self.root_project.project_name,
|
self.root_project.project_name,
|
||||||
internal_package_names
|
internal_package_names,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -197,65 +194,55 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
|||||||
def get_paths(self):
|
def get_paths(self):
|
||||||
# TODO: In order to support this, make FilesystemSearcher accept a list
|
# TODO: In order to support this, make FilesystemSearcher accept a list
|
||||||
# of file patterns. eg: ['.yml', '.yaml']
|
# of file patterns. eg: ['.yml', '.yaml']
|
||||||
yaml_files = list(FilesystemSearcher(
|
yaml_files = list(
|
||||||
self.project, self.project.all_source_paths, '.yaml'
|
FilesystemSearcher(self.project, self.project.all_source_paths, ".yaml")
|
||||||
))
|
)
|
||||||
if yaml_files:
|
if yaml_files:
|
||||||
warn_or_error(
|
warn_or_error(
|
||||||
'A future version of dbt will parse files with both'
|
"A future version of dbt will parse files with both"
|
||||||
' .yml and .yaml file extensions. dbt found'
|
" .yml and .yaml file extensions. dbt found"
|
||||||
f' {len(yaml_files)} files with .yaml extensions in'
|
f" {len(yaml_files)} files with .yaml extensions in"
|
||||||
' your dbt project. To avoid errors when upgrading'
|
" your dbt project. To avoid errors when upgrading"
|
||||||
' to a future release, either remove these files from'
|
" to a future release, either remove these files from"
|
||||||
' your dbt project, or change their extensions.'
|
" your dbt project, or change their extensions."
|
||||||
)
|
|
||||||
return FilesystemSearcher(
|
|
||||||
self.project, self.project.all_source_paths, '.yml'
|
|
||||||
)
|
)
|
||||||
|
return FilesystemSearcher(self.project, self.project.all_source_paths, ".yml")
|
||||||
|
|
||||||
def parse_from_dict(self, dct, validate=True) -> ParsedSchemaTestNode:
|
def parse_from_dict(self, dct, validate=True) -> ParsedSchemaTestNode:
|
||||||
if validate:
|
if validate:
|
||||||
ParsedSchemaTestNode.validate(dct)
|
ParsedSchemaTestNode.validate(dct)
|
||||||
return ParsedSchemaTestNode.from_dict(dct)
|
return ParsedSchemaTestNode.from_dict(dct)
|
||||||
|
|
||||||
def _check_format_version(
|
def _check_format_version(self, yaml: YamlBlock) -> None:
|
||||||
self, yaml: YamlBlock
|
|
||||||
) -> None:
|
|
||||||
path = yaml.path.relative_path
|
path = yaml.path.relative_path
|
||||||
if 'version' not in yaml.data:
|
if "version" not in yaml.data:
|
||||||
raise_invalid_schema_yml_version(path, 'no version is specified')
|
raise_invalid_schema_yml_version(path, "no version is specified")
|
||||||
|
|
||||||
version = yaml.data['version']
|
version = yaml.data["version"]
|
||||||
# if it's not an integer, the version is malformed, or not
|
# if it's not an integer, the version is malformed, or not
|
||||||
# set. Either way, only 'version: 2' is supported.
|
# set. Either way, only 'version: 2' is supported.
|
||||||
if not isinstance(version, int):
|
if not isinstance(version, int):
|
||||||
raise_invalid_schema_yml_version(
|
raise_invalid_schema_yml_version(path, "the version is not an integer")
|
||||||
path, 'the version is not an integer'
|
|
||||||
)
|
|
||||||
if version != 2:
|
if version != 2:
|
||||||
raise_invalid_schema_yml_version(
|
raise_invalid_schema_yml_version(
|
||||||
path, 'version {} is not supported'.format(version)
|
path, "version {} is not supported".format(version)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _yaml_from_file(
|
def _yaml_from_file(self, source_file: SourceFile) -> Optional[Dict[str, Any]]:
|
||||||
self, source_file: SourceFile
|
"""If loading the yaml fails, raise an exception."""
|
||||||
) -> Optional[Dict[str, Any]]:
|
|
||||||
"""If loading the yaml fails, raise an exception.
|
|
||||||
"""
|
|
||||||
path: str = source_file.path.relative_path
|
path: str = source_file.path.relative_path
|
||||||
try:
|
try:
|
||||||
return load_yaml_text(source_file.contents)
|
return load_yaml_text(source_file.contents)
|
||||||
except ValidationException as e:
|
except ValidationException as e:
|
||||||
reason = validator_error_message(e)
|
reason = validator_error_message(e)
|
||||||
raise CompilationException(
|
raise CompilationException(
|
||||||
'Error reading {}: {} - {}'
|
"Error reading {}: {} - {}".format(
|
||||||
.format(self.project.project_name, path, reason)
|
self.project.project_name, path, reason
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def parse_column_tests(
|
def parse_column_tests(self, block: TestBlock, column: UnparsedColumn) -> None:
|
||||||
self, block: TestBlock, column: UnparsedColumn
|
|
||||||
) -> None:
|
|
||||||
if not column.tests:
|
if not column.tests:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -267,9 +254,7 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
|||||||
if rendered:
|
if rendered:
|
||||||
generator = ContextConfigGenerator(self.root_project)
|
generator = ContextConfigGenerator(self.root_project)
|
||||||
else:
|
else:
|
||||||
generator = UnrenderedConfigGenerator(
|
generator = UnrenderedConfigGenerator(self.root_project)
|
||||||
self.root_project
|
|
||||||
)
|
|
||||||
|
|
||||||
return generator.calculate_node_config(
|
return generator.calculate_node_config(
|
||||||
config_calls=[],
|
config_calls=[],
|
||||||
@@ -284,16 +269,14 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
|||||||
relation_cls = adapter.Relation
|
relation_cls = adapter.Relation
|
||||||
return str(relation_cls.create_from(self.root_project, node))
|
return str(relation_cls.create_from(self.root_project, node))
|
||||||
|
|
||||||
def parse_source(
|
def parse_source(self, target: UnpatchedSourceDefinition) -> ParsedSourceDefinition:
|
||||||
self, target: UnpatchedSourceDefinition
|
|
||||||
) -> ParsedSourceDefinition:
|
|
||||||
source = target.source
|
source = target.source
|
||||||
table = target.table
|
table = target.table
|
||||||
refs = ParserRef.from_target(table)
|
refs = ParserRef.from_target(table)
|
||||||
unique_id = target.unique_id
|
unique_id = target.unique_id
|
||||||
description = table.description or ''
|
description = table.description or ""
|
||||||
meta = table.meta or {}
|
meta = table.meta or {}
|
||||||
source_description = source.description or ''
|
source_description = source.description or ""
|
||||||
loaded_at_field = table.loaded_at_field or source.loaded_at_field
|
loaded_at_field = table.loaded_at_field or source.loaded_at_field
|
||||||
|
|
||||||
freshness = merge_freshness(source.freshness, table.freshness)
|
freshness = merge_freshness(source.freshness, table.freshness)
|
||||||
@@ -316,8 +299,8 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
|||||||
|
|
||||||
if not isinstance(config, SourceConfig):
|
if not isinstance(config, SourceConfig):
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
f'Calculated a {type(config)} for a source, but expected '
|
f"Calculated a {type(config)} for a source, but expected "
|
||||||
f'a SourceConfig'
|
f"a SourceConfig"
|
||||||
)
|
)
|
||||||
|
|
||||||
default_database = self.root_project.credentials.database
|
default_database = self.root_project.credentials.database
|
||||||
@@ -369,23 +352,23 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
|||||||
) -> ParsedSchemaTestNode:
|
) -> ParsedSchemaTestNode:
|
||||||
|
|
||||||
dct = {
|
dct = {
|
||||||
'alias': name,
|
"alias": name,
|
||||||
'schema': self.default_schema,
|
"schema": self.default_schema,
|
||||||
'database': self.default_database,
|
"database": self.default_database,
|
||||||
'fqn': fqn,
|
"fqn": fqn,
|
||||||
'name': name,
|
"name": name,
|
||||||
'root_path': self.project.project_root,
|
"root_path": self.project.project_root,
|
||||||
'resource_type': self.resource_type,
|
"resource_type": self.resource_type,
|
||||||
'tags': tags,
|
"tags": tags,
|
||||||
'path': path,
|
"path": path,
|
||||||
'original_file_path': target.original_file_path,
|
"original_file_path": target.original_file_path,
|
||||||
'package_name': self.project.project_name,
|
"package_name": self.project.project_name,
|
||||||
'raw_sql': raw_sql,
|
"raw_sql": raw_sql,
|
||||||
'unique_id': self.generate_unique_id(name),
|
"unique_id": self.generate_unique_id(name),
|
||||||
'config': self.config_dict(config),
|
"config": self.config_dict(config),
|
||||||
'test_metadata': test_metadata,
|
"test_metadata": test_metadata,
|
||||||
'column_name': column_name,
|
"column_name": column_name,
|
||||||
'checksum': FileHash.empty().to_dict(omit_none=True),
|
"checksum": FileHash.empty().to_dict(omit_none=True),
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
ParsedSchemaTestNode.validate(dct)
|
ParsedSchemaTestNode.validate(dct)
|
||||||
@@ -424,18 +407,20 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
|||||||
)
|
)
|
||||||
except CompilationException as exc:
|
except CompilationException as exc:
|
||||||
context = _trimmed(str(target))
|
context = _trimmed(str(target))
|
||||||
msg = (
|
msg = "Invalid test config given in {}:" "\n\t{}\n\t@: {}".format(
|
||||||
'Invalid test config given in {}:'
|
target.original_file_path, exc.msg, context
|
||||||
'\n\t{}\n\t@: {}'
|
|
||||||
.format(target.original_file_path, exc.msg, context)
|
|
||||||
)
|
)
|
||||||
raise CompilationException(msg) from exc
|
raise CompilationException(msg) from exc
|
||||||
original_name = os.path.basename(target.original_file_path)
|
original_name = os.path.basename(target.original_file_path)
|
||||||
compiled_path = get_pseudo_test_path(
|
compiled_path = get_pseudo_test_path(
|
||||||
builder.compiled_name, original_name, 'schema_test',
|
builder.compiled_name,
|
||||||
|
original_name,
|
||||||
|
"schema_test",
|
||||||
)
|
)
|
||||||
fqn_path = get_pseudo_test_path(
|
fqn_path = get_pseudo_test_path(
|
||||||
builder.fqn_name, original_name, 'schema_test',
|
builder.fqn_name,
|
||||||
|
original_name,
|
||||||
|
"schema_test",
|
||||||
)
|
)
|
||||||
# the fqn for tests actually happens in the test target's name, which
|
# the fqn for tests actually happens in the test target's name, which
|
||||||
# is not necessarily this package's name
|
# is not necessarily this package's name
|
||||||
@@ -445,13 +430,13 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
|||||||
config = self.initial_config(fqn)
|
config = self.initial_config(fqn)
|
||||||
|
|
||||||
metadata = {
|
metadata = {
|
||||||
'namespace': builder.namespace,
|
"namespace": builder.namespace,
|
||||||
'name': builder.name,
|
"name": builder.name,
|
||||||
'kwargs': builder.args,
|
"kwargs": builder.args,
|
||||||
}
|
}
|
||||||
tags = sorted(set(itertools.chain(tags, builder.tags())))
|
tags = sorted(set(itertools.chain(tags, builder.tags())))
|
||||||
if 'schema' not in tags:
|
if "schema" not in tags:
|
||||||
tags.append('schema')
|
tags.append("schema")
|
||||||
|
|
||||||
node = self.create_test_node(
|
node = self.create_test_node(
|
||||||
target=target,
|
target=target,
|
||||||
@@ -477,15 +462,15 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
|||||||
# parsing to avoid jinja overhead.
|
# parsing to avoid jinja overhead.
|
||||||
def render_test_update(self, node, config, builder):
|
def render_test_update(self, node, config, builder):
|
||||||
macro_unique_id = self.macro_resolver.get_macro_id(
|
macro_unique_id = self.macro_resolver.get_macro_id(
|
||||||
node.package_name, 'test_' + builder.name)
|
node.package_name, "test_" + builder.name
|
||||||
|
)
|
||||||
# Add the depends_on here so we can limit the macros added
|
# Add the depends_on here so we can limit the macros added
|
||||||
# to the context in rendering processing
|
# to the context in rendering processing
|
||||||
node.depends_on.add_macro(macro_unique_id)
|
node.depends_on.add_macro(macro_unique_id)
|
||||||
if (macro_unique_id in
|
if macro_unique_id in ["macro.dbt.test_not_null", "macro.dbt.test_unique"]:
|
||||||
['macro.dbt.test_not_null', 'macro.dbt.test_unique']):
|
|
||||||
self.update_parsed_node(node, config)
|
self.update_parsed_node(node, config)
|
||||||
node.unrendered_config['severity'] = builder.severity()
|
node.unrendered_config["severity"] = builder.severity()
|
||||||
node.config['severity'] = builder.severity()
|
node.config["severity"] = builder.severity()
|
||||||
# source node tests are processed at patch_source time
|
# source node tests are processed at patch_source time
|
||||||
if isinstance(builder.target, UnpatchedSourceDefinition):
|
if isinstance(builder.target, UnpatchedSourceDefinition):
|
||||||
sources = [builder.target.fqn[-2], builder.target.fqn[-1]]
|
sources = [builder.target.fqn[-2], builder.target.fqn[-1]]
|
||||||
@@ -496,15 +481,16 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
|||||||
try:
|
try:
|
||||||
# make a base context that doesn't have the magic kwargs field
|
# make a base context that doesn't have the magic kwargs field
|
||||||
context = generate_test_context(
|
context = generate_test_context(
|
||||||
node, self.root_project, self.macro_manifest, config,
|
node,
|
||||||
|
self.root_project,
|
||||||
|
self.macro_manifest,
|
||||||
|
config,
|
||||||
self.macro_resolver,
|
self.macro_resolver,
|
||||||
)
|
)
|
||||||
# update with rendered test kwargs (which collects any refs)
|
# update with rendered test kwargs (which collects any refs)
|
||||||
add_rendered_test_kwargs(context, node, capture_macros=True)
|
add_rendered_test_kwargs(context, node, capture_macros=True)
|
||||||
# the parsed node is not rendered in the native context.
|
# the parsed node is not rendered in the native context.
|
||||||
get_rendered(
|
get_rendered(node.raw_sql, context, node, capture_macros=True)
|
||||||
node.raw_sql, context, node, capture_macros=True
|
|
||||||
)
|
|
||||||
self.update_parsed_node(node, config)
|
self.update_parsed_node(node, config)
|
||||||
except ValidationError as exc:
|
except ValidationError as exc:
|
||||||
# we got a ValidationError - probably bad types in config()
|
# we got a ValidationError - probably bad types in config()
|
||||||
@@ -522,9 +508,8 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
|||||||
column_name = None
|
column_name = None
|
||||||
else:
|
else:
|
||||||
column_name = column.name
|
column_name = column.name
|
||||||
should_quote = (
|
should_quote = column.quote or (
|
||||||
column.quote or
|
column.quote is None and target.quote_columns
|
||||||
(column.quote is None and target.quote_columns)
|
|
||||||
)
|
)
|
||||||
if should_quote:
|
if should_quote:
|
||||||
column_name = get_adapter(self.root_project).quote(column_name)
|
column_name = get_adapter(self.root_project).quote(column_name)
|
||||||
@@ -535,10 +520,7 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
|||||||
tags = list(itertools.chain.from_iterable(tags_sources))
|
tags = list(itertools.chain.from_iterable(tags_sources))
|
||||||
|
|
||||||
node = self._parse_generic_test(
|
node = self._parse_generic_test(
|
||||||
target=target,
|
target=target, test=test, tags=tags, column_name=column_name
|
||||||
test=test,
|
|
||||||
tags=tags,
|
|
||||||
column_name=column_name
|
|
||||||
)
|
)
|
||||||
# we can't go through result.add_node - no file... instead!
|
# we can't go through result.add_node - no file... instead!
|
||||||
if node.config.enabled:
|
if node.config.enabled:
|
||||||
@@ -562,7 +544,9 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
|||||||
return node
|
return node
|
||||||
|
|
||||||
def render_with_context(
|
def render_with_context(
|
||||||
self, node: ParsedSchemaTestNode, config: ContextConfig,
|
self,
|
||||||
|
node: ParsedSchemaTestNode,
|
||||||
|
config: ContextConfig,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Given the parsed node and a ContextConfig to use during
|
"""Given the parsed node and a ContextConfig to use during
|
||||||
parsing, collect all the refs that might be squirreled away in the test
|
parsing, collect all the refs that might be squirreled away in the test
|
||||||
@@ -574,9 +558,7 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
|||||||
add_rendered_test_kwargs(context, node, capture_macros=True)
|
add_rendered_test_kwargs(context, node, capture_macros=True)
|
||||||
|
|
||||||
# the parsed node is not rendered in the native context.
|
# the parsed node is not rendered in the native context.
|
||||||
get_rendered(
|
get_rendered(node.raw_sql, context, node, capture_macros=True)
|
||||||
node.raw_sql, context, node, capture_macros=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def parse_test(
|
def parse_test(
|
||||||
self,
|
self,
|
||||||
@@ -592,9 +574,8 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
|||||||
column_tags: List[str] = []
|
column_tags: List[str] = []
|
||||||
else:
|
else:
|
||||||
column_name = column.name
|
column_name = column.name
|
||||||
should_quote = (
|
should_quote = column.quote or (
|
||||||
column.quote or
|
column.quote is None and target_block.quote_columns
|
||||||
(column.quote is None and target_block.quote_columns)
|
|
||||||
)
|
)
|
||||||
if should_quote:
|
if should_quote:
|
||||||
column_name = get_adapter(self.root_project).quote(column_name)
|
column_name = get_adapter(self.root_project).quote(column_name)
|
||||||
@@ -632,8 +613,8 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
|||||||
dct = self.raw_renderer.render_data(dct)
|
dct = self.raw_renderer.render_data(dct)
|
||||||
except CompilationException as exc:
|
except CompilationException as exc:
|
||||||
raise CompilationException(
|
raise CompilationException(
|
||||||
f'Failed to render {block.path.original_file_path} from '
|
f"Failed to render {block.path.original_file_path} from "
|
||||||
f'project {self.project.project_name}: {exc}'
|
f"project {self.project.project_name}: {exc}"
|
||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
# contains the FileBlock and the data (dictionary)
|
# contains the FileBlock and the data (dictionary)
|
||||||
@@ -649,66 +630,57 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
|||||||
|
|
||||||
# NonSourceParser.parse(), TestablePatchParser is a variety of
|
# NonSourceParser.parse(), TestablePatchParser is a variety of
|
||||||
# NodePatchParser
|
# NodePatchParser
|
||||||
if 'models' in dct:
|
if "models" in dct:
|
||||||
parser = TestablePatchParser(self, yaml_block, 'models')
|
parser = TestablePatchParser(self, yaml_block, "models")
|
||||||
for test_block in parser.parse():
|
for test_block in parser.parse():
|
||||||
self.parse_tests(test_block)
|
self.parse_tests(test_block)
|
||||||
|
|
||||||
# NonSourceParser.parse()
|
# NonSourceParser.parse()
|
||||||
if 'seeds' in dct:
|
if "seeds" in dct:
|
||||||
parser = TestablePatchParser(self, yaml_block, 'seeds')
|
parser = TestablePatchParser(self, yaml_block, "seeds")
|
||||||
for test_block in parser.parse():
|
for test_block in parser.parse():
|
||||||
self.parse_tests(test_block)
|
self.parse_tests(test_block)
|
||||||
|
|
||||||
# NonSourceParser.parse()
|
# NonSourceParser.parse()
|
||||||
if 'snapshots' in dct:
|
if "snapshots" in dct:
|
||||||
parser = TestablePatchParser(self, yaml_block, 'snapshots')
|
parser = TestablePatchParser(self, yaml_block, "snapshots")
|
||||||
for test_block in parser.parse():
|
for test_block in parser.parse():
|
||||||
self.parse_tests(test_block)
|
self.parse_tests(test_block)
|
||||||
|
|
||||||
# This parser uses SourceParser.parse() which doesn't return
|
# This parser uses SourceParser.parse() which doesn't return
|
||||||
# any test blocks. Source tests are handled at a later point
|
# any test blocks. Source tests are handled at a later point
|
||||||
# in the process.
|
# in the process.
|
||||||
if 'sources' in dct:
|
if "sources" in dct:
|
||||||
parser = SourceParser(self, yaml_block, 'sources')
|
parser = SourceParser(self, yaml_block, "sources")
|
||||||
parser.parse()
|
parser.parse()
|
||||||
|
|
||||||
# NonSourceParser.parse()
|
# NonSourceParser.parse()
|
||||||
if 'macros' in dct:
|
if "macros" in dct:
|
||||||
parser = MacroPatchParser(self, yaml_block, 'macros')
|
parser = MacroPatchParser(self, yaml_block, "macros")
|
||||||
for test_block in parser.parse():
|
for test_block in parser.parse():
|
||||||
self.parse_tests(test_block)
|
self.parse_tests(test_block)
|
||||||
|
|
||||||
# NonSourceParser.parse()
|
# NonSourceParser.parse()
|
||||||
if 'analyses' in dct:
|
if "analyses" in dct:
|
||||||
parser = AnalysisPatchParser(self, yaml_block, 'analyses')
|
parser = AnalysisPatchParser(self, yaml_block, "analyses")
|
||||||
for test_block in parser.parse():
|
for test_block in parser.parse():
|
||||||
self.parse_tests(test_block)
|
self.parse_tests(test_block)
|
||||||
|
|
||||||
# parse exposures
|
# parse exposures
|
||||||
if 'exposures' in dct:
|
if "exposures" in dct:
|
||||||
self.parse_exposures(yaml_block)
|
self.parse_exposures(yaml_block)
|
||||||
|
|
||||||
|
|
||||||
Parsed = TypeVar(
|
Parsed = TypeVar("Parsed", UnpatchedSourceDefinition, ParsedNodePatch, ParsedMacroPatch)
|
||||||
'Parsed',
|
NodeTarget = TypeVar("NodeTarget", UnparsedNodeUpdate, UnparsedAnalysisUpdate)
|
||||||
UnpatchedSourceDefinition, ParsedNodePatch, ParsedMacroPatch
|
|
||||||
)
|
|
||||||
NodeTarget = TypeVar(
|
|
||||||
'NodeTarget',
|
|
||||||
UnparsedNodeUpdate, UnparsedAnalysisUpdate
|
|
||||||
)
|
|
||||||
NonSourceTarget = TypeVar(
|
NonSourceTarget = TypeVar(
|
||||||
'NonSourceTarget',
|
"NonSourceTarget", UnparsedNodeUpdate, UnparsedAnalysisUpdate, UnparsedMacroUpdate
|
||||||
UnparsedNodeUpdate, UnparsedAnalysisUpdate, UnparsedMacroUpdate
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# abstract base class (ABCMeta)
|
# abstract base class (ABCMeta)
|
||||||
class YamlReader(metaclass=ABCMeta):
|
class YamlReader(metaclass=ABCMeta):
|
||||||
def __init__(
|
def __init__(self, schema_parser: SchemaParser, yaml: YamlBlock, key: str) -> None:
|
||||||
self, schema_parser: SchemaParser, yaml: YamlBlock, key: str
|
|
||||||
) -> None:
|
|
||||||
self.schema_parser = schema_parser
|
self.schema_parser = schema_parser
|
||||||
# key: models, seeds, snapshots, sources, macros,
|
# key: models, seeds, snapshots, sources, macros,
|
||||||
# analyses, exposures
|
# analyses, exposures
|
||||||
@@ -738,8 +710,9 @@ class YamlReader(metaclass=ABCMeta):
|
|||||||
data = self.yaml.data.get(self.key, [])
|
data = self.yaml.data.get(self.key, [])
|
||||||
if not isinstance(data, list):
|
if not isinstance(data, list):
|
||||||
raise CompilationException(
|
raise CompilationException(
|
||||||
'{} must be a list, got {} instead: ({})'
|
"{} must be a list, got {} instead: ({})".format(
|
||||||
.format(self.key, type(data), _trimmed(str(data)))
|
self.key, type(data), _trimmed(str(data))
|
||||||
|
)
|
||||||
)
|
)
|
||||||
path = self.yaml.path.original_file_path
|
path = self.yaml.path.original_file_path
|
||||||
|
|
||||||
@@ -751,7 +724,7 @@ class YamlReader(metaclass=ABCMeta):
|
|||||||
yield entry
|
yield entry
|
||||||
else:
|
else:
|
||||||
msg = error_context(
|
msg = error_context(
|
||||||
path, self.key, data, 'expected a dict with string keys'
|
path, self.key, data, "expected a dict with string keys"
|
||||||
)
|
)
|
||||||
raise CompilationException(msg)
|
raise CompilationException(msg)
|
||||||
|
|
||||||
@@ -759,10 +732,10 @@ class YamlReader(metaclass=ABCMeta):
|
|||||||
class YamlDocsReader(YamlReader):
|
class YamlDocsReader(YamlReader):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def parse(self) -> List[TestBlock]:
|
def parse(self) -> List[TestBlock]:
|
||||||
raise NotImplementedError('parse is abstract')
|
raise NotImplementedError("parse is abstract")
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar('T', bound=dbtClassMixin)
|
T = TypeVar("T", bound=dbtClassMixin)
|
||||||
|
|
||||||
|
|
||||||
class SourceParser(YamlDocsReader):
|
class SourceParser(YamlDocsReader):
|
||||||
@@ -779,13 +752,11 @@ class SourceParser(YamlDocsReader):
|
|||||||
def parse(self) -> List[TestBlock]:
|
def parse(self) -> List[TestBlock]:
|
||||||
# get a verified list of dicts for the key handled by this parser
|
# get a verified list of dicts for the key handled by this parser
|
||||||
for data in self.get_key_dicts():
|
for data in self.get_key_dicts():
|
||||||
data = self.project.credentials.translate_aliases(
|
data = self.project.credentials.translate_aliases(data, recurse=True)
|
||||||
data, recurse=True
|
|
||||||
)
|
|
||||||
|
|
||||||
is_override = 'overrides' in data
|
is_override = "overrides" in data
|
||||||
if is_override:
|
if is_override:
|
||||||
data['path'] = self.yaml.path.original_file_path
|
data["path"] = self.yaml.path.original_file_path
|
||||||
patch = self._target_from_dict(SourcePatch, data)
|
patch = self._target_from_dict(SourcePatch, data)
|
||||||
self.results.add_source_patch(self.yaml.file, patch)
|
self.results.add_source_patch(self.yaml.file, patch)
|
||||||
else:
|
else:
|
||||||
@@ -797,10 +768,9 @@ class SourceParser(YamlDocsReader):
|
|||||||
original_file_path = self.yaml.path.original_file_path
|
original_file_path = self.yaml.path.original_file_path
|
||||||
fqn_path = self.yaml.path.relative_path
|
fqn_path = self.yaml.path.relative_path
|
||||||
for table in source.tables:
|
for table in source.tables:
|
||||||
unique_id = '.'.join([
|
unique_id = ".".join(
|
||||||
NodeType.Source, self.project.project_name,
|
[NodeType.Source, self.project.project_name, source.name, table.name]
|
||||||
source.name, table.name
|
)
|
||||||
])
|
|
||||||
|
|
||||||
# the FQN is project name / path elements /source_name /table_name
|
# the FQN is project name / path elements /source_name /table_name
|
||||||
fqn = self.schema_parser.get_fqn_prefix(fqn_path)
|
fqn = self.schema_parser.get_fqn_prefix(fqn_path)
|
||||||
@@ -825,17 +795,15 @@ class SourceParser(YamlDocsReader):
|
|||||||
class NonSourceParser(YamlDocsReader, Generic[NonSourceTarget, Parsed]):
|
class NonSourceParser(YamlDocsReader, Generic[NonSourceTarget, Parsed]):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _target_type(self) -> Type[NonSourceTarget]:
|
def _target_type(self) -> Type[NonSourceTarget]:
|
||||||
raise NotImplementedError('_target_type not implemented')
|
raise NotImplementedError("_target_type not implemented")
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_block(self, node: NonSourceTarget) -> TargetBlock:
|
def get_block(self, node: NonSourceTarget) -> TargetBlock:
|
||||||
raise NotImplementedError('get_block is abstract')
|
raise NotImplementedError("get_block is abstract")
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def parse_patch(
|
def parse_patch(self, block: TargetBlock[NonSourceTarget], refs: ParserRef) -> None:
|
||||||
self, block: TargetBlock[NonSourceTarget], refs: ParserRef
|
raise NotImplementedError("parse_patch is abstract")
|
||||||
) -> None:
|
|
||||||
raise NotImplementedError('parse_patch is abstract')
|
|
||||||
|
|
||||||
def parse(self) -> List[TestBlock]:
|
def parse(self) -> List[TestBlock]:
|
||||||
node: NonSourceTarget
|
node: NonSourceTarget
|
||||||
@@ -874,11 +842,13 @@ class NonSourceParser(YamlDocsReader, Generic[NonSourceTarget, Parsed]):
|
|||||||
for data in key_dicts:
|
for data in key_dicts:
|
||||||
# add extra data to each dict. This updates the dicts
|
# add extra data to each dict. This updates the dicts
|
||||||
# in the parser yaml
|
# in the parser yaml
|
||||||
data.update({
|
data.update(
|
||||||
'original_file_path': path,
|
{
|
||||||
'yaml_key': self.key,
|
"original_file_path": path,
|
||||||
'package_name': self.project.project_name,
|
"yaml_key": self.key,
|
||||||
})
|
"package_name": self.project.project_name,
|
||||||
|
}
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
# target_type: UnparsedNodeUpdate, UnparsedAnalysisUpdate,
|
# target_type: UnparsedNodeUpdate, UnparsedAnalysisUpdate,
|
||||||
# or UnparsedMacroUpdate
|
# or UnparsedMacroUpdate
|
||||||
@@ -892,12 +862,9 @@ class NonSourceParser(YamlDocsReader, Generic[NonSourceTarget, Parsed]):
|
|||||||
|
|
||||||
|
|
||||||
class NodePatchParser(
|
class NodePatchParser(
|
||||||
NonSourceParser[NodeTarget, ParsedNodePatch],
|
NonSourceParser[NodeTarget, ParsedNodePatch], Generic[NodeTarget]
|
||||||
Generic[NodeTarget]
|
|
||||||
):
|
):
|
||||||
def parse_patch(
|
def parse_patch(self, block: TargetBlock[NodeTarget], refs: ParserRef) -> None:
|
||||||
self, block: TargetBlock[NodeTarget], refs: ParserRef
|
|
||||||
) -> None:
|
|
||||||
result = ParsedNodePatch(
|
result = ParsedNodePatch(
|
||||||
name=block.target.name,
|
name=block.target.name,
|
||||||
original_file_path=block.target.original_file_path,
|
original_file_path=block.target.original_file_path,
|
||||||
@@ -958,7 +925,7 @@ class ExposureParser(YamlReader):
|
|||||||
|
|
||||||
def parse_exposure(self, unparsed: UnparsedExposure) -> ParsedExposure:
|
def parse_exposure(self, unparsed: UnparsedExposure) -> ParsedExposure:
|
||||||
package_name = self.project.project_name
|
package_name = self.project.project_name
|
||||||
unique_id = f'{NodeType.Exposure}.{package_name}.{unparsed.name}'
|
unique_id = f"{NodeType.Exposure}.{package_name}.{unparsed.name}"
|
||||||
path = self.yaml.path.relative_path
|
path = self.yaml.path.relative_path
|
||||||
|
|
||||||
fqn = self.schema_parser.get_fqn_prefix(path)
|
fqn = self.schema_parser.get_fqn_prefix(path)
|
||||||
@@ -984,12 +951,10 @@ class ExposureParser(YamlReader):
|
|||||||
self.schema_parser.macro_manifest,
|
self.schema_parser.macro_manifest,
|
||||||
package_name,
|
package_name,
|
||||||
)
|
)
|
||||||
depends_on_jinja = '\n'.join(
|
depends_on_jinja = "\n".join(
|
||||||
'{{ ' + line + '}}' for line in unparsed.depends_on
|
"{{ " + line + "}}" for line in unparsed.depends_on
|
||||||
)
|
|
||||||
get_rendered(
|
|
||||||
depends_on_jinja, ctx, parsed, capture_macros=True
|
|
||||||
)
|
)
|
||||||
|
get_rendered(depends_on_jinja, ctx, parsed, capture_macros=True)
|
||||||
# parsed now has a populated refs/sources
|
# parsed now has a populated refs/sources
|
||||||
return parsed
|
return parsed
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (
|
from typing import List, Callable, Iterable, Set, Union, Iterator, TypeVar, Generic
|
||||||
List, Callable, Iterable, Set, Union, Iterator, TypeVar, Generic
|
|
||||||
)
|
|
||||||
|
|
||||||
from dbt.clients.jinja import extract_toplevel_blocks, BlockTag
|
from dbt.clients.jinja import extract_toplevel_blocks, BlockTag
|
||||||
from dbt.clients.system import find_matching
|
from dbt.clients.system import find_matching
|
||||||
@@ -72,13 +70,13 @@ class FilesystemSearcher(Iterable[FilePath]):
|
|||||||
root = self.project.project_root
|
root = self.project.project_root
|
||||||
|
|
||||||
for result in find_matching(root, self.relative_dirs, ext):
|
for result in find_matching(root, self.relative_dirs, ext):
|
||||||
if 'searched_path' not in result or 'relative_path' not in result:
|
if "searched_path" not in result or "relative_path" not in result:
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
'Invalid result from find_matching: {}'.format(result)
|
"Invalid result from find_matching: {}".format(result)
|
||||||
)
|
)
|
||||||
file_match = FilePath(
|
file_match = FilePath(
|
||||||
searched_path=result['searched_path'],
|
searched_path=result["searched_path"],
|
||||||
relative_path=result['relative_path'],
|
relative_path=result["relative_path"],
|
||||||
project_root=root,
|
project_root=root,
|
||||||
)
|
)
|
||||||
yield file_match
|
yield file_match
|
||||||
@@ -86,7 +84,7 @@ class FilesystemSearcher(Iterable[FilePath]):
|
|||||||
|
|
||||||
Block = Union[BlockContents, FullBlock]
|
Block = Union[BlockContents, FullBlock]
|
||||||
|
|
||||||
BlockSearchResult = TypeVar('BlockSearchResult', BlockContents, FullBlock)
|
BlockSearchResult = TypeVar("BlockSearchResult", BlockContents, FullBlock)
|
||||||
|
|
||||||
BlockSearchResultFactory = Callable[[SourceFile, BlockTag], BlockSearchResult]
|
BlockSearchResultFactory = Callable[[SourceFile, BlockTag], BlockSearchResult]
|
||||||
|
|
||||||
@@ -96,7 +94,7 @@ class BlockSearcher(Generic[BlockSearchResult], Iterable[BlockSearchResult]):
|
|||||||
self,
|
self,
|
||||||
source: List[FileBlock],
|
source: List[FileBlock],
|
||||||
allowed_blocks: Set[str],
|
allowed_blocks: Set[str],
|
||||||
source_tag_factory: BlockSearchResultFactory
|
source_tag_factory: BlockSearchResultFactory,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.source = source
|
self.source = source
|
||||||
self.allowed_blocks = allowed_blocks
|
self.allowed_blocks = allowed_blocks
|
||||||
@@ -107,7 +105,7 @@ class BlockSearcher(Generic[BlockSearchResult], Iterable[BlockSearchResult]):
|
|||||||
blocks = extract_toplevel_blocks(
|
blocks = extract_toplevel_blocks(
|
||||||
source_file.contents,
|
source_file.contents,
|
||||||
allowed_blocks=self.allowed_blocks,
|
allowed_blocks=self.allowed_blocks,
|
||||||
collect_raw_data=False
|
collect_raw_data=False,
|
||||||
)
|
)
|
||||||
# this makes mypy happy, and this is an invariant we really need
|
# this makes mypy happy, and this is an invariant we really need
|
||||||
for block in blocks:
|
for block in blocks:
|
||||||
|
|||||||
@@ -8,9 +8,7 @@ from dbt.parser.search import FileBlock, FilesystemSearcher
|
|||||||
|
|
||||||
class SeedParser(SimpleSQLParser[ParsedSeedNode]):
|
class SeedParser(SimpleSQLParser[ParsedSeedNode]):
|
||||||
def get_paths(self):
|
def get_paths(self):
|
||||||
return FilesystemSearcher(
|
return FilesystemSearcher(self.project, self.project.data_paths, ".csv")
|
||||||
self.project, self.project.data_paths, '.csv'
|
|
||||||
)
|
|
||||||
|
|
||||||
def parse_from_dict(self, dct, validate=True) -> ParsedSeedNode:
|
def parse_from_dict(self, dct, validate=True) -> ParsedSeedNode:
|
||||||
if validate:
|
if validate:
|
||||||
@@ -30,9 +28,7 @@ class SeedParser(SimpleSQLParser[ParsedSeedNode]):
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Seeds don't need to do any rendering."""
|
"""Seeds don't need to do any rendering."""
|
||||||
|
|
||||||
def load_file(
|
def load_file(self, match: FilePath, *, set_contents: bool = False) -> SourceFile:
|
||||||
self, match: FilePath, *, set_contents: bool = False
|
|
||||||
) -> SourceFile:
|
|
||||||
if match.seed_too_large():
|
if match.seed_too_large():
|
||||||
# We don't want to calculate a hash of this file. Use the path.
|
# We don't want to calculate a hash of this file. Use the path.
|
||||||
return SourceFile.big_seed(match)
|
return SourceFile.big_seed(match)
|
||||||
|
|||||||
@@ -3,27 +3,22 @@ from typing import List
|
|||||||
|
|
||||||
from dbt.dataclass_schema import ValidationError
|
from dbt.dataclass_schema import ValidationError
|
||||||
|
|
||||||
from dbt.contracts.graph.parsed import (
|
from dbt.contracts.graph.parsed import IntermediateSnapshotNode, ParsedSnapshotNode
|
||||||
IntermediateSnapshotNode, ParsedSnapshotNode
|
from dbt.exceptions import CompilationException, validator_error_message
|
||||||
)
|
|
||||||
from dbt.exceptions import (
|
|
||||||
CompilationException, validator_error_message
|
|
||||||
)
|
|
||||||
from dbt.node_types import NodeType
|
from dbt.node_types import NodeType
|
||||||
from dbt.parser.base import SQLParser
|
from dbt.parser.base import SQLParser
|
||||||
from dbt.parser.search import (
|
from dbt.parser.search import (
|
||||||
FilesystemSearcher, BlockContents, BlockSearcher, FileBlock
|
FilesystemSearcher,
|
||||||
|
BlockContents,
|
||||||
|
BlockSearcher,
|
||||||
|
FileBlock,
|
||||||
)
|
)
|
||||||
from dbt.utils import split_path
|
from dbt.utils import split_path
|
||||||
|
|
||||||
|
|
||||||
class SnapshotParser(
|
class SnapshotParser(SQLParser[IntermediateSnapshotNode, ParsedSnapshotNode]):
|
||||||
SQLParser[IntermediateSnapshotNode, ParsedSnapshotNode]
|
|
||||||
):
|
|
||||||
def get_paths(self):
|
def get_paths(self):
|
||||||
return FilesystemSearcher(
|
return FilesystemSearcher(self.project, self.project.snapshot_paths, ".sql")
|
||||||
self.project, self.project.snapshot_paths, '.sql'
|
|
||||||
)
|
|
||||||
|
|
||||||
def parse_from_dict(self, dct, validate=True) -> IntermediateSnapshotNode:
|
def parse_from_dict(self, dct, validate=True) -> IntermediateSnapshotNode:
|
||||||
if validate:
|
if validate:
|
||||||
@@ -78,7 +73,7 @@ class SnapshotParser(
|
|||||||
def parse_file(self, file_block: FileBlock) -> None:
|
def parse_file(self, file_block: FileBlock) -> None:
|
||||||
blocks = BlockSearcher(
|
blocks = BlockSearcher(
|
||||||
source=[file_block],
|
source=[file_block],
|
||||||
allowed_blocks={'snapshot'},
|
allowed_blocks={"snapshot"},
|
||||||
source_tag_factory=BlockContents,
|
source_tag_factory=BlockContents,
|
||||||
)
|
)
|
||||||
for block in blocks:
|
for block in blocks:
|
||||||
|
|||||||
@@ -34,8 +34,7 @@ class SourcePatcher:
|
|||||||
self.results = results
|
self.results = results
|
||||||
self.root_project = root_project
|
self.root_project = root_project
|
||||||
self.macro_manifest = MacroManifest(
|
self.macro_manifest = MacroManifest(
|
||||||
macros=self.results.macros,
|
macros=self.results.macros, files=self.results.files
|
||||||
files=self.results.files
|
|
||||||
)
|
)
|
||||||
self.schema_parsers: Dict[str, SchemaParser] = {}
|
self.schema_parsers: Dict[str, SchemaParser] = {}
|
||||||
self.patches_used: Dict[SourceKey, Set[str]] = {}
|
self.patches_used: Dict[SourceKey, Set[str]] = {}
|
||||||
@@ -65,9 +64,7 @@ class SourcePatcher:
|
|||||||
|
|
||||||
source = UnparsedSourceDefinition.from_dict(source_dct)
|
source = UnparsedSourceDefinition.from_dict(source_dct)
|
||||||
table = UnparsedSourceTableDefinition.from_dict(table_dct)
|
table = UnparsedSourceTableDefinition.from_dict(table_dct)
|
||||||
return unpatched.replace(
|
return unpatched.replace(source=source, table=table, patch_path=patch_path)
|
||||||
source=source, table=table, patch_path=patch_path
|
|
||||||
)
|
|
||||||
|
|
||||||
def parse_source_docs(self, block: UnpatchedSourceDefinition) -> ParserRef:
|
def parse_source_docs(self, block: UnpatchedSourceDefinition) -> ParserRef:
|
||||||
refs = ParserRef()
|
refs = ParserRef()
|
||||||
@@ -78,7 +75,7 @@ class SourcePatcher:
|
|||||||
refs.add(column, description, data_type, meta)
|
refs.add(column, description, data_type, meta)
|
||||||
return refs
|
return refs
|
||||||
|
|
||||||
def get_schema_parser_for(self, package_name: str) -> 'SchemaParser':
|
def get_schema_parser_for(self, package_name: str) -> "SchemaParser":
|
||||||
if package_name in self.schema_parsers:
|
if package_name in self.schema_parsers:
|
||||||
schema_parser = self.schema_parsers[package_name]
|
schema_parser = self.schema_parsers[package_name]
|
||||||
else:
|
else:
|
||||||
@@ -157,31 +154,28 @@ class SourcePatcher:
|
|||||||
|
|
||||||
if unused_tables:
|
if unused_tables:
|
||||||
msg = self.get_unused_msg(unused_tables)
|
msg = self.get_unused_msg(unused_tables)
|
||||||
warn_or_error(msg, log_fmt=ui.warning_tag('{}'))
|
warn_or_error(msg, log_fmt=ui.warning_tag("{}"))
|
||||||
|
|
||||||
def get_unused_msg(
|
def get_unused_msg(
|
||||||
self,
|
self,
|
||||||
unused_tables: Dict[SourceKey, Optional[Set[str]]],
|
unused_tables: Dict[SourceKey, Optional[Set[str]]],
|
||||||
) -> str:
|
) -> str:
|
||||||
msg = [
|
msg = [
|
||||||
'During parsing, dbt encountered source overrides that had no '
|
"During parsing, dbt encountered source overrides that had no " "target:",
|
||||||
'target:',
|
|
||||||
]
|
]
|
||||||
for key, table_names in unused_tables.items():
|
for key, table_names in unused_tables.items():
|
||||||
patch = self.results.source_patches[key]
|
patch = self.results.source_patches[key]
|
||||||
patch_name = f'{patch.overrides}.{patch.name}'
|
patch_name = f"{patch.overrides}.{patch.name}"
|
||||||
if table_names is None:
|
if table_names is None:
|
||||||
msg.append(
|
msg.append(f" - Source {patch_name} (in {patch.path})")
|
||||||
f' - Source {patch_name} (in {patch.path})'
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
for table_name in sorted(table_names):
|
for table_name in sorted(table_names):
|
||||||
msg.append(
|
msg.append(
|
||||||
f' - Source table {patch_name}.{table_name} '
|
f" - Source table {patch_name}.{table_name} "
|
||||||
f'(in {patch.path})'
|
f"(in {patch.path})"
|
||||||
)
|
)
|
||||||
msg.append('')
|
msg.append("")
|
||||||
return '\n'.join(msg)
|
return "\n".join(msg)
|
||||||
|
|
||||||
|
|
||||||
def patch_sources(
|
def patch_sources(
|
||||||
|
|||||||
@@ -15,5 +15,5 @@ def profiler(enable, outfile):
|
|||||||
if enable:
|
if enable:
|
||||||
profiler.disable()
|
profiler.disable()
|
||||||
stats = Stats(profiler)
|
stats = Stats(profiler)
|
||||||
stats.sort_stats('tottime')
|
stats.sort_stats("tottime")
|
||||||
stats.dump_stats(outfile)
|
stats.dump_stats(outfile)
|
||||||
|
|||||||
@@ -46,14 +46,14 @@ from dbt.rpc.task_handler import RequestTaskHandler
|
|||||||
|
|
||||||
|
|
||||||
class GC(RemoteBuiltinMethod[GCParameters, GCResult]):
|
class GC(RemoteBuiltinMethod[GCParameters, GCResult]):
|
||||||
METHOD_NAME = 'gc'
|
METHOD_NAME = "gc"
|
||||||
|
|
||||||
def set_args(self, params: GCParameters):
|
def set_args(self, params: GCParameters):
|
||||||
super().set_args(params)
|
super().set_args(params)
|
||||||
|
|
||||||
def handle_request(self) -> GCResult:
|
def handle_request(self) -> GCResult:
|
||||||
if self.params is None:
|
if self.params is None:
|
||||||
raise dbt.exceptions.InternalException('GC: params not set')
|
raise dbt.exceptions.InternalException("GC: params not set")
|
||||||
return self.task_manager.gc_safe(
|
return self.task_manager.gc_safe(
|
||||||
task_ids=self.params.task_ids,
|
task_ids=self.params.task_ids,
|
||||||
before=self.params.before,
|
before=self.params.before,
|
||||||
@@ -62,14 +62,14 @@ class GC(RemoteBuiltinMethod[GCParameters, GCResult]):
|
|||||||
|
|
||||||
|
|
||||||
class Kill(RemoteBuiltinMethod[KillParameters, KillResult]):
|
class Kill(RemoteBuiltinMethod[KillParameters, KillResult]):
|
||||||
METHOD_NAME = 'kill'
|
METHOD_NAME = "kill"
|
||||||
|
|
||||||
def set_args(self, params: KillParameters):
|
def set_args(self, params: KillParameters):
|
||||||
super().set_args(params)
|
super().set_args(params)
|
||||||
|
|
||||||
def handle_request(self) -> KillResult:
|
def handle_request(self) -> KillResult:
|
||||||
if self.params is None:
|
if self.params is None:
|
||||||
raise dbt.exceptions.InternalException('Kill: params not set')
|
raise dbt.exceptions.InternalException("Kill: params not set")
|
||||||
result = KillResult()
|
result = KillResult()
|
||||||
task: RequestTaskHandler
|
task: RequestTaskHandler
|
||||||
try:
|
try:
|
||||||
@@ -99,7 +99,7 @@ class Kill(RemoteBuiltinMethod[KillParameters, KillResult]):
|
|||||||
|
|
||||||
|
|
||||||
class Status(RemoteBuiltinMethod[StatusParameters, LastParse]):
|
class Status(RemoteBuiltinMethod[StatusParameters, LastParse]):
|
||||||
METHOD_NAME = 'status'
|
METHOD_NAME = "status"
|
||||||
|
|
||||||
def set_args(self, params: StatusParameters):
|
def set_args(self, params: StatusParameters):
|
||||||
super().set_args(params)
|
super().set_args(params)
|
||||||
@@ -109,14 +109,14 @@ class Status(RemoteBuiltinMethod[StatusParameters, LastParse]):
|
|||||||
|
|
||||||
|
|
||||||
class PS(RemoteBuiltinMethod[PSParameters, PSResult]):
|
class PS(RemoteBuiltinMethod[PSParameters, PSResult]):
|
||||||
METHOD_NAME = 'ps'
|
METHOD_NAME = "ps"
|
||||||
|
|
||||||
def set_args(self, params: PSParameters):
|
def set_args(self, params: PSParameters):
|
||||||
super().set_args(params)
|
super().set_args(params)
|
||||||
|
|
||||||
def keep(self, row: TaskRow):
|
def keep(self, row: TaskRow):
|
||||||
if self.params is None:
|
if self.params is None:
|
||||||
raise dbt.exceptions.InternalException('PS: params not set')
|
raise dbt.exceptions.InternalException("PS: params not set")
|
||||||
if row.state.finished and self.params.completed:
|
if row.state.finished and self.params.completed:
|
||||||
return True
|
return True
|
||||||
elif not row.state.finished and self.params.active:
|
elif not row.state.finished and self.params.active:
|
||||||
@@ -125,9 +125,7 @@ class PS(RemoteBuiltinMethod[PSParameters, PSResult]):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def handle_request(self) -> PSResult:
|
def handle_request(self) -> PSResult:
|
||||||
rows = [
|
rows = [row for row in self.task_manager.task_table() if self.keep(row)]
|
||||||
row for row in self.task_manager.task_table() if self.keep(row)
|
|
||||||
]
|
|
||||||
rows.sort(key=lambda r: (r.state, r.start, r.method))
|
rows.sort(key=lambda r: (r.state, r.start, r.method))
|
||||||
result = PSResult(rows=rows, logs=[])
|
result = PSResult(rows=rows, logs=[])
|
||||||
return result
|
return result
|
||||||
@@ -138,10 +136,11 @@ def poll_complete(
|
|||||||
) -> PollResult:
|
) -> PollResult:
|
||||||
if timing.state not in (TaskHandlerState.Success, TaskHandlerState.Failed):
|
if timing.state not in (TaskHandlerState.Success, TaskHandlerState.Failed):
|
||||||
raise dbt.exceptions.InternalException(
|
raise dbt.exceptions.InternalException(
|
||||||
f'got invalid result state in poll_complete: {timing.state}'
|
f"got invalid result state in poll_complete: {timing.state}"
|
||||||
)
|
)
|
||||||
|
|
||||||
cls: Type[Union[
|
cls: Type[
|
||||||
|
Union[
|
||||||
PollExecuteCompleteResult,
|
PollExecuteCompleteResult,
|
||||||
PollRunCompleteResult,
|
PollRunCompleteResult,
|
||||||
PollCompileCompleteResult,
|
PollCompileCompleteResult,
|
||||||
@@ -150,7 +149,8 @@ def poll_complete(
|
|||||||
PollRunOperationCompleteResult,
|
PollRunOperationCompleteResult,
|
||||||
PollGetManifestResult,
|
PollGetManifestResult,
|
||||||
PollFreshnessResult,
|
PollFreshnessResult,
|
||||||
]]
|
]
|
||||||
|
]
|
||||||
|
|
||||||
if isinstance(result, RemoteExecutionResult):
|
if isinstance(result, RemoteExecutionResult):
|
||||||
cls = PollExecuteCompleteResult
|
cls = PollExecuteCompleteResult
|
||||||
@@ -171,7 +171,7 @@ def poll_complete(
|
|||||||
cls = PollFreshnessResult
|
cls = PollFreshnessResult
|
||||||
else:
|
else:
|
||||||
raise dbt.exceptions.InternalException(
|
raise dbt.exceptions.InternalException(
|
||||||
'got invalid result in poll_complete: {}'.format(result)
|
"got invalid result in poll_complete: {}".format(result)
|
||||||
)
|
)
|
||||||
return cls.from_result(result, tags, timing, logs)
|
return cls.from_result(result, tags, timing, logs)
|
||||||
|
|
||||||
@@ -181,14 +181,14 @@ def _dict_logs(logs: List[LogMessage]) -> List[Dict[str, Any]]:
|
|||||||
|
|
||||||
|
|
||||||
class Poll(RemoteBuiltinMethod[PollParameters, PollResult]):
|
class Poll(RemoteBuiltinMethod[PollParameters, PollResult]):
|
||||||
METHOD_NAME = 'poll'
|
METHOD_NAME = "poll"
|
||||||
|
|
||||||
def set_args(self, params: PollParameters):
|
def set_args(self, params: PollParameters):
|
||||||
super().set_args(params)
|
super().set_args(params)
|
||||||
|
|
||||||
def handle_request(self) -> PollResult:
|
def handle_request(self) -> PollResult:
|
||||||
if self.params is None:
|
if self.params is None:
|
||||||
raise dbt.exceptions.InternalException('Poll: params not set')
|
raise dbt.exceptions.InternalException("Poll: params not set")
|
||||||
task_id = self.params.request_token
|
task_id = self.params.request_token
|
||||||
task: RequestTaskHandler = self.task_manager.get_request(task_id)
|
task: RequestTaskHandler = self.task_manager.get_request(task_id)
|
||||||
|
|
||||||
@@ -216,7 +216,7 @@ class Poll(RemoteBuiltinMethod[PollParameters, PollResult]):
|
|||||||
err = task.error
|
err = task.error
|
||||||
if err is None:
|
if err is None:
|
||||||
exc = dbt.exceptions.InternalException(
|
exc = dbt.exceptions.InternalException(
|
||||||
f'At end of task {task_id}, error state but error is None'
|
f"At end of task {task_id}, error state but error is None"
|
||||||
)
|
)
|
||||||
raise RPCException.from_error(
|
raise RPCException.from_error(
|
||||||
dbt_error(exc, logs=_dict_logs(task_logs))
|
dbt_error(exc, logs=_dict_logs(task_logs))
|
||||||
@@ -228,17 +228,13 @@ class Poll(RemoteBuiltinMethod[PollParameters, PollResult]):
|
|||||||
|
|
||||||
if task.result is None:
|
if task.result is None:
|
||||||
exc = dbt.exceptions.InternalException(
|
exc = dbt.exceptions.InternalException(
|
||||||
f'At end of task {task_id}, state={state} but result is '
|
f"At end of task {task_id}, state={state} but result is " "None"
|
||||||
'None'
|
|
||||||
)
|
)
|
||||||
raise RPCException.from_error(
|
raise RPCException.from_error(
|
||||||
dbt_error(exc, logs=_dict_logs(task_logs))
|
dbt_error(exc, logs=_dict_logs(task_logs))
|
||||||
)
|
)
|
||||||
return poll_complete(
|
return poll_complete(
|
||||||
timing=timing,
|
timing=timing, result=task.result, tags=task.tags, logs=task_logs
|
||||||
result=task.result,
|
|
||||||
tags=task.tags,
|
|
||||||
logs=task_logs
|
|
||||||
)
|
)
|
||||||
elif state == TaskHandlerState.Killed:
|
elif state == TaskHandlerState.Killed:
|
||||||
return PollKilledResult(
|
return PollKilledResult(
|
||||||
@@ -251,8 +247,6 @@ class Poll(RemoteBuiltinMethod[PollParameters, PollResult]):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
exc = dbt.exceptions.InternalException(
|
exc = dbt.exceptions.InternalException(
|
||||||
f'Got unknown value state={state} for task {task_id}'
|
f"Got unknown value state={state} for task {task_id}"
|
||||||
)
|
|
||||||
raise RPCException.from_error(
|
|
||||||
dbt_error(exc, logs=_dict_logs(task_logs))
|
|
||||||
)
|
)
|
||||||
|
raise RPCException.from_error(dbt_error(exc, logs=_dict_logs(task_logs)))
|
||||||
|
|||||||
@@ -12,45 +12,44 @@ class RPCException(JSONRPCDispatchException):
|
|||||||
message: Optional[str] = None,
|
message: Optional[str] = None,
|
||||||
data: Optional[Dict[str, Any]] = None,
|
data: Optional[Dict[str, Any]] = None,
|
||||||
logs: Optional[List[Dict[str, Any]]] = None,
|
logs: Optional[List[Dict[str, Any]]] = None,
|
||||||
tags: Optional[Dict[str, Any]] = None
|
tags: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if code is None:
|
if code is None:
|
||||||
code = -32000
|
code = -32000
|
||||||
if message is None:
|
if message is None:
|
||||||
message = 'Server error'
|
message = "Server error"
|
||||||
if data is None:
|
if data is None:
|
||||||
data = {}
|
data = {}
|
||||||
|
|
||||||
super().__init__(code=code, message=message, data=data)
|
super().__init__(code=code, message=message, data=data)
|
||||||
if logs is not None:
|
if logs is not None:
|
||||||
self.logs = logs
|
self.logs = logs
|
||||||
self.error.data['tags'] = tags
|
self.error.data["tags"] = tags
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return (
|
return "RPCException({0.code}, {0.message}, {0.data}, {1.logs})".format(
|
||||||
'RPCException({0.code}, {0.message}, {0.data}, {1.logs})'
|
self.error, self
|
||||||
.format(self.error, self)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def logs(self) -> List[Dict[str, Any]]:
|
def logs(self) -> List[Dict[str, Any]]:
|
||||||
return self.error.data.get('logs')
|
return self.error.data.get("logs")
|
||||||
|
|
||||||
@logs.setter
|
@logs.setter
|
||||||
def logs(self, value):
|
def logs(self, value):
|
||||||
if value is None:
|
if value is None:
|
||||||
return
|
return
|
||||||
self.error.data['logs'] = value
|
self.error.data["logs"] = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tags(self):
|
def tags(self):
|
||||||
return self.error.data.get('tags')
|
return self.error.data.get("tags")
|
||||||
|
|
||||||
@tags.setter
|
@tags.setter
|
||||||
def tags(self, value):
|
def tags(self, value):
|
||||||
if value is None:
|
if value is None:
|
||||||
return
|
return
|
||||||
self.error.data['tags'] = value
|
self.error.data["tags"] = value
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_error(cls, err):
|
def from_error(cls, err):
|
||||||
@@ -58,16 +57,14 @@ class RPCException(JSONRPCDispatchException):
|
|||||||
code=err.code,
|
code=err.code,
|
||||||
message=err.message,
|
message=err.message,
|
||||||
data=err.data,
|
data=err.data,
|
||||||
logs=err.data.get('logs'),
|
logs=err.data.get("logs"),
|
||||||
tags=err.data.get('tags'),
|
tags=err.data.get("tags"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def invalid_params(data):
|
def invalid_params(data):
|
||||||
return RPCException(
|
return RPCException(
|
||||||
code=JSONRPCInvalidParams.CODE,
|
code=JSONRPCInvalidParams.CODE, message=JSONRPCInvalidParams.MESSAGE, data=data
|
||||||
message=JSONRPCInvalidParams.MESSAGE,
|
|
||||||
data=data
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -82,6 +79,7 @@ def timeout_error(timeout_value, logs=None, tags=None):
|
|||||||
|
|
||||||
|
|
||||||
def dbt_error(exc, logs=None, tags=None):
|
def dbt_error(exc, logs=None, tags=None):
|
||||||
exc = RPCException(code=exc.CODE, message=exc.MESSAGE, data=exc.data(),
|
exc = RPCException(
|
||||||
logs=logs, tags=tags)
|
code=exc.CODE, message=exc.MESSAGE, data=exc.data(), logs=logs, tags=tags
|
||||||
|
)
|
||||||
return exc
|
return exc
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user