Compare commits

...

1 Commits

Author SHA1 Message Date
Kyle Wigley
42058de028 --wip-- 2021-03-22 09:19:26 -04:00
423 changed files with 23484 additions and 24492 deletions

View File

@@ -4,7 +4,7 @@ parse = (?P<major>\d+)
\.(?P<minor>\d+)
\.(?P<patch>\d+)
((?P<prerelease>[a-z]+)(?P<num>\d+))?
serialize =
serialize =
{major}.{minor}.{patch}{prerelease}{num}
{major}.{minor}.{patch}
commit = False
@@ -12,7 +12,7 @@ tag = False
[bumpversion:part:prerelease]
first_value = a
values =
values =
a
b
rc
@@ -41,4 +41,3 @@ first_value = 1
[bumpversion:file:plugins/snowflake/dbt/adapters/snowflake/__version__.py]
[bumpversion:file:plugins/bigquery/dbt/adapters/bigquery/__version__.py]

20
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,20 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 20.8b1
hooks:
- id: black
- repo: https://gitlab.com/PyCQA/flake8
rev: 3.9.0
hooks:
- id: flake8
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.812
hooks:
- id: mypy
files: ^core/dbt/

View File

@@ -1,4 +1,4 @@
The core function of dbt is SQL compilation and execution. Users create projects of dbt resources (models, tests, seeds, snapshots, ...), defined in SQL and YAML files, and they invoke dbt to create, update, or query associated views and tables. Today, dbt makes heavy use of Jinja2 to enable the templating of SQL, and to construct a DAG (Directed Acyclic Graph) from all of the resources in a project. Users can also extend their projects by installing resources (including Jinja macros) from other projects, called "packages."
The core function of dbt is SQL compilation and execution. Users create projects of dbt resources (models, tests, seeds, snapshots, ...), defined in SQL and YAML files, and they invoke dbt to create, update, or query associated views and tables. Today, dbt makes heavy use of Jinja2 to enable the templating of SQL, and to construct a DAG (Directed Acyclic Graph) from all of the resources in a project. Users can also extend their projects by installing resources (including Jinja macros) from other projects, called "packages."
## dbt-core
@@ -28,7 +28,7 @@ This is the docs website code. It comes from the dbt-docs repository, and is gen
dbt uses an adapter-plugin pattern to extend support to different databases, warehouses, query engines, etc. The four core adapters that are in the main repository, contained within the [`plugins`](plugins) subdirectory, are: Postgres Redshift, Snowflake and BigQuery. Other warehouses use adapter plugins defined in separate repositories (e.g. [dbt-spark](https://github.com/fishtown-analytics/dbt-spark), [dbt-presto](https://github.com/fishtown-analytics/dbt-presto)).
Each adapter is a mix of python, Jinja2, and SQL. The adapter code also makes heavy use of Jinja2 to wrap modular chunks of SQL functionality, define default implementations, and allow plugins to override it.
Each adapter is a mix of python, Jinja2, and SQL. The adapter code also makes heavy use of Jinja2 to wrap modular chunks of SQL functionality, define default implementations, and allow plugins to override it.
Each adapter plugin is a standalone python package that includes:

View File

@@ -88,7 +88,7 @@ If you are not a member of the `fishtown-analytics` GitHub organization, you can
### Core contributors
If you are a member of the `fishtown-analytics` GitHub organization, you will have push access to the dbt repo. Rather than
If you are a member of the `fishtown-analytics` GitHub organization, you will have push access to the dbt repo. Rather than
forking dbt to make your changes, just clone the repository, check out a new branch, and push directly to that branch.
## Setting up an environment
@@ -139,7 +139,7 @@ brew install postgresql
First make sure that you set up your `virtualenv` as described in section _Setting up an environment_. Next, install dbt (and its dependencies) with:
```
pip install -r editable_requirements.txt
pip install -r requirements-editable.txt
```
When dbt is installed from source in this way, any changes you make to the dbt source code will be reflected immediately in your next `dbt` run.

View File

@@ -1,10 +1,5 @@
.PHONY: install test test-unit test-integration
changed_tests := `git status --porcelain | grep '^\(M\| M\|A\| A\)' | awk '{ print $$2 }' | grep '\/test_[a-zA-Z_\-\.]\+.py'`
install:
pip install -e .
test: .env
@echo "Full test run starting..."
@time docker-compose run --rm test tox
@@ -18,7 +13,7 @@ test-integration: .env
@time docker-compose run --rm test tox -e integration-postgres-py36,integration-redshift-py36,integration-snowflake-py36,integration-bigquery-py36
test-quick: .env
@echo "Integration test run starting..."
@echo "Integration test run starting, will exit on first failure..."
@time docker-compose run --rm test tox -e integration-postgres-py36 -- -x
# This rule creates a file named .env that is used by docker-compose for passing

View File

@@ -139,7 +139,7 @@ jobs:
inputs:
versionSpec: '3.7'
architecture: 'x64'
- script: python -m pip install --upgrade pip setuptools && python -m pip install -r requirements.txt && python -m pip install -r dev_requirements.txt
- script: python -m pip install --upgrade pip setuptools && python -m pip install -r requirements.txt && python -m pip install -r requirements-dev.txt
displayName: Install dependencies
- task: ShellScript@2
inputs:

View File

@@ -63,11 +63,13 @@ def main():
packages = registry.packages()
project_json = init_project_in_packages(args, packages)
if args.project["version"] in project_json["versions"]:
raise Exception("Version {} already in packages JSON"
.format(args.project["version"]),
file=sys.stderr)
raise Exception(
"Version {} already in packages JSON".format(args.project["version"]),
file=sys.stderr,
)
add_version_to_package(args, project_json)
print(json.dumps(packages, indent=2))
if __name__ == "__main__":
main()

View File

@@ -8,10 +8,10 @@ from dbt.exceptions import RuntimeException
@dataclass
class Column:
TYPE_LABELS: ClassVar[Dict[str, str]] = {
'STRING': 'TEXT',
'TIMESTAMP': 'TIMESTAMP',
'FLOAT': 'FLOAT',
'INTEGER': 'INT'
"STRING": "TEXT",
"TIMESTAMP": "TIMESTAMP",
"FLOAT": "FLOAT",
"INTEGER": "INT",
}
column: str
dtype: str
@@ -24,7 +24,7 @@ class Column:
return cls.TYPE_LABELS.get(dtype.upper(), dtype)
@classmethod
def create(cls, name, label_or_dtype: str) -> 'Column':
def create(cls, name, label_or_dtype: str) -> "Column":
column_type = cls.translate_type(label_or_dtype)
return cls(name, column_type)
@@ -41,14 +41,19 @@ class Column:
if self.is_string():
return Column.string_type(self.string_size())
elif self.is_numeric():
return Column.numeric_type(self.dtype, self.numeric_precision,
self.numeric_scale)
return Column.numeric_type(
self.dtype, self.numeric_precision, self.numeric_scale
)
else:
return self.dtype
def is_string(self) -> bool:
return self.dtype.lower() in ['text', 'character varying', 'character',
'varchar']
return self.dtype.lower() in [
"text",
"character varying",
"character",
"varchar",
]
def is_number(self):
return any([self.is_integer(), self.is_numeric(), self.is_float()])
@@ -56,33 +61,45 @@ class Column:
def is_float(self):
return self.dtype.lower() in [
# floats
'real', 'float4', 'float', 'double precision', 'float8'
"real",
"float4",
"float",
"double precision",
"float8",
]
def is_integer(self) -> bool:
return self.dtype.lower() in [
# real types
'smallint', 'integer', 'bigint',
'smallserial', 'serial', 'bigserial',
"smallint",
"integer",
"bigint",
"smallserial",
"serial",
"bigserial",
# aliases
'int2', 'int4', 'int8',
'serial2', 'serial4', 'serial8',
"int2",
"int4",
"int8",
"serial2",
"serial4",
"serial8",
]
def is_numeric(self) -> bool:
return self.dtype.lower() in ['numeric', 'decimal']
return self.dtype.lower() in ["numeric", "decimal"]
def string_size(self) -> int:
if not self.is_string():
raise RuntimeException("Called string_size() on non-string field!")
if self.dtype == 'text' or self.char_size is None:
if self.dtype == "text" or self.char_size is None:
# char_size should never be None. Handle it reasonably just in case
return 256
else:
return int(self.char_size)
def can_expand_to(self, other_column: 'Column') -> bool:
def can_expand_to(self, other_column: "Column") -> bool:
"""returns True if this column can be expanded to the size of the
other column"""
if not self.is_string() or not other_column.is_string():
@@ -110,12 +127,10 @@ class Column:
return "<Column {} ({})>".format(self.name, self.data_type)
@classmethod
def from_description(cls, name: str, raw_data_type: str) -> 'Column':
match = re.match(r'([^(]+)(\([^)]+\))?', raw_data_type)
def from_description(cls, name: str, raw_data_type: str) -> "Column":
match = re.match(r"([^(]+)(\([^)]+\))?", raw_data_type)
if match is None:
raise RuntimeException(
f'Could not interpret data type "{raw_data_type}"'
)
raise RuntimeException(f'Could not interpret data type "{raw_data_type}"')
data_type, size_info = match.groups()
char_size = None
numeric_precision = None
@@ -123,7 +138,7 @@ class Column:
if size_info is not None:
# strip out the parentheses
size_info = size_info[1:-1]
parts = size_info.split(',')
parts = size_info.split(",")
if len(parts) == 1:
try:
char_size = int(parts[0])
@@ -148,6 +163,4 @@ class Column:
f'could not convert "{parts[1]}" to an integer'
)
return cls(
name, data_type, char_size, numeric_precision, numeric_scale
)
return cls(name, data_type, char_size, numeric_precision, numeric_scale)

View File

@@ -1,18 +1,21 @@
import abc
import os
# multiprocessing.RLock is a function returning this type
from multiprocessing.synchronize import RLock
from threading import get_ident
from typing import (
Dict, Tuple, Hashable, Optional, ContextManager, List, Union
)
from typing import Dict, Tuple, Hashable, Optional, ContextManager, List, Union
import agate
import dbt.exceptions
from dbt.contracts.connection import (
Connection, Identifier, ConnectionState,
AdapterRequiredConfig, LazyHandle, AdapterResponse
Connection,
Identifier,
ConnectionState,
AdapterRequiredConfig,
LazyHandle,
AdapterResponse,
)
from dbt.contracts.graph.manifest import Manifest
from dbt.adapters.base.query_headers import (
@@ -35,6 +38,7 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
You must also set the 'TYPE' class attribute with a class-unique constant
string.
"""
TYPE: str = NotImplemented
def __init__(self, profile: AdapterRequiredConfig):
@@ -65,7 +69,7 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
key = self.get_thread_identifier()
if key in self.thread_connections:
raise dbt.exceptions.InternalException(
'In set_thread_connection, existing connection exists for {}'
"In set_thread_connection, existing connection exists for {}"
)
self.thread_connections[key] = conn
@@ -105,18 +109,19 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
underlying database.
"""
raise dbt.exceptions.NotImplementedException(
'`exception_handler` is not implemented for this adapter!')
"`exception_handler` is not implemented for this adapter!"
)
def set_connection_name(self, name: Optional[str] = None) -> Connection:
conn_name: str
if name is None:
# if a name isn't specified, we'll re-use a single handle
# named 'master'
conn_name = 'master'
conn_name = "master"
else:
if not isinstance(name, str):
raise dbt.exceptions.CompilerException(
f'For connection name, got {name} - not a string!'
f"For connection name, got {name} - not a string!"
)
assert isinstance(name, str)
conn_name = name
@@ -129,20 +134,20 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
state=ConnectionState.INIT,
transaction_open=False,
handle=None,
credentials=self.profile.credentials
credentials=self.profile.credentials,
)
self.set_thread_connection(conn)
if conn.name == conn_name and conn.state == 'open':
if conn.name == conn_name and conn.state == "open":
return conn
logger.debug(
'Acquiring new {} connection "{}".'.format(self.TYPE, conn_name))
logger.debug('Acquiring new {} connection "{}".'.format(self.TYPE, conn_name))
if conn.state == 'open':
if conn.state == "open":
logger.debug(
'Re-using an available connection from the pool (formerly {}).'
.format(conn.name)
"Re-using an available connection from the pool (formerly {}).".format(
conn.name
)
)
else:
conn.handle = LazyHandle(self.open)
@@ -154,7 +159,7 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
def cancel_open(self) -> Optional[List[str]]:
"""Cancel all open connections on the adapter. (passable)"""
raise dbt.exceptions.NotImplementedException(
'`cancel_open` is not implemented for this adapter!'
"`cancel_open` is not implemented for this adapter!"
)
@abc.abstractclassmethod
@@ -168,7 +173,7 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
connection should not be in either in_use or available.
"""
raise dbt.exceptions.NotImplementedException(
'`open` is not implemented for this adapter!'
"`open` is not implemented for this adapter!"
)
def release(self) -> None:
@@ -189,12 +194,14 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
def cleanup_all(self) -> None:
with self.lock:
for connection in self.thread_connections.values():
if connection.state not in {'closed', 'init'}:
logger.debug("Connection '{}' was left open."
.format(connection.name))
if connection.state not in {"closed", "init"}:
logger.debug(
"Connection '{}' was left open.".format(connection.name)
)
else:
logger.debug("Connection '{}' was properly closed."
.format(connection.name))
logger.debug(
"Connection '{}' was properly closed.".format(connection.name)
)
self.close(connection)
# garbage collect these connections
@@ -204,14 +211,14 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
def begin(self) -> None:
"""Begin a transaction. (passable)"""
raise dbt.exceptions.NotImplementedException(
'`begin` is not implemented for this adapter!'
"`begin` is not implemented for this adapter!"
)
@abc.abstractmethod
def commit(self) -> None:
"""Commit a transaction. (passable)"""
raise dbt.exceptions.NotImplementedException(
'`commit` is not implemented for this adapter!'
"`commit` is not implemented for this adapter!"
)
@classmethod
@@ -220,20 +227,17 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
try:
connection.handle.rollback()
except Exception:
logger.debug(
'Failed to rollback {}'.format(connection.name),
exc_info=True
)
logger.debug("Failed to rollback {}".format(connection.name), exc_info=True)
@classmethod
def _close_handle(cls, connection: Connection) -> None:
"""Perform the actual close operation."""
# On windows, sometimes connection handles don't have a close() attr.
if hasattr(connection.handle, 'close'):
logger.debug(f'On {connection.name}: Close')
if hasattr(connection.handle, "close"):
logger.debug(f"On {connection.name}: Close")
connection.handle.close()
else:
logger.debug(f'On {connection.name}: No close available on handle')
logger.debug(f"On {connection.name}: No close available on handle")
@classmethod
def _rollback(cls, connection: Connection) -> None:
@@ -241,16 +245,16 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
if flags.STRICT_MODE:
if not isinstance(connection, Connection):
raise dbt.exceptions.CompilerException(
f'In _rollback, got {connection} - not a Connection!'
f"In _rollback, got {connection} - not a Connection!"
)
if connection.transaction_open is False:
raise dbt.exceptions.InternalException(
f'Tried to rollback transaction on connection '
f"Tried to rollback transaction on connection "
f'"{connection.name}", but it does not have one open!'
)
logger.debug(f'On {connection.name}: ROLLBACK')
logger.debug(f"On {connection.name}: ROLLBACK")
cls._rollback_handle(connection)
connection.transaction_open = False
@@ -260,7 +264,7 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
if flags.STRICT_MODE:
if not isinstance(connection, Connection):
raise dbt.exceptions.CompilerException(
f'In close, got {connection} - not a Connection!'
f"In close, got {connection} - not a Connection!"
)
# if the connection is in closed or init, there's nothing to do
@@ -268,7 +272,7 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
return connection
if connection.transaction_open and connection.handle:
logger.debug('On {}: ROLLBACK'.format(connection.name))
logger.debug("On {}: ROLLBACK".format(connection.name))
cls._rollback_handle(connection)
connection.transaction_open = False
@@ -302,5 +306,5 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
:rtype: Tuple[Union[str, AdapterResponse], agate.Table]
"""
raise dbt.exceptions.NotImplementedException(
'`execute` is not implemented for this adapter!'
"`execute` is not implemented for this adapter!"
)

View File

@@ -4,17 +4,31 @@ from contextlib import contextmanager
from datetime import datetime
from itertools import chain
from typing import (
Optional, Tuple, Callable, Iterable, Type, Dict, Any, List, Mapping,
Iterator, Union, Set
Optional,
Tuple,
Callable,
Iterable,
Type,
Dict,
Any,
List,
Mapping,
Iterator,
Union,
Set,
)
import agate
import pytz
from dbt.exceptions import (
raise_database_error, raise_compiler_error, invalid_type_error,
raise_database_error,
raise_compiler_error,
invalid_type_error,
get_relation_returned_multiple_results,
InternalException, NotImplementedException, RuntimeException,
InternalException,
NotImplementedException,
RuntimeException,
)
from dbt import flags
@@ -25,9 +39,7 @@ from dbt.adapters.protocol import (
)
from dbt.clients.agate_helper import empty_table, merge_tables, table_from_rows
from dbt.clients.jinja import MacroGenerator
from dbt.contracts.graph.compiled import (
CompileResultNode, CompiledSeedNode
)
from dbt.contracts.graph.compiled import CompileResultNode, CompiledSeedNode
from dbt.contracts.graph.manifest import Manifest, MacroManifest
from dbt.contracts.graph.parsed import ParsedSeedNode
from dbt.exceptions import warn_or_error
@@ -38,7 +50,10 @@ from dbt.utils import filter_null_values, executor
from dbt.adapters.base.connections import Connection, AdapterResponse
from dbt.adapters.base.meta import AdapterMeta, available
from dbt.adapters.base.relation import (
ComponentName, BaseRelation, InformationSchema, SchemaSearchMap
ComponentName,
BaseRelation,
InformationSchema,
SchemaSearchMap,
)
from dbt.adapters.base import Column as BaseColumn
from dbt.adapters.cache import RelationsCache
@@ -47,15 +62,14 @@ from dbt.adapters.cache import RelationsCache
SeedModel = Union[ParsedSeedNode, CompiledSeedNode]
GET_CATALOG_MACRO_NAME = 'get_catalog'
FRESHNESS_MACRO_NAME = 'collect_freshness'
GET_CATALOG_MACRO_NAME = "get_catalog"
FRESHNESS_MACRO_NAME = "collect_freshness"
def _expect_row_value(key: str, row: agate.Row):
if key not in row.keys():
raise InternalException(
'Got a row without "{}" column, columns: {}'
.format(key, row.keys())
'Got a row without "{}" column, columns: {}'.format(key, row.keys())
)
return row[key]
@@ -64,40 +78,37 @@ def _catalog_filter_schemas(manifest: Manifest) -> Callable[[agate.Row], bool]:
"""Return a function that takes a row and decides if the row should be
included in the catalog output.
"""
schemas = frozenset((d.lower(), s.lower())
for d, s in manifest.get_used_schemas())
schemas = frozenset((d.lower(), s.lower()) for d, s in manifest.get_used_schemas())
def test(row: agate.Row) -> bool:
table_database = _expect_row_value('table_database', row)
table_schema = _expect_row_value('table_schema', row)
table_database = _expect_row_value("table_database", row)
table_schema = _expect_row_value("table_schema", row)
# the schema may be present but None, which is not an error and should
# be filtered out
if table_schema is None:
return False
return (table_database.lower(), table_schema.lower()) in schemas
return test
def _utc(
dt: Optional[datetime], source: BaseRelation, field_name: str
) -> datetime:
def _utc(dt: Optional[datetime], source: BaseRelation, field_name: str) -> datetime:
"""If dt has a timezone, return a new datetime that's in UTC. Otherwise,
assume the datetime is already for UTC and add the timezone.
"""
if dt is None:
raise raise_database_error(
"Expected a non-null value when querying field '{}' of table "
" {} but received value 'null' instead".format(
field_name,
source))
" {} but received value 'null' instead".format(field_name, source)
)
elif not hasattr(dt, 'tzinfo'):
elif not hasattr(dt, "tzinfo"):
raise raise_database_error(
"Expected a timestamp value when querying field '{}' of table "
"{} but received value of type '{}' instead".format(
field_name,
source,
type(dt).__name__))
field_name, source, type(dt).__name__
)
)
elif dt.tzinfo:
return dt.astimezone(pytz.UTC)
@@ -107,7 +118,7 @@ def _utc(
def _relation_name(rel: Optional[BaseRelation]) -> str:
if rel is None:
return 'null relation'
return "null relation"
else:
return str(rel)
@@ -148,6 +159,7 @@ class BaseAdapter(metaclass=AdapterMeta):
Macros:
- get_catalog
"""
Relation: Type[BaseRelation] = BaseRelation
Column: Type[BaseColumn] = BaseColumn
ConnectionManager: Type[ConnectionManagerProtocol]
@@ -181,12 +193,12 @@ class BaseAdapter(metaclass=AdapterMeta):
self.connections.commit_if_has_connection()
def debug_query(self) -> None:
self.execute('select 1 as id')
self.execute("select 1 as id")
def nice_connection_name(self) -> str:
conn = self.connections.get_if_exists()
if conn is None or conn.name is None:
return '<None>'
return "<None>"
return conn.name
@contextmanager
@@ -204,13 +216,11 @@ class BaseAdapter(metaclass=AdapterMeta):
self.connections.query_header.reset()
@contextmanager
def connection_for(
self, node: CompileResultNode
) -> Iterator[None]:
def connection_for(self, node: CompileResultNode) -> Iterator[None]:
with self.connection_named(node.unique_id, node):
yield
@available.parse(lambda *a, **k: ('', empty_table()))
@available.parse(lambda *a, **k: ("", empty_table()))
def execute(
self, sql: str, auto_begin: bool = False, fetch: bool = False
) -> Tuple[Union[str, AdapterResponse], agate.Table]:
@@ -224,16 +234,10 @@ class BaseAdapter(metaclass=AdapterMeta):
:return: A tuple of the status and the results (empty if fetch=False).
:rtype: Tuple[Union[str, AdapterResponse], agate.Table]
"""
return self.connections.execute(
sql=sql,
auto_begin=auto_begin,
fetch=fetch
)
return self.connections.execute(sql=sql, auto_begin=auto_begin, fetch=fetch)
@available.parse(lambda *a, **k: ('', empty_table()))
def get_partitions_metadata(
self, table: str
) -> Tuple[agate.Table]:
@available.parse(lambda *a, **k: ("", empty_table()))
def get_partitions_metadata(self, table: str) -> Tuple[agate.Table]:
"""Obtain partitions metadata for a BigQuery partitioned table.
:param str table_id: a partitioned table id, in standard SQL format.
@@ -241,9 +245,7 @@ class BaseAdapter(metaclass=AdapterMeta):
https://cloud.google.com/bigquery/docs/creating-partitioned-tables#getting_partition_metadata_using_meta_tables.
:rtype: agate.Table
"""
return self.connections.get_partitions_metadata(
table=table
)
return self.connections.get_partitions_metadata(table=table)
###
# Methods that should never be overridden
@@ -274,6 +276,7 @@ class BaseAdapter(metaclass=AdapterMeta):
if self._macro_manifest_lazy is None:
# avoid a circular import
from dbt.parser.manifest import load_macro_manifest
manifest = load_macro_manifest(
self.config, self.connections.set_query_header
)
@@ -294,8 +297,9 @@ class BaseAdapter(metaclass=AdapterMeta):
return False
elif (database, schema) not in self.cache:
logger.debug(
'On "{}": cache miss for schema "{}.{}", this is inefficient'
.format(self.nice_connection_name(), database, schema)
'On "{}": cache miss for schema "{}.{}", this is inefficient'.format(
self.nice_connection_name(), database, schema
)
)
return False
else:
@@ -310,8 +314,8 @@ class BaseAdapter(metaclass=AdapterMeta):
self.Relation.create_from(self.config, node).without_identifier()
for node in manifest.nodes.values()
if (
node.resource_type in NodeType.executable() and
not node.is_ephemeral_model
node.resource_type in NodeType.executable()
and not node.is_ephemeral_model
)
}
@@ -351,9 +355,9 @@ class BaseAdapter(metaclass=AdapterMeta):
for cache_schema in cache_schemas:
fut = tpe.submit_connected(
self,
f'list_{cache_schema.database}_{cache_schema.schema}',
f"list_{cache_schema.database}_{cache_schema.schema}",
self.list_relations_without_caching,
cache_schema
cache_schema,
)
futures.append(fut)
@@ -371,9 +375,7 @@ class BaseAdapter(metaclass=AdapterMeta):
cache_update.add((relation.database, relation.schema))
self.cache.update_schemas(cache_update)
def set_relations_cache(
self, manifest: Manifest, clear: bool = False
) -> None:
def set_relations_cache(self, manifest: Manifest, clear: bool = False) -> None:
"""Run a query that gets a populated cache of the relations in the
database and set the cache on this adapter.
"""
@@ -391,12 +393,12 @@ class BaseAdapter(metaclass=AdapterMeta):
if relation is None:
name = self.nice_connection_name()
raise_compiler_error(
'Attempted to cache a null relation for {}'.format(name)
"Attempted to cache a null relation for {}".format(name)
)
if flags.USE_CACHE:
self.cache.add(relation)
# so jinja doesn't render things
return ''
return ""
@available
def cache_dropped(self, relation: Optional[BaseRelation]) -> str:
@@ -406,11 +408,11 @@ class BaseAdapter(metaclass=AdapterMeta):
if relation is None:
name = self.nice_connection_name()
raise_compiler_error(
'Attempted to drop a null relation for {}'.format(name)
"Attempted to drop a null relation for {}".format(name)
)
if flags.USE_CACHE:
self.cache.drop(relation)
return ''
return ""
@available
def cache_renamed(
@@ -426,13 +428,12 @@ class BaseAdapter(metaclass=AdapterMeta):
src_name = _relation_name(from_relation)
dst_name = _relation_name(to_relation)
raise_compiler_error(
'Attempted to rename {} to {} for {}'
.format(src_name, dst_name, name)
"Attempted to rename {} to {} for {}".format(src_name, dst_name, name)
)
if flags.USE_CACHE:
self.cache.rename(from_relation, to_relation)
return ''
return ""
###
# Abstract methods for database-specific values, attributes, and types
@@ -441,12 +442,13 @@ class BaseAdapter(metaclass=AdapterMeta):
def date_function(cls) -> str:
"""Get the date function used by this adapter's database."""
raise NotImplementedException(
'`date_function` is not implemented for this adapter!')
"`date_function` is not implemented for this adapter!"
)
@abc.abstractclassmethod
def is_cancelable(cls) -> bool:
raise NotImplementedException(
'`is_cancelable` is not implemented for this adapter!'
"`is_cancelable` is not implemented for this adapter!"
)
###
@@ -456,7 +458,7 @@ class BaseAdapter(metaclass=AdapterMeta):
def list_schemas(self, database: str) -> List[str]:
"""Get a list of existing schemas in database"""
raise NotImplementedException(
'`list_schemas` is not implemented for this adapter!'
"`list_schemas` is not implemented for this adapter!"
)
@available.parse(lambda *a, **k: False)
@@ -467,10 +469,7 @@ class BaseAdapter(metaclass=AdapterMeta):
and adapters should implement it if there is an optimized path (and
there probably is)
"""
search = (
s.lower() for s in
self.list_schemas(database=database)
)
search = (s.lower() for s in self.list_schemas(database=database))
return schema.lower() in search
###
@@ -484,7 +483,7 @@ class BaseAdapter(metaclass=AdapterMeta):
*Implementors must call self.cache.drop() to preserve cache state!*
"""
raise NotImplementedException(
'`drop_relation` is not implemented for this adapter!'
"`drop_relation` is not implemented for this adapter!"
)
@abc.abstractmethod
@@ -492,7 +491,7 @@ class BaseAdapter(metaclass=AdapterMeta):
def truncate_relation(self, relation: BaseRelation) -> None:
"""Truncate the given relation."""
raise NotImplementedException(
'`truncate_relation` is not implemented for this adapter!'
"`truncate_relation` is not implemented for this adapter!"
)
@abc.abstractmethod
@@ -505,36 +504,30 @@ class BaseAdapter(metaclass=AdapterMeta):
Implementors must call self.cache.rename() to preserve cache state.
"""
raise NotImplementedException(
'`rename_relation` is not implemented for this adapter!'
"`rename_relation` is not implemented for this adapter!"
)
@abc.abstractmethod
@available.parse_list
def get_columns_in_relation(
self, relation: BaseRelation
) -> List[BaseColumn]:
def get_columns_in_relation(self, relation: BaseRelation) -> List[BaseColumn]:
"""Get a list of the columns in the given Relation."""
raise NotImplementedException(
'`get_columns_in_relation` is not implemented for this adapter!'
"`get_columns_in_relation` is not implemented for this adapter!"
)
@available.deprecated('get_columns_in_relation', lambda *a, **k: [])
def get_columns_in_table(
self, schema: str, identifier: str
) -> List[BaseColumn]:
@available.deprecated("get_columns_in_relation", lambda *a, **k: [])
def get_columns_in_table(self, schema: str, identifier: str) -> List[BaseColumn]:
"""DEPRECATED: Get a list of the columns in the given table."""
relation = self.Relation.create(
database=self.config.credentials.database,
schema=schema,
identifier=identifier,
quote_policy=self.config.quoting
quote_policy=self.config.quoting,
)
return self.get_columns_in_relation(relation)
@abc.abstractmethod
def expand_column_types(
self, goal: BaseRelation, current: BaseRelation
) -> None:
def expand_column_types(self, goal: BaseRelation, current: BaseRelation) -> None:
"""Expand the current table's types to match the goal table. (passable)
:param self.Relation goal: A relation that currently exists in the
@@ -543,7 +536,7 @@ class BaseAdapter(metaclass=AdapterMeta):
database with columns of unspecified types.
"""
raise NotImplementedException(
'`expand_target_column_types` is not implemented for this adapter!'
"`expand_target_column_types` is not implemented for this adapter!"
)
@abc.abstractmethod
@@ -560,8 +553,7 @@ class BaseAdapter(metaclass=AdapterMeta):
:rtype: List[self.Relation]
"""
raise NotImplementedException(
'`list_relations_without_caching` is not implemented for this '
'adapter!'
"`list_relations_without_caching` is not implemented for this " "adapter!"
)
###
@@ -576,32 +568,33 @@ class BaseAdapter(metaclass=AdapterMeta):
"""
if not isinstance(from_relation, self.Relation):
invalid_type_error(
method_name='get_missing_columns',
arg_name='from_relation',
method_name="get_missing_columns",
arg_name="from_relation",
got_value=from_relation,
expected_type=self.Relation)
expected_type=self.Relation,
)
if not isinstance(to_relation, self.Relation):
invalid_type_error(
method_name='get_missing_columns',
arg_name='to_relation',
method_name="get_missing_columns",
arg_name="to_relation",
got_value=to_relation,
expected_type=self.Relation)
expected_type=self.Relation,
)
from_columns = {
col.name: col for col in
self.get_columns_in_relation(from_relation)
col.name: col for col in self.get_columns_in_relation(from_relation)
}
to_columns = {
col.name: col for col in
self.get_columns_in_relation(to_relation)
col.name: col for col in self.get_columns_in_relation(to_relation)
}
missing_columns = set(from_columns.keys()) - set(to_columns.keys())
return [
col for (col_name, col) in from_columns.items()
col
for (col_name, col) in from_columns.items()
if col_name in missing_columns
]
@@ -616,18 +609,19 @@ class BaseAdapter(metaclass=AdapterMeta):
"""
if not isinstance(relation, self.Relation):
invalid_type_error(
method_name='valid_snapshot_target',
arg_name='relation',
method_name="valid_snapshot_target",
arg_name="relation",
got_value=relation,
expected_type=self.Relation)
expected_type=self.Relation,
)
columns = self.get_columns_in_relation(relation)
names = set(c.name.lower() for c in columns)
expanded_keys = ('scd_id', 'valid_from', 'valid_to')
expanded_keys = ("scd_id", "valid_from", "valid_to")
extra = []
missing = []
for legacy in expanded_keys:
desired = 'dbt_' + legacy
desired = "dbt_" + legacy
if desired not in names:
missing.append(desired)
if legacy in names:
@@ -637,13 +631,13 @@ class BaseAdapter(metaclass=AdapterMeta):
if extra:
msg = (
'Snapshot target has ("{}") but not ("{}") - is it an '
'unmigrated previous version archive?'
.format('", "'.join(extra), '", "'.join(missing))
"unmigrated previous version archive?".format(
'", "'.join(extra), '", "'.join(missing)
)
)
else:
msg = (
'Snapshot target is not a snapshot table (missing "{}")'
.format('", "'.join(missing))
msg = 'Snapshot target is not a snapshot table (missing "{}")'.format(
'", "'.join(missing)
)
raise_compiler_error(msg)
@@ -653,17 +647,19 @@ class BaseAdapter(metaclass=AdapterMeta):
) -> None:
if not isinstance(from_relation, self.Relation):
invalid_type_error(
method_name='expand_target_column_types',
arg_name='from_relation',
method_name="expand_target_column_types",
arg_name="from_relation",
got_value=from_relation,
expected_type=self.Relation)
expected_type=self.Relation,
)
if not isinstance(to_relation, self.Relation):
invalid_type_error(
method_name='expand_target_column_types',
arg_name='to_relation',
method_name="expand_target_column_types",
arg_name="to_relation",
got_value=to_relation,
expected_type=self.Relation)
expected_type=self.Relation,
)
self.expand_column_types(from_relation, to_relation)
@@ -676,38 +672,41 @@ class BaseAdapter(metaclass=AdapterMeta):
schema_relation = self.Relation.create(
database=database,
schema=schema,
identifier='',
quote_policy=self.config.quoting
identifier="",
quote_policy=self.config.quoting,
).without_identifier()
# we can't build the relations cache because we don't have a
# manifest so we can't run any operations.
relations = self.list_relations_without_caching(
schema_relation
)
relations = self.list_relations_without_caching(schema_relation)
logger.debug('with database={}, schema={}, relations={}'
.format(database, schema, relations))
logger.debug(
"with database={}, schema={}, relations={}".format(
database, schema, relations
)
)
return relations
def _make_match_kwargs(
self, database: str, schema: str, identifier: str
) -> Dict[str, str]:
quoting = self.config.quoting
if identifier is not None and quoting['identifier'] is False:
if identifier is not None and quoting["identifier"] is False:
identifier = identifier.lower()
if schema is not None and quoting['schema'] is False:
if schema is not None and quoting["schema"] is False:
schema = schema.lower()
if database is not None and quoting['database'] is False:
if database is not None and quoting["database"] is False:
database = database.lower()
return filter_null_values({
'database': database,
'identifier': identifier,
'schema': schema,
})
return filter_null_values(
{
"database": database,
"identifier": identifier,
"schema": schema,
}
)
def _make_match(
self,
@@ -733,25 +732,22 @@ class BaseAdapter(metaclass=AdapterMeta):
) -> Optional[BaseRelation]:
relations_list = self.list_relations(database, schema)
matches = self._make_match(relations_list, database, schema,
identifier)
matches = self._make_match(relations_list, database, schema, identifier)
if len(matches) > 1:
kwargs = {
'identifier': identifier,
'schema': schema,
'database': database,
"identifier": identifier,
"schema": schema,
"database": database,
}
get_relation_returned_multiple_results(
kwargs, matches
)
get_relation_returned_multiple_results(kwargs, matches)
elif matches:
return matches[0]
return None
@available.deprecated('get_relation', lambda *a, **k: False)
@available.deprecated("get_relation", lambda *a, **k: False)
def already_exists(self, schema: str, name: str) -> bool:
"""DEPRECATED: Return if a model already exists in the database"""
database = self.config.credentials.database
@@ -767,7 +763,7 @@ class BaseAdapter(metaclass=AdapterMeta):
def create_schema(self, relation: BaseRelation):
"""Create the given schema if it does not exist."""
raise NotImplementedException(
'`create_schema` is not implemented for this adapter!'
"`create_schema` is not implemented for this adapter!"
)
@abc.abstractmethod
@@ -775,16 +771,14 @@ class BaseAdapter(metaclass=AdapterMeta):
def drop_schema(self, relation: BaseRelation):
"""Drop the given schema (and everything in it) if it exists."""
raise NotImplementedException(
'`drop_schema` is not implemented for this adapter!'
"`drop_schema` is not implemented for this adapter!"
)
@available
@abc.abstractclassmethod
def quote(cls, identifier: str) -> str:
"""Quote the given identifier, as appropriate for the database."""
raise NotImplementedException(
'`quote` is not implemented for this adapter!'
)
raise NotImplementedException("`quote` is not implemented for this adapter!")
@available
def quote_as_configured(self, identifier: str, quote_key: str) -> str:
@@ -806,19 +800,17 @@ class BaseAdapter(metaclass=AdapterMeta):
return identifier
@available
def quote_seed_column(
self, column: str, quote_config: Optional[bool]
) -> str:
def quote_seed_column(self, column: str, quote_config: Optional[bool]) -> str:
# this is the default for now
quote_columns: bool = False
if isinstance(quote_config, bool):
quote_columns = quote_config
elif quote_config is None:
deprecations.warn('column-quoting-unset')
deprecations.warn("column-quoting-unset")
else:
raise_compiler_error(
f'The seed configuration value of "quote_columns" has an '
f'invalid type {type(quote_config)}'
f"invalid type {type(quote_config)}"
)
if quote_columns:
@@ -831,9 +823,7 @@ class BaseAdapter(metaclass=AdapterMeta):
# converting agate types into their sql equivalents.
###
@abc.abstractclassmethod
def convert_text_type(
cls, agate_table: agate.Table, col_idx: int
) -> str:
def convert_text_type(cls, agate_table: agate.Table, col_idx: int) -> str:
"""Return the type in the database that best maps to the agate.Text
type for the given agate table and column index.
@@ -842,12 +832,11 @@ class BaseAdapter(metaclass=AdapterMeta):
:return: The name of the type in the database
"""
raise NotImplementedException(
'`convert_text_type` is not implemented for this adapter!')
"`convert_text_type` is not implemented for this adapter!"
)
@abc.abstractclassmethod
def convert_number_type(
cls, agate_table: agate.Table, col_idx: int
) -> str:
def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str:
"""Return the type in the database that best maps to the agate.Number
type for the given agate table and column index.
@@ -856,12 +845,11 @@ class BaseAdapter(metaclass=AdapterMeta):
:return: The name of the type in the database
"""
raise NotImplementedException(
'`convert_number_type` is not implemented for this adapter!')
"`convert_number_type` is not implemented for this adapter!"
)
@abc.abstractclassmethod
def convert_boolean_type(
cls, agate_table: agate.Table, col_idx: int
) -> str:
def convert_boolean_type(cls, agate_table: agate.Table, col_idx: int) -> str:
"""Return the type in the database that best maps to the agate.Boolean
type for the given agate table and column index.
@@ -870,12 +858,11 @@ class BaseAdapter(metaclass=AdapterMeta):
:return: The name of the type in the database
"""
raise NotImplementedException(
'`convert_boolean_type` is not implemented for this adapter!')
"`convert_boolean_type` is not implemented for this adapter!"
)
@abc.abstractclassmethod
def convert_datetime_type(
cls, agate_table: agate.Table, col_idx: int
) -> str:
def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str:
"""Return the type in the database that best maps to the agate.DateTime
type for the given agate table and column index.
@@ -884,7 +871,8 @@ class BaseAdapter(metaclass=AdapterMeta):
:return: The name of the type in the database
"""
raise NotImplementedException(
'`convert_datetime_type` is not implemented for this adapter!')
"`convert_datetime_type` is not implemented for this adapter!"
)
@abc.abstractclassmethod
def convert_date_type(cls, agate_table: agate.Table, col_idx: int) -> str:
@@ -896,7 +884,8 @@ class BaseAdapter(metaclass=AdapterMeta):
:return: The name of the type in the database
"""
raise NotImplementedException(
'`convert_date_type` is not implemented for this adapter!')
"`convert_date_type` is not implemented for this adapter!"
)
@abc.abstractclassmethod
def convert_time_type(cls, agate_table: agate.Table, col_idx: int) -> str:
@@ -908,13 +897,12 @@ class BaseAdapter(metaclass=AdapterMeta):
:return: The name of the type in the database
"""
raise NotImplementedException(
'`convert_time_type` is not implemented for this adapter!')
"`convert_time_type` is not implemented for this adapter!"
)
@available
@classmethod
def convert_type(
cls, agate_table: agate.Table, col_idx: int
) -> Optional[str]:
def convert_type(cls, agate_table: agate.Table, col_idx: int) -> Optional[str]:
return cls.convert_agate_type(agate_table, col_idx)
@classmethod
@@ -963,7 +951,7 @@ class BaseAdapter(metaclass=AdapterMeta):
:param release: Ignored.
"""
if release is not False:
deprecations.warn('execute-macro-release')
deprecations.warn("execute-macro-release")
if kwargs is None:
kwargs = {}
if context_override is None:
@@ -977,28 +965,27 @@ class BaseAdapter(metaclass=AdapterMeta):
)
if macro is None:
if project is None:
package_name = 'any package'
package_name = "any package"
else:
package_name = 'the "{}" package'.format(project)
raise RuntimeException(
'dbt could not find a macro with the name "{}" in {}'
.format(macro_name, package_name)
'dbt could not find a macro with the name "{}" in {}'.format(
macro_name, package_name
)
)
# This causes a reference cycle, as generate_runtime_macro()
# ends up calling get_adapter, so the import has to be here.
from dbt.context.providers import generate_runtime_macro
macro_context = generate_runtime_macro(
macro=macro,
config=self.config,
manifest=manifest,
package_name=project
macro=macro, config=self.config, manifest=manifest, package_name=project
)
macro_context.update(context_override)
macro_function = MacroGenerator(macro, macro_context)
with self.connections.exception_handler(f'macro {macro_name}'):
with self.connections.exception_handler(f"macro {macro_name}"):
result = macro_function(**kwargs)
return result
@@ -1013,7 +1000,7 @@ class BaseAdapter(metaclass=AdapterMeta):
table = table_from_rows(
table.rows,
table.column_names,
text_only_columns=['table_database', 'table_schema', 'table_name']
text_only_columns=["table_database", "table_schema", "table_name"],
)
return table.where(_catalog_filter_schemas(manifest))
@@ -1024,10 +1011,7 @@ class BaseAdapter(metaclass=AdapterMeta):
manifest: Manifest,
) -> agate.Table:
kwargs = {
'information_schema': information_schema,
'schemas': schemas
}
kwargs = {"information_schema": information_schema, "schemas": schemas}
table = self.execute_macro(
GET_CATALOG_MACRO_NAME,
kwargs=kwargs,
@@ -1039,9 +1023,7 @@ class BaseAdapter(metaclass=AdapterMeta):
results = self._catalog_filter_table(table, manifest)
return results
def get_catalog(
self, manifest: Manifest
) -> Tuple[agate.Table, List[Exception]]:
def get_catalog(self, manifest: Manifest) -> Tuple[agate.Table, List[Exception]]:
schema_map = self._get_catalog_schemas(manifest)
with executor(self.config) as tpe:
@@ -1049,14 +1031,10 @@ class BaseAdapter(metaclass=AdapterMeta):
for info, schemas in schema_map.items():
if len(schemas) == 0:
continue
name = '.'.join([
str(info.database),
'information_schema'
])
name = ".".join([str(info.database), "information_schema"])
fut = tpe.submit_connected(
self, name,
self._get_one_catalog, info, schemas, manifest
self, name, self._get_one_catalog, info, schemas, manifest
)
futures.append(fut)
@@ -1073,20 +1051,18 @@ class BaseAdapter(metaclass=AdapterMeta):
source: BaseRelation,
loaded_at_field: str,
filter: Optional[str],
manifest: Optional[Manifest] = None
manifest: Optional[Manifest] = None,
) -> Dict[str, Any]:
"""Calculate the freshness of sources in dbt, and return it"""
kwargs: Dict[str, Any] = {
'source': source,
'loaded_at_field': loaded_at_field,
'filter': filter,
"source": source,
"loaded_at_field": loaded_at_field,
"filter": filter,
}
# run the macro
table = self.execute_macro(
FRESHNESS_MACRO_NAME,
kwargs=kwargs,
manifest=manifest
FRESHNESS_MACRO_NAME, kwargs=kwargs, manifest=manifest
)
# now we have a 1-row table of the maximum `loaded_at_field` value and
# the current time according to the db.
@@ -1106,9 +1082,9 @@ class BaseAdapter(metaclass=AdapterMeta):
snapshotted_at = _utc(table[0][1], source, loaded_at_field)
age = (snapshotted_at - max_loaded_at).total_seconds()
return {
'max_loaded_at': max_loaded_at,
'snapshotted_at': snapshotted_at,
'age': age,
"max_loaded_at": max_loaded_at,
"snapshotted_at": snapshotted_at,
"age": age,
}
def pre_model_hook(self, config: Mapping[str, Any]) -> Any:
@@ -1138,6 +1114,7 @@ class BaseAdapter(metaclass=AdapterMeta):
def get_compiler(self):
from dbt.compilation import Compiler
return Compiler(self.config)
# Methods used in adapter tests
@@ -1148,13 +1125,13 @@ class BaseAdapter(metaclass=AdapterMeta):
clause: str,
where_clause: Optional[str] = None,
) -> str:
clause = f'update {dst_name} set {dst_column} = {clause}'
clause = f"update {dst_name} set {dst_column} = {clause}"
if where_clause is not None:
clause += f' where {where_clause}'
clause += f" where {where_clause}"
return clause
def timestamp_add_sql(
self, add_to: str, number: int = 1, interval: str = 'hour'
self, add_to: str, number: int = 1, interval: str = "hour"
) -> str:
# for backwards compatibility, we're compelled to set some sort of
# default. A lot of searching has lead me to believe that the
@@ -1163,23 +1140,24 @@ class BaseAdapter(metaclass=AdapterMeta):
return f"{add_to} + interval '{number} {interval}'"
def string_add_sql(
self, add_to: str, value: str, location='append',
self,
add_to: str,
value: str,
location="append",
) -> str:
if location == 'append':
if location == "append":
return f"{add_to} || '{value}'"
elif location == 'prepend':
elif location == "prepend":
return f"'{value}' || {add_to}"
else:
raise RuntimeException(
f'Got an unexpected location value of "{location}"'
)
raise RuntimeException(f'Got an unexpected location value of "{location}"')
def get_rows_different_sql(
self,
relation_a: BaseRelation,
relation_b: BaseRelation,
column_names: Optional[List[str]] = None,
except_operator: str = 'EXCEPT',
except_operator: str = "EXCEPT",
) -> str:
"""Generate SQL for a query that returns a single row with a two
columns: the number of rows that are different between the two
@@ -1192,7 +1170,7 @@ class BaseAdapter(metaclass=AdapterMeta):
names = sorted((self.quote(c.name) for c in columns))
else:
names = sorted((self.quote(n) for n in column_names))
columns_csv = ', '.join(names)
columns_csv = ", ".join(names)
sql = COLUMNS_EQUAL_SQL.format(
columns=columns_csv,
@@ -1204,7 +1182,7 @@ class BaseAdapter(metaclass=AdapterMeta):
return sql
COLUMNS_EQUAL_SQL = '''
COLUMNS_EQUAL_SQL = """
with diff_count as (
SELECT
1 as id,
@@ -1230,11 +1208,11 @@ select
diff_count.num_missing as num_mismatched
from row_count_diff
join diff_count using (id)
'''.strip()
""".strip()
def catch_as_completed(
futures # typing: List[Future[agate.Table]]
futures, # typing: List[Future[agate.Table]]
) -> Tuple[agate.Table, List[Exception]]:
# catalogs: agate.Table = agate.Table(rows=[])
@@ -1247,15 +1225,10 @@ def catch_as_completed(
if exc is None:
catalog = future.result()
tables.append(catalog)
elif (
isinstance(exc, KeyboardInterrupt) or
not isinstance(exc, Exception)
):
elif isinstance(exc, KeyboardInterrupt) or not isinstance(exc, Exception):
raise exc
else:
warn_or_error(
f'Encountered an error while generating catalog: {str(exc)}'
)
warn_or_error(f"Encountered an error while generating catalog: {str(exc)}")
# exc is not None, derives from Exception, and isn't ctrl+c
exceptions.append(exc)
return merge_tables(tables), exceptions

View File

@@ -30,9 +30,11 @@ class _Available:
x.update(big_expensive_db_query())
return x
"""
def inner(func):
func._parse_replacement_ = parse_replacement
return self(func)
return inner
def deprecated(
@@ -57,13 +59,14 @@ class _Available:
The optional parse_replacement, if provided, will provide a parse-time
replacement for the actual method (see `available.parse`).
"""
def wrapper(func):
func_name = func.__name__
renamed_method(func_name, supported_name)
@wraps(func)
def inner(*args, **kwargs):
warn('adapter:{}'.format(func_name))
warn("adapter:{}".format(func_name))
return func(*args, **kwargs)
if parse_replacement:
@@ -71,6 +74,7 @@ class _Available:
else:
available_function = self
return available_function(inner)
return wrapper
def parse_none(self, func: Callable) -> Callable:
@@ -109,14 +113,14 @@ class AdapterMeta(abc.ABCMeta):
# collect base class data first
for base in bases:
available.update(getattr(base, '_available_', set()))
replacements.update(getattr(base, '_parse_replacements_', set()))
available.update(getattr(base, "_available_", set()))
replacements.update(getattr(base, "_parse_replacements_", set()))
# override with local data if it exists
for name, value in namespace.items():
if getattr(value, '_is_available_', False):
if getattr(value, "_is_available_", False):
available.add(name)
parse_replacement = getattr(value, '_parse_replacement_', None)
parse_replacement = getattr(value, "_parse_replacement_", None)
if parse_replacement is not None:
replacements[name] = parse_replacement

View File

@@ -8,11 +8,10 @@ from dbt.adapters.protocol import AdapterProtocol
def project_name_from_path(include_path: str) -> str:
# avoid an import cycle
from dbt.config.project import Project
partial = Project.partial_load(include_path)
if partial.project_name is None:
raise CompilationException(
f'Invalid project at {include_path}: name not set!'
)
raise CompilationException(f"Invalid project at {include_path}: name not set!")
return partial.project_name
@@ -23,12 +22,13 @@ class AdapterPlugin:
:param dependencies: A list of adapter names that this adapter depends
upon.
"""
def __init__(
self,
adapter: Type[AdapterProtocol],
credentials: Type[Credentials],
include_path: str,
dependencies: Optional[List[str]] = None
dependencies: Optional[List[str]] = None,
):
self.adapter: Type[AdapterProtocol] = adapter

View File

@@ -15,7 +15,7 @@ class NodeWrapper:
self._inner_node = node
def __getattr__(self, name):
return getattr(self._inner_node, name, '')
return getattr(self._inner_node, name, "")
class _QueryComment(local):
@@ -24,6 +24,7 @@ class _QueryComment(local):
- the current thread's query comment.
- a source_name indicating what set the current thread's query comment
"""
def __init__(self, initial):
self.query_comment: Optional[str] = initial
self.append = False
@@ -35,16 +36,16 @@ class _QueryComment(local):
if self.append:
# replace last ';' with '<comment>;'
sql = sql.rstrip()
if sql[-1] == ';':
if sql[-1] == ";":
sql = sql[:-1]
return '{}\n/* {} */;'.format(sql, self.query_comment.strip())
return "{}\n/* {} */;".format(sql, self.query_comment.strip())
return '{}\n/* {} */'.format(sql, self.query_comment.strip())
return "{}\n/* {} */".format(sql, self.query_comment.strip())
return '/* {} */\n{}'.format(self.query_comment.strip(), sql)
return "/* {} */\n{}".format(self.query_comment.strip(), sql)
def set(self, comment: Optional[str], append: bool):
if isinstance(comment, str) and '*/' in comment:
if isinstance(comment, str) and "*/" in comment:
# tell the user "no" so they don't hurt themselves by writing
# garbage
raise RuntimeException(
@@ -63,15 +64,17 @@ class MacroQueryStringSetter:
self.config = config
comment_macro = self._get_comment_macro()
self.generator: QueryStringFunc = lambda name, model: ''
self.generator: QueryStringFunc = lambda name, model: ""
# if the comment value was None or the empty string, just skip it
if comment_macro:
assert isinstance(comment_macro, str)
macro = '\n'.join((
'{%- macro query_comment_macro(connection_name, node) -%}',
comment_macro,
'{% endmacro %}'
))
macro = "\n".join(
(
"{%- macro query_comment_macro(connection_name, node) -%}",
comment_macro,
"{% endmacro %}",
)
)
ctx = self._get_context()
self.generator = QueryStringGenerator(macro, ctx)
self.comment = _QueryComment(None)
@@ -87,7 +90,7 @@ class MacroQueryStringSetter:
return self.comment.add(sql)
def reset(self):
self.set('master', None)
self.set("master", None)
def set(self, name: str, node: Optional[CompileResultNode]):
wrapped: Optional[NodeWrapper] = None

View File

@@ -1,13 +1,16 @@
from collections.abc import Hashable
from dataclasses import dataclass
from typing import (
Optional, TypeVar, Any, Type, Dict, Union, Iterator, Tuple, Set
)
from typing import Optional, TypeVar, Any, Type, Dict, Union, Iterator, Tuple, Set
from dbt.contracts.graph.compiled import CompiledNode
from dbt.contracts.graph.parsed import ParsedSourceDefinition, ParsedNode
from dbt.contracts.relation import (
RelationType, ComponentName, HasQuoting, FakeAPIObject, Policy, Path
RelationType,
ComponentName,
HasQuoting,
FakeAPIObject,
Policy,
Path,
)
from dbt.exceptions import InternalException
from dbt.node_types import NodeType
@@ -16,7 +19,7 @@ from dbt.utils import filter_null_values, deep_merge, classproperty
import dbt.exceptions
Self = TypeVar('Self', bound='BaseRelation')
Self = TypeVar("Self", bound="BaseRelation")
@dataclass(frozen=True, eq=False, repr=False)
@@ -40,7 +43,7 @@ class BaseRelation(FakeAPIObject, Hashable):
if field.name == field_name:
return field
# this should be unreachable
raise ValueError(f'BaseRelation has no {field_name} field!')
raise ValueError(f"BaseRelation has no {field_name} field!")
def __eq__(self, other):
if not isinstance(other, self.__class__):
@@ -49,20 +52,18 @@ class BaseRelation(FakeAPIObject, Hashable):
@classmethod
def get_default_quote_policy(cls) -> Policy:
return cls._get_field_named('quote_policy').default
return cls._get_field_named("quote_policy").default
@classmethod
def get_default_include_policy(cls) -> Policy:
return cls._get_field_named('include_policy').default
return cls._get_field_named("include_policy").default
def get(self, key, default=None):
"""Override `.get` to return a metadata object so we don't break
dbt_utils.
"""
if key == 'metadata':
return {
'type': self.__class__.__name__
}
if key == "metadata":
return {"type": self.__class__.__name__}
return super().get(key, default)
def matches(
@@ -71,16 +72,19 @@ class BaseRelation(FakeAPIObject, Hashable):
schema: Optional[str] = None,
identifier: Optional[str] = None,
) -> bool:
search = filter_null_values({
ComponentName.Database: database,
ComponentName.Schema: schema,
ComponentName.Identifier: identifier
})
search = filter_null_values(
{
ComponentName.Database: database,
ComponentName.Schema: schema,
ComponentName.Identifier: identifier,
}
)
if not search:
# nothing was passed in
raise dbt.exceptions.RuntimeException(
"Tried to match relation, but no search path was passed!")
"Tried to match relation, but no search path was passed!"
)
exact_match = True
approximate_match = True
@@ -109,11 +113,13 @@ class BaseRelation(FakeAPIObject, Hashable):
schema: Optional[bool] = None,
identifier: Optional[bool] = None,
) -> Self:
policy = filter_null_values({
ComponentName.Database: database,
ComponentName.Schema: schema,
ComponentName.Identifier: identifier
})
policy = filter_null_values(
{
ComponentName.Database: database,
ComponentName.Schema: schema,
ComponentName.Identifier: identifier,
}
)
new_quote_policy = self.quote_policy.replace_dict(policy)
return self.replace(quote_policy=new_quote_policy)
@@ -124,16 +130,18 @@ class BaseRelation(FakeAPIObject, Hashable):
schema: Optional[bool] = None,
identifier: Optional[bool] = None,
) -> Self:
policy = filter_null_values({
ComponentName.Database: database,
ComponentName.Schema: schema,
ComponentName.Identifier: identifier
})
policy = filter_null_values(
{
ComponentName.Database: database,
ComponentName.Schema: schema,
ComponentName.Identifier: identifier,
}
)
new_include_policy = self.include_policy.replace_dict(policy)
return self.replace(include_policy=new_include_policy)
def information_schema(self, view_name=None) -> 'InformationSchema':
def information_schema(self, view_name=None) -> "InformationSchema":
# some of our data comes from jinja, where things can be `Undefined`.
if not isinstance(view_name, str):
view_name = None
@@ -143,10 +151,10 @@ class BaseRelation(FakeAPIObject, Hashable):
info_schema = InformationSchema.from_relation(self, view_name)
return info_schema.incorporate(path={"schema": None})
def information_schema_only(self) -> 'InformationSchema':
def information_schema_only(self) -> "InformationSchema":
return self.information_schema()
def without_identifier(self) -> 'BaseRelation':
def without_identifier(self) -> "BaseRelation":
"""Return a form of this relation that only has the database and schema
set to included. To get the appropriately-quoted form the schema out of
the result (for use as part of a query), use `.render()`. To get the
@@ -157,7 +165,7 @@ class BaseRelation(FakeAPIObject, Hashable):
return self.include(identifier=False).replace_path(identifier=None)
def _render_iterator(
self
self,
) -> Iterator[Tuple[Optional[ComponentName], Optional[str]]]:
for key in ComponentName:
@@ -170,13 +178,10 @@ class BaseRelation(FakeAPIObject, Hashable):
def render(self) -> str:
# if there is nothing set, this will return the empty string.
return '.'.join(
part for _, part in self._render_iterator()
if part is not None
)
return ".".join(part for _, part in self._render_iterator() if part is not None)
def quoted(self, identifier):
return '{quote_char}{identifier}{quote_char}'.format(
return "{quote_char}{identifier}{quote_char}".format(
quote_char=self.quote_character,
identifier=identifier,
)
@@ -186,11 +191,11 @@ class BaseRelation(FakeAPIObject, Hashable):
cls: Type[Self], source: ParsedSourceDefinition, **kwargs: Any
) -> Self:
source_quoting = source.quoting.to_dict(omit_none=True)
source_quoting.pop('column', None)
source_quoting.pop("column", None)
quote_policy = deep_merge(
cls.get_default_quote_policy().to_dict(omit_none=True),
source_quoting,
kwargs.get('quote_policy', {}),
kwargs.get("quote_policy", {}),
)
return cls.create(
@@ -198,12 +203,12 @@ class BaseRelation(FakeAPIObject, Hashable):
schema=source.schema,
identifier=source.identifier,
quote_policy=quote_policy,
**kwargs
**kwargs,
)
@staticmethod
def add_ephemeral_prefix(name: str):
return f'__dbt__cte__{name}'
return f"__dbt__cte__{name}"
@classmethod
def create_ephemeral_from_node(
@@ -236,7 +241,8 @@ class BaseRelation(FakeAPIObject, Hashable):
schema=node.schema,
identifier=node.alias,
quote_policy=quote_policy,
**kwargs)
**kwargs,
)
@classmethod
def create_from(
@@ -248,15 +254,16 @@ class BaseRelation(FakeAPIObject, Hashable):
if node.resource_type == NodeType.Source:
if not isinstance(node, ParsedSourceDefinition):
raise InternalException(
'type mismatch, expected ParsedSourceDefinition but got {}'
.format(type(node))
"type mismatch, expected ParsedSourceDefinition but got {}".format(
type(node)
)
)
return cls.create_from_source(node, **kwargs)
else:
if not isinstance(node, (ParsedNode, CompiledNode)):
raise InternalException(
'type mismatch, expected ParsedNode or CompiledNode but '
'got {}'.format(type(node))
"type mismatch, expected ParsedNode or CompiledNode but "
"got {}".format(type(node))
)
return cls.create_from_node(config, node, **kwargs)
@@ -269,14 +276,16 @@ class BaseRelation(FakeAPIObject, Hashable):
type: Optional[RelationType] = None,
**kwargs,
) -> Self:
kwargs.update({
'path': {
'database': database,
'schema': schema,
'identifier': identifier,
},
'type': type,
})
kwargs.update(
{
"path": {
"database": database,
"schema": schema,
"identifier": identifier,
},
"type": type,
}
)
return cls.from_dict(kwargs)
def __repr__(self) -> str:
@@ -342,7 +351,7 @@ class BaseRelation(FakeAPIObject, Hashable):
return RelationType
Info = TypeVar('Info', bound='InformationSchema')
Info = TypeVar("Info", bound="InformationSchema")
@dataclass(frozen=True, eq=False, repr=False)
@@ -352,7 +361,7 @@ class InformationSchema(BaseRelation):
def __post_init__(self):
if not isinstance(self.information_schema_view, (type(None), str)):
raise dbt.exceptions.CompilationException(
'Got an invalid name: {}'.format(self.information_schema_view)
"Got an invalid name: {}".format(self.information_schema_view)
)
@classmethod
@@ -362,7 +371,7 @@ class InformationSchema(BaseRelation):
return Path(
database=relation.database,
schema=relation.schema,
identifier='INFORMATION_SCHEMA',
identifier="INFORMATION_SCHEMA",
)
@classmethod
@@ -393,9 +402,7 @@ class InformationSchema(BaseRelation):
relation: BaseRelation,
information_schema_view: Optional[str],
) -> Info:
include_policy = cls.get_include_policy(
relation, information_schema_view
)
include_policy = cls.get_include_policy(relation, information_schema_view)
quote_policy = cls.get_quote_policy(relation, information_schema_view)
path = cls.get_path(relation, information_schema_view)
return cls(
@@ -417,6 +424,7 @@ class SchemaSearchMap(Dict[InformationSchema, Set[Optional[str]]]):
search for what schemas. The schema values are all lowercased to avoid
duplication.
"""
def add(self, relation: BaseRelation):
key = relation.information_schema_only()
if key not in self:
@@ -426,9 +434,7 @@ class SchemaSearchMap(Dict[InformationSchema, Set[Optional[str]]]):
schema = relation.schema.lower()
self[key].add(schema)
def search(
self
) -> Iterator[Tuple[InformationSchema, Optional[str]]]:
def search(self) -> Iterator[Tuple[InformationSchema, Optional[str]]]:
for information_schema_name, schemas in self.items():
for schema in schemas:
yield information_schema_name, schema
@@ -442,14 +448,13 @@ class SchemaSearchMap(Dict[InformationSchema, Set[Optional[str]]]):
dbt.exceptions.raise_compiler_error(str(seen))
for information_schema_name, schema in self.search():
path = {
'database': information_schema_name.database,
'schema': schema
}
new.add(information_schema_name.incorporate(
path=path,
quote_policy={'database': False},
include_policy={'database': False},
))
path = {"database": information_schema_name.database, "schema": schema}
new.add(
information_schema_name.incorporate(
path=path,
quote_policy={"database": False},
include_policy={"database": False},
)
)
return new

View File

@@ -7,7 +7,7 @@ from dbt.logger import CACHE_LOGGER as logger
from dbt.utils import lowercase
import dbt.exceptions
_ReferenceKey = namedtuple('_ReferenceKey', 'database schema identifier')
_ReferenceKey = namedtuple("_ReferenceKey", "database schema identifier")
def _make_key(relation) -> _ReferenceKey:
@@ -15,9 +15,11 @@ def _make_key(relation) -> _ReferenceKey:
to keep track of quoting
"""
# databases and schemas can both be None
return _ReferenceKey(lowercase(relation.database),
lowercase(relation.schema),
lowercase(relation.identifier))
return _ReferenceKey(
lowercase(relation.database),
lowercase(relation.schema),
lowercase(relation.identifier),
)
def dot_separated(key: _ReferenceKey) -> str:
@@ -25,7 +27,7 @@ def dot_separated(key: _ReferenceKey) -> str:
:param _ReferenceKey key: The key to stringify.
"""
return '.'.join(map(str, key))
return ".".join(map(str, key))
class _CachedRelation:
@@ -37,13 +39,14 @@ class _CachedRelation:
that refer to this relation.
:attr BaseRelation inner: The underlying dbt relation.
"""
def __init__(self, inner):
self.referenced_by = {}
self.inner = inner
def __str__(self) -> str:
return (
'_CachedRelation(database={}, schema={}, identifier={}, inner={})'
"_CachedRelation(database={}, schema={}, identifier={}, inner={})"
).format(self.database, self.schema, self.identifier, self.inner)
@property
@@ -78,7 +81,7 @@ class _CachedRelation:
"""
return _make_key(self)
def add_reference(self, referrer: '_CachedRelation'):
def add_reference(self, referrer: "_CachedRelation"):
"""Add a reference from referrer to self, indicating that if this node
were drop...cascaded, the referrer would be dropped as well.
@@ -122,9 +125,9 @@ class _CachedRelation:
# table_name is ever anything but the identifier (via .create())
self.inner = self.inner.incorporate(
path={
'database': new_relation.inner.database,
'schema': new_relation.inner.schema,
'identifier': new_relation.inner.identifier
"database": new_relation.inner.database,
"schema": new_relation.inner.schema,
"identifier": new_relation.inner.identifier,
},
)
@@ -140,8 +143,9 @@ class _CachedRelation:
"""
if new_key in self.referenced_by:
dbt.exceptions.raise_cache_inconsistent(
'in rename of "{}" -> "{}", new name is in the cache already'
.format(old_key, new_key)
'in rename of "{}" -> "{}", new name is in the cache already'.format(
old_key, new_key
)
)
if old_key not in self.referenced_by:
@@ -172,13 +176,16 @@ class RelationsCache:
The adapters also hold this lock while filling the cache.
:attr Set[str] schemas: The set of known/cached schemas, all lowercased.
"""
def __init__(self) -> None:
self.relations: Dict[_ReferenceKey, _CachedRelation] = {}
self.lock = threading.RLock()
self.schemas: Set[Tuple[Optional[str], Optional[str]]] = set()
def add_schema(
self, database: Optional[str], schema: Optional[str],
self,
database: Optional[str],
schema: Optional[str],
) -> None:
"""Add a schema to the set of known schemas (case-insensitive)
@@ -188,7 +195,9 @@ class RelationsCache:
self.schemas.add((lowercase(database), lowercase(schema)))
def drop_schema(
self, database: Optional[str], schema: Optional[str],
self,
database: Optional[str],
schema: Optional[str],
) -> None:
"""Drop the given schema and remove it from the set of known schemas.
@@ -263,15 +272,15 @@ class RelationsCache:
return
if referenced is None:
dbt.exceptions.raise_cache_inconsistent(
'in add_link, referenced link key {} not in cache!'
.format(referenced_key)
"in add_link, referenced link key {} not in cache!".format(
referenced_key
)
)
dependent = self.relations.get(dependent_key)
if dependent is None:
dbt.exceptions.raise_cache_inconsistent(
'in add_link, dependent link key {} not in cache!'
.format(dependent_key)
"in add_link, dependent link key {} not in cache!".format(dependent_key)
)
assert dependent is not None # we just raised!
@@ -298,28 +307,23 @@ class RelationsCache:
# referring to a table outside our control. There's no need to make
# a link - we will never drop the referenced relation during a run.
logger.debug(
'{dep!s} references {ref!s} but {ref.database}.{ref.schema} '
'is not in the cache, skipping assumed external relation'
.format(dep=dependent, ref=ref_key)
"{dep!s} references {ref!s} but {ref.database}.{ref.schema} "
"is not in the cache, skipping assumed external relation".format(
dep=dependent, ref=ref_key
)
)
return
if ref_key not in self.relations:
# Insert a dummy "external" relation.
referenced = referenced.replace(
type=referenced.External
)
referenced = referenced.replace(type=referenced.External)
self.add(referenced)
dep_key = _make_key(dependent)
if dep_key not in self.relations:
# Insert a dummy "external" relation.
dependent = dependent.replace(
type=referenced.External
)
dependent = dependent.replace(type=referenced.External)
self.add(dependent)
logger.debug(
'adding link, {!s} references {!s}'.format(dep_key, ref_key)
)
logger.debug("adding link, {!s} references {!s}".format(dep_key, ref_key))
with self.lock:
self._add_link(ref_key, dep_key)
@@ -330,14 +334,14 @@ class RelationsCache:
:param BaseRelation relation: The underlying relation.
"""
cached = _CachedRelation(relation)
logger.debug('Adding relation: {!s}'.format(cached))
logger.debug("Adding relation: {!s}".format(cached))
lazy_log('before adding: {!s}', self.dump_graph)
lazy_log("before adding: {!s}", self.dump_graph)
with self.lock:
self._setdefault(cached)
lazy_log('after adding: {!s}', self.dump_graph)
lazy_log("after adding: {!s}", self.dump_graph)
def _remove_refs(self, keys):
"""Removes all references to all entries in keys. This does not
@@ -359,13 +363,10 @@ class RelationsCache:
:param _CachedRelation dropped: An existing _CachedRelation to drop.
"""
if dropped not in self.relations:
logger.debug('dropped a nonexistent relationship: {!s}'
.format(dropped))
logger.debug("dropped a nonexistent relationship: {!s}".format(dropped))
return
consequences = self.relations[dropped].collect_consequences()
logger.debug(
'drop {} is cascading to {}'.format(dropped, consequences)
)
logger.debug("drop {} is cascading to {}".format(dropped, consequences))
self._remove_refs(consequences)
def drop(self, relation):
@@ -380,7 +381,7 @@ class RelationsCache:
:param str identifier: The identifier of the relation to drop.
"""
dropped = _make_key(relation)
logger.debug('Dropping relation: {!s}'.format(dropped))
logger.debug("Dropping relation: {!s}".format(dropped))
with self.lock:
self._drop_cascade_relation(dropped)
@@ -404,8 +405,9 @@ class RelationsCache:
for cached in self.relations.values():
if cached.is_referenced_by(old_key):
logger.debug(
'updated reference from {0} -> {2} to {1} -> {2}'
.format(old_key, new_key, cached.key())
"updated reference from {0} -> {2} to {1} -> {2}".format(
old_key, new_key, cached.key()
)
)
cached.rename_key(old_key, new_key)
@@ -430,14 +432,16 @@ class RelationsCache:
"""
if new_key in self.relations:
dbt.exceptions.raise_cache_inconsistent(
'in rename, new key {} already in cache: {}'
.format(new_key, list(self.relations.keys()))
"in rename, new key {} already in cache: {}".format(
new_key, list(self.relations.keys())
)
)
if old_key not in self.relations:
logger.debug(
'old key {} not found in self.relations, assuming temporary'
.format(old_key)
"old key {} not found in self.relations, assuming temporary".format(
old_key
)
)
return False
return True
@@ -456,11 +460,9 @@ class RelationsCache:
"""
old_key = _make_key(old)
new_key = _make_key(new)
logger.debug('Renaming relation {!s} to {!s}'.format(
old_key, new_key
))
logger.debug("Renaming relation {!s} to {!s}".format(old_key, new_key))
lazy_log('before rename: {!s}', self.dump_graph)
lazy_log("before rename: {!s}", self.dump_graph)
with self.lock:
if self._check_rename_constraints(old_key, new_key):
@@ -468,7 +470,7 @@ class RelationsCache:
else:
self._setdefault(_CachedRelation(new))
lazy_log('after rename: {!s}', self.dump_graph)
lazy_log("after rename: {!s}", self.dump_graph)
def get_relations(
self, database: Optional[str], schema: Optional[str]
@@ -483,14 +485,14 @@ class RelationsCache:
schema = lowercase(schema)
with self.lock:
results = [
r.inner for r in self.relations.values()
if (lowercase(r.schema) == schema and
lowercase(r.database) == database)
r.inner
for r in self.relations.values()
if (lowercase(r.schema) == schema and lowercase(r.database) == database)
]
if None in results:
dbt.exceptions.raise_cache_inconsistent(
'in get_relations, a None relation was found in the cache!'
"in get_relations, a None relation was found in the cache!"
)
return results

View File

@@ -50,9 +50,7 @@ class AdapterContainer:
adapter = self.get_adapter_class_by_name(name)
return adapter.Relation
def get_config_class_by_name(
self, name: str
) -> Type[AdapterConfig]:
def get_config_class_by_name(self, name: str) -> Type[AdapterConfig]:
adapter = self.get_adapter_class_by_name(name)
return adapter.AdapterSpecificConfigs
@@ -62,24 +60,24 @@ class AdapterContainer:
# singletons
try:
# mypy doesn't think modules have any attributes.
mod: Any = import_module('.' + name, 'dbt.adapters')
mod: Any = import_module("." + name, "dbt.adapters")
except ModuleNotFoundError as exc:
# if we failed to import the target module in particular, inform
# the user about it via a runtime error
if exc.name == 'dbt.adapters.' + name:
raise RuntimeException(f'Could not find adapter type {name}!')
logger.info(f'Error importing adapter: {exc}')
if exc.name == "dbt.adapters." + name:
raise RuntimeException(f"Could not find adapter type {name}!")
logger.info(f"Error importing adapter: {exc}")
# otherwise, the error had to have come from some underlying
# library. Log the stack trace.
logger.debug('', exc_info=True)
logger.debug("", exc_info=True)
raise
plugin: AdapterPlugin = mod.Plugin
plugin_type = plugin.adapter.type()
if plugin_type != name:
raise RuntimeException(
f'Expected to find adapter with type named {name}, got '
f'adapter with type {plugin_type}'
f"Expected to find adapter with type named {name}, got "
f"adapter with type {plugin_type}"
)
with self.lock:
@@ -109,8 +107,7 @@ class AdapterContainer:
return self.adapters[adapter_name]
def reset_adapters(self):
"""Clear the adapters. This is useful for tests, which change configs.
"""
"""Clear the adapters. This is useful for tests, which change configs."""
with self.lock:
for adapter in self.adapters.values():
adapter.cleanup_connections()
@@ -140,9 +137,7 @@ class AdapterContainer:
try:
plugin = self.plugins[plugin_name]
except KeyError:
raise InternalException(
f'No plugin found for {plugin_name}'
) from None
raise InternalException(f"No plugin found for {plugin_name}") from None
plugins.append(plugin)
seen.add(plugin_name)
if plugin.dependencies is None:
@@ -166,7 +161,7 @@ class AdapterContainer:
path = self.packages[package_name]
except KeyError:
raise InternalException(
f'No internal package listing found for {package_name}'
f"No internal package listing found for {package_name}"
)
paths.append(path)
return paths
@@ -187,8 +182,7 @@ def get_adapter(config: AdapterRequiredConfig):
def reset_adapters():
"""Clear the adapters. This is useful for tests, which change configs.
"""
"""Clear the adapters. This is useful for tests, which change configs."""
FACTORY.reset_adapters()

View File

@@ -1,17 +1,27 @@
from dataclasses import dataclass
from typing import (
Type, Hashable, Optional, ContextManager, List, Generic, TypeVar, ClassVar,
Tuple, Union, Dict, Any
Type,
Hashable,
Optional,
ContextManager,
List,
Generic,
TypeVar,
ClassVar,
Tuple,
Union,
Dict,
Any,
)
from typing_extensions import Protocol
import agate
from dbt.contracts.connection import (
Connection, AdapterRequiredConfig, AdapterResponse
)
from dbt.contracts.connection import Connection, AdapterRequiredConfig, AdapterResponse
from dbt.contracts.graph.compiled import (
CompiledNode, ManifestNode, NonSourceCompiledNode
CompiledNode,
ManifestNode,
NonSourceCompiledNode,
)
from dbt.contracts.graph.parsed import ParsedNode, ParsedSourceDefinition
from dbt.contracts.graph.model_config import BaseConfig
@@ -34,7 +44,7 @@ class ColumnProtocol(Protocol):
pass
Self = TypeVar('Self', bound='RelationProtocol')
Self = TypeVar("Self", bound="RelationProtocol")
class RelationProtocol(Protocol):
@@ -64,19 +74,11 @@ class CompilerProtocol(Protocol):
...
AdapterConfig_T = TypeVar(
'AdapterConfig_T', bound=AdapterConfig
)
ConnectionManager_T = TypeVar(
'ConnectionManager_T', bound=ConnectionManagerProtocol
)
Relation_T = TypeVar(
'Relation_T', bound=RelationProtocol
)
Column_T = TypeVar(
'Column_T', bound=ColumnProtocol
)
Compiler_T = TypeVar('Compiler_T', bound=CompilerProtocol)
AdapterConfig_T = TypeVar("AdapterConfig_T", bound=AdapterConfig)
ConnectionManager_T = TypeVar("ConnectionManager_T", bound=ConnectionManagerProtocol)
Relation_T = TypeVar("Relation_T", bound=RelationProtocol)
Column_T = TypeVar("Column_T", bound=ColumnProtocol)
Compiler_T = TypeVar("Compiler_T", bound=CompilerProtocol)
class AdapterProtocol(
@@ -87,7 +89,7 @@ class AdapterProtocol(
Relation_T,
Column_T,
Compiler_T,
]
],
):
AdapterSpecificConfigs: ClassVar[Type[AdapterConfig_T]]
Column: ClassVar[Type[Column_T]]

View File

@@ -7,9 +7,7 @@ import agate
import dbt.clients.agate_helper
import dbt.exceptions
from dbt.adapters.base import BaseConnectionManager
from dbt.contracts.connection import (
Connection, ConnectionState, AdapterResponse
)
from dbt.contracts.connection import Connection, ConnectionState, AdapterResponse
from dbt.logger import GLOBAL_LOGGER as logger
from dbt import flags
@@ -23,11 +21,12 @@ class SQLConnectionManager(BaseConnectionManager):
- get_response
- open
"""
@abc.abstractmethod
def cancel(self, connection: Connection):
"""Cancel the given connection."""
raise dbt.exceptions.NotImplementedException(
'`cancel` is not implemented for this adapter!'
"`cancel` is not implemented for this adapter!"
)
def cancel_open(self) -> List[str]:
@@ -41,8 +40,8 @@ class SQLConnectionManager(BaseConnectionManager):
# if the connection failed, the handle will be None so we have
# nothing to cancel.
if (
connection.handle is not None and
connection.state == ConnectionState.OPEN
connection.handle is not None
and connection.state == ConnectionState.OPEN
):
self.cancel(connection)
if connection.name is not None:
@@ -54,23 +53,22 @@ class SQLConnectionManager(BaseConnectionManager):
sql: str,
auto_begin: bool = True,
bindings: Optional[Any] = None,
abridge_sql_log: bool = False
abridge_sql_log: bool = False,
) -> Tuple[Connection, Any]:
connection = self.get_thread_connection()
if auto_begin and connection.transaction_open is False:
self.begin()
logger.debug('Using {} connection "{}".'
.format(self.TYPE, connection.name))
logger.debug('Using {} connection "{}".'.format(self.TYPE, connection.name))
with self.exception_handler(sql):
if abridge_sql_log:
log_sql = '{}...'.format(sql[:512])
log_sql = "{}...".format(sql[:512])
else:
log_sql = sql
logger.debug(
'On {connection_name}: {sql}',
"On {connection_name}: {sql}",
connection_name=connection.name,
sql=log_sql,
)
@@ -81,7 +79,7 @@ class SQLConnectionManager(BaseConnectionManager):
logger.debug(
"SQL status: {status} in {elapsed:0.2f} seconds",
status=self.get_response(cursor),
elapsed=(time.time() - pre)
elapsed=(time.time() - pre),
)
return connection, cursor
@@ -90,14 +88,12 @@ class SQLConnectionManager(BaseConnectionManager):
def get_response(cls, cursor: Any) -> Union[AdapterResponse, str]:
"""Get the status of the cursor."""
raise dbt.exceptions.NotImplementedException(
'`get_response` is not implemented for this adapter!'
"`get_response` is not implemented for this adapter!"
)
@classmethod
def process_results(
cls,
column_names: Iterable[str],
rows: Iterable[Any]
cls, column_names: Iterable[str], rows: Iterable[Any]
) -> List[Dict[str, Any]]:
return [dict(zip(column_names, row)) for row in rows]
@@ -112,10 +108,7 @@ class SQLConnectionManager(BaseConnectionManager):
rows = cursor.fetchall()
data = cls.process_results(column_names, rows)
return dbt.clients.agate_helper.table_from_data_flat(
data,
column_names
)
return dbt.clients.agate_helper.table_from_data_flat(data, column_names)
def execute(
self, sql: str, auto_begin: bool = False, fetch: bool = False
@@ -130,10 +123,10 @@ class SQLConnectionManager(BaseConnectionManager):
return response, table
def add_begin_query(self):
return self.add_query('BEGIN', auto_begin=False)
return self.add_query("BEGIN", auto_begin=False)
def add_commit_query(self):
return self.add_query('COMMIT', auto_begin=False)
return self.add_query("COMMIT", auto_begin=False)
def begin(self):
connection = self.get_thread_connection()
@@ -141,13 +134,14 @@ class SQLConnectionManager(BaseConnectionManager):
if flags.STRICT_MODE:
if not isinstance(connection, Connection):
raise dbt.exceptions.CompilerException(
f'In begin, got {connection} - not a Connection!'
f"In begin, got {connection} - not a Connection!"
)
if connection.transaction_open is True:
raise dbt.exceptions.InternalException(
'Tried to begin a new transaction on connection "{}", but '
'it already had one open!'.format(connection.name))
"it already had one open!".format(connection.name)
)
self.add_begin_query()
@@ -159,15 +153,16 @@ class SQLConnectionManager(BaseConnectionManager):
if flags.STRICT_MODE:
if not isinstance(connection, Connection):
raise dbt.exceptions.CompilerException(
f'In commit, got {connection} - not a Connection!'
f"In commit, got {connection} - not a Connection!"
)
if connection.transaction_open is False:
raise dbt.exceptions.InternalException(
'Tried to commit transaction on connection "{}", but '
'it does not have one open!'.format(connection.name))
"it does not have one open!".format(connection.name)
)
logger.debug('On {}: COMMIT'.format(connection.name))
logger.debug("On {}: COMMIT".format(connection.name))
self.add_commit_query()
connection.transaction_open = False

View File

@@ -10,16 +10,16 @@ from dbt.logger import GLOBAL_LOGGER as logger
from dbt.adapters.base.relation import BaseRelation
LIST_RELATIONS_MACRO_NAME = 'list_relations_without_caching'
GET_COLUMNS_IN_RELATION_MACRO_NAME = 'get_columns_in_relation'
LIST_SCHEMAS_MACRO_NAME = 'list_schemas'
CHECK_SCHEMA_EXISTS_MACRO_NAME = 'check_schema_exists'
CREATE_SCHEMA_MACRO_NAME = 'create_schema'
DROP_SCHEMA_MACRO_NAME = 'drop_schema'
RENAME_RELATION_MACRO_NAME = 'rename_relation'
TRUNCATE_RELATION_MACRO_NAME = 'truncate_relation'
DROP_RELATION_MACRO_NAME = 'drop_relation'
ALTER_COLUMN_TYPE_MACRO_NAME = 'alter_column_type'
LIST_RELATIONS_MACRO_NAME = "list_relations_without_caching"
GET_COLUMNS_IN_RELATION_MACRO_NAME = "get_columns_in_relation"
LIST_SCHEMAS_MACRO_NAME = "list_schemas"
CHECK_SCHEMA_EXISTS_MACRO_NAME = "check_schema_exists"
CREATE_SCHEMA_MACRO_NAME = "create_schema"
DROP_SCHEMA_MACRO_NAME = "drop_schema"
RENAME_RELATION_MACRO_NAME = "rename_relation"
TRUNCATE_RELATION_MACRO_NAME = "truncate_relation"
DROP_RELATION_MACRO_NAME = "drop_relation"
ALTER_COLUMN_TYPE_MACRO_NAME = "alter_column_type"
class SQLAdapter(BaseAdapter):
@@ -60,30 +60,23 @@ class SQLAdapter(BaseAdapter):
:param abridge_sql_log: If set, limit the raw sql logged to 512
characters
"""
return self.connections.add_query(sql, auto_begin, bindings,
abridge_sql_log)
return self.connections.add_query(sql, auto_begin, bindings, abridge_sql_log)
@classmethod
def convert_text_type(cls, agate_table: agate.Table, col_idx: int) -> str:
return "text"
@classmethod
def convert_number_type(
cls, agate_table: agate.Table, col_idx: int
) -> str:
decimals = agate_table.aggregate(agate.MaxPrecision(col_idx))
def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str:
decimals = agate_table.aggregate(agate.MaxPrecision(col_idx)) # type: ignore
return "float8" if decimals else "integer"
@classmethod
def convert_boolean_type(
cls, agate_table: agate.Table, col_idx: int
) -> str:
def convert_boolean_type(cls, agate_table: agate.Table, col_idx: int) -> str:
return "boolean"
@classmethod
def convert_datetime_type(
cls, agate_table: agate.Table, col_idx: int
) -> str:
def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str:
return "timestamp without time zone"
@classmethod
@@ -99,31 +92,28 @@ class SQLAdapter(BaseAdapter):
return True
def expand_column_types(self, goal, current):
reference_columns = {
c.name: c for c in
self.get_columns_in_relation(goal)
}
reference_columns = {c.name: c for c in self.get_columns_in_relation(goal)}
target_columns = {
c.name: c for c
in self.get_columns_in_relation(current)
}
target_columns = {c.name: c for c in self.get_columns_in_relation(current)}
for column_name, reference_column in reference_columns.items():
target_column = target_columns.get(column_name)
if target_column is not None and \
target_column.can_expand_to(reference_column):
if target_column is not None and target_column.can_expand_to(
reference_column
):
col_string_size = reference_column.string_size()
new_type = self.Column.string_type(col_string_size)
logger.debug("Changing col type from {} to {} in table {}",
target_column.data_type, new_type, current)
logger.debug(
"Changing col type from {} to {} in table {}",
target_column.data_type,
new_type,
current,
)
self.alter_column_type(current, column_name, new_type)
def alter_column_type(
self, relation, column_name, new_column_type
) -> None:
def alter_column_type(self, relation, column_name, new_column_type) -> None:
"""
1. Create a new column (w/ temp name and correct type)
2. Copy data over to it
@@ -131,53 +121,40 @@ class SQLAdapter(BaseAdapter):
4. Rename the new column to existing column
"""
kwargs = {
'relation': relation,
'column_name': column_name,
'new_column_type': new_column_type,
"relation": relation,
"column_name": column_name,
"new_column_type": new_column_type,
}
self.execute_macro(
ALTER_COLUMN_TYPE_MACRO_NAME,
kwargs=kwargs
)
self.execute_macro(ALTER_COLUMN_TYPE_MACRO_NAME, kwargs=kwargs)
def drop_relation(self, relation):
if relation.type is None:
dbt.exceptions.raise_compiler_error(
'Tried to drop relation {}, but its type is null.'
.format(relation))
"Tried to drop relation {}, but its type is null.".format(relation)
)
self.cache_dropped(relation)
self.execute_macro(
DROP_RELATION_MACRO_NAME,
kwargs={'relation': relation}
)
self.execute_macro(DROP_RELATION_MACRO_NAME, kwargs={"relation": relation})
def truncate_relation(self, relation):
self.execute_macro(
TRUNCATE_RELATION_MACRO_NAME,
kwargs={'relation': relation}
)
self.execute_macro(TRUNCATE_RELATION_MACRO_NAME, kwargs={"relation": relation})
def rename_relation(self, from_relation, to_relation):
self.cache_renamed(from_relation, to_relation)
kwargs = {'from_relation': from_relation, 'to_relation': to_relation}
self.execute_macro(
RENAME_RELATION_MACRO_NAME,
kwargs=kwargs
)
kwargs = {"from_relation": from_relation, "to_relation": to_relation}
self.execute_macro(RENAME_RELATION_MACRO_NAME, kwargs=kwargs)
def get_columns_in_relation(self, relation):
return self.execute_macro(
GET_COLUMNS_IN_RELATION_MACRO_NAME,
kwargs={'relation': relation}
GET_COLUMNS_IN_RELATION_MACRO_NAME, kwargs={"relation": relation}
)
def create_schema(self, relation: BaseRelation) -> None:
relation = relation.without_identifier()
logger.debug('Creating schema "{}"', relation)
kwargs = {
'relation': relation,
"relation": relation,
}
self.execute_macro(CREATE_SCHEMA_MACRO_NAME, kwargs=kwargs)
self.commit_if_has_connection()
@@ -188,39 +165,35 @@ class SQLAdapter(BaseAdapter):
relation = relation.without_identifier()
logger.debug('Dropping schema "{}".', relation)
kwargs = {
'relation': relation,
"relation": relation,
}
self.execute_macro(DROP_SCHEMA_MACRO_NAME, kwargs=kwargs)
# we can update the cache here
self.cache.drop_schema(relation.database, relation.schema)
def list_relations_without_caching(
self, schema_relation: BaseRelation,
self,
schema_relation: BaseRelation,
) -> List[BaseRelation]:
kwargs = {'schema_relation': schema_relation}
results = self.execute_macro(
LIST_RELATIONS_MACRO_NAME,
kwargs=kwargs
)
kwargs = {"schema_relation": schema_relation}
results = self.execute_macro(LIST_RELATIONS_MACRO_NAME, kwargs=kwargs)
relations = []
quote_policy = {
'database': True,
'schema': True,
'identifier': True
}
quote_policy = {"database": True, "schema": True, "identifier": True}
for _database, name, _schema, _type in results:
try:
_type = self.Relation.get_relation_type(_type)
except ValueError:
_type = self.Relation.External
relations.append(self.Relation.create(
database=_database,
schema=_schema,
identifier=name,
quote_policy=quote_policy,
type=_type
))
relations.append(
self.Relation.create(
database=_database,
schema=_schema,
identifier=name,
quote_policy=quote_policy,
type=_type,
)
)
return relations
def quote(self, identifier):
@@ -228,8 +201,7 @@ class SQLAdapter(BaseAdapter):
def list_schemas(self, database: str) -> List[str]:
results = self.execute_macro(
LIST_SCHEMAS_MACRO_NAME,
kwargs={'database': database}
LIST_SCHEMAS_MACRO_NAME, kwargs={"database": database}
)
return [row[0] for row in results]
@@ -238,13 +210,10 @@ class SQLAdapter(BaseAdapter):
information_schema = self.Relation.create(
database=database,
schema=schema,
identifier='INFORMATION_SCHEMA',
quote_policy=self.config.quoting
identifier="INFORMATION_SCHEMA",
quote_policy=self.config.quoting,
).information_schema()
kwargs = {'information_schema': information_schema, 'schema': schema}
results = self.execute_macro(
CHECK_SCHEMA_EXISTS_MACRO_NAME,
kwargs=kwargs
)
kwargs = {"information_schema": information_schema, "schema": schema}
results = self.execute_macro(CHECK_SCHEMA_EXISTS_MACRO_NAME, kwargs=kwargs)
return results[0][0] > 0

View File

@@ -10,79 +10,89 @@ def regex(pat):
class BlockData:
"""raw plaintext data from the top level of the file."""
def __init__(self, contents):
self.block_type_name = '__dbt__data'
self.block_type_name = "__dbt__data"
self.contents = contents
self.full_block = contents
class BlockTag:
def __init__(self, block_type_name, block_name, contents=None,
full_block=None, **kw):
def __init__(
self, block_type_name, block_name, contents=None, full_block=None, **kw
):
self.block_type_name = block_type_name
self.block_name = block_name
self.contents = contents
self.full_block = full_block
def __str__(self):
return 'BlockTag({!r}, {!r})'.format(self.block_type_name,
self.block_name)
return "BlockTag({!r}, {!r})".format(self.block_type_name, self.block_name)
def __repr__(self):
return str(self)
@property
def end_block_type_name(self):
return 'end{}'.format(self.block_type_name)
return "end{}".format(self.block_type_name)
def end_pat(self):
# we don't want to use string formatting here because jinja uses most
# of the string formatting operators in its syntax...
pattern = ''.join((
r'(?P<endblock>((?:\s*\{\%\-|\{\%)\s*',
self.end_block_type_name,
r'\s*(?:\-\%\}\s*|\%\})))',
))
pattern = "".join(
(
r"(?P<endblock>((?:\s*\{\%\-|\{\%)\s*",
self.end_block_type_name,
r"\s*(?:\-\%\}\s*|\%\})))",
)
)
return regex(pattern)
Tag = namedtuple('Tag', 'block_type_name block_name start end')
Tag = namedtuple("Tag", "block_type_name block_name start end")
_NAME_PATTERN = r'[A-Za-z_][A-Za-z_0-9]*'
_NAME_PATTERN = r"[A-Za-z_][A-Za-z_0-9]*"
COMMENT_START_PATTERN = regex(r'(?:(?P<comment_start>(\s*\{\#)))')
COMMENT_END_PATTERN = regex(r'(.*?)(\s*\#\})')
COMMENT_START_PATTERN = regex(r"(?:(?P<comment_start>(\s*\{\#)))")
COMMENT_END_PATTERN = regex(r"(.*?)(\s*\#\})")
RAW_START_PATTERN = regex(
r'(?:\s*\{\%\-|\{\%)\s*(?P<raw_start>(raw))\s*(?:\-\%\}\s*|\%\})'
r"(?:\s*\{\%\-|\{\%)\s*(?P<raw_start>(raw))\s*(?:\-\%\}\s*|\%\})"
)
EXPR_START_PATTERN = regex(r'(?P<expr_start>(\{\{\s*))')
EXPR_END_PATTERN = regex(r'(?P<expr_end>(\s*\}\}))')
EXPR_START_PATTERN = regex(r"(?P<expr_start>(\{\{\s*))")
EXPR_END_PATTERN = regex(r"(?P<expr_end>(\s*\}\}))")
BLOCK_START_PATTERN = regex(''.join((
r'(?:\s*\{\%\-|\{\%)\s*',
r'(?P<block_type_name>({}))'.format(_NAME_PATTERN),
# some blocks have a 'block name'.
r'(?:\s+(?P<block_name>({})))?'.format(_NAME_PATTERN),
)))
BLOCK_START_PATTERN = regex(
"".join(
(
r"(?:\s*\{\%\-|\{\%)\s*",
r"(?P<block_type_name>({}))".format(_NAME_PATTERN),
# some blocks have a 'block name'.
r"(?:\s+(?P<block_name>({})))?".format(_NAME_PATTERN),
)
)
)
RAW_BLOCK_PATTERN = regex(''.join((
r'(?:\s*\{\%\-|\{\%)\s*raw\s*(?:\-\%\}\s*|\%\})',
r'(?:.*?)',
r'(?:\s*\{\%\-|\{\%)\s*endraw\s*(?:\-\%\}\s*|\%\})',
)))
RAW_BLOCK_PATTERN = regex(
"".join(
(
r"(?:\s*\{\%\-|\{\%)\s*raw\s*(?:\-\%\}\s*|\%\})",
r"(?:.*?)",
r"(?:\s*\{\%\-|\{\%)\s*endraw\s*(?:\-\%\}\s*|\%\})",
)
)
)
TAG_CLOSE_PATTERN = regex(r'(?:(?P<tag_close>(\-\%\}\s*|\%\})))')
TAG_CLOSE_PATTERN = regex(r"(?:(?P<tag_close>(\-\%\}\s*|\%\})))")
# stolen from jinja's lexer. Note that we've consumed all prefix whitespace by
# the time we want to use this.
STRING_PATTERN = regex(
r"(?P<string>('([^'\\]*(?:\\.[^'\\]*)*)'|"
r'"([^"\\]*(?:\\.[^"\\]*)*)"))'
r"(?P<string>('([^'\\]*(?:\\.[^'\\]*)*)'|" r'"([^"\\]*(?:\\.[^"\\]*)*)"))'
)
QUOTE_START_PATTERN = regex(r'''(?P<quote>(['"]))''')
QUOTE_START_PATTERN = regex(r"""(?P<quote>(['"]))""")
class TagIterator:
@@ -99,10 +109,10 @@ class TagIterator:
end_val: int = self.pos if end is None else end
data = self.data[:end_val]
# if not found, rfind returns -1, and -1+1=0, which is perfect!
last_line_start = data.rfind('\n') + 1
last_line_start = data.rfind("\n") + 1
# it's easy to forget this, but line numbers are 1-indexed
line_number = data.count('\n') + 1
return f'{line_number}:{end_val - last_line_start}'
line_number = data.count("\n") + 1
return f"{line_number}:{end_val - last_line_start}"
def advance(self, new_position):
self.pos = new_position
@@ -120,7 +130,7 @@ class TagIterator:
matches = []
for pattern in patterns:
# default to 'search', but sometimes we want to 'match'.
if kwargs.get('method', 'search') == 'search':
if kwargs.get("method", "search") == "search":
match = self._search(pattern)
else:
match = self._match(pattern)
@@ -136,7 +146,7 @@ class TagIterator:
match = self._first_match(*patterns, **kwargs)
if match is None:
msg = 'unexpected EOF, expected {}, got "{}"'.format(
expected_name, self.data[self.pos:]
expected_name, self.data[self.pos :]
)
dbt.exceptions.raise_compiler_error(msg)
return match
@@ -156,22 +166,20 @@ class TagIterator:
"""
self.advance(match.end())
while True:
match = self._expect_match('}}',
EXPR_END_PATTERN,
QUOTE_START_PATTERN)
if match.groupdict().get('expr_end') is not None:
match = self._expect_match("}}", EXPR_END_PATTERN, QUOTE_START_PATTERN)
if match.groupdict().get("expr_end") is not None:
break
else:
# it's a quote. we haven't advanced for this match yet, so
# just slurp up the whole string, no need to rewind.
match = self._expect_match('string', STRING_PATTERN)
match = self._expect_match("string", STRING_PATTERN)
self.advance(match.end())
self.advance(match.end())
def handle_comment(self, match):
self.advance(match.end())
match = self._expect_match('#}', COMMENT_END_PATTERN)
match = self._expect_match("#}", COMMENT_END_PATTERN)
self.advance(match.end())
def _expect_block_close(self):
@@ -188,22 +196,19 @@ class TagIterator:
"""
while True:
end_match = self._expect_match(
'tag close ("%}")',
QUOTE_START_PATTERN,
TAG_CLOSE_PATTERN
'tag close ("%}")', QUOTE_START_PATTERN, TAG_CLOSE_PATTERN
)
self.advance(end_match.end())
if end_match.groupdict().get('tag_close') is not None:
if end_match.groupdict().get("tag_close") is not None:
return
# must be a string. Rewind to its start and advance past it.
self.rewind()
string_match = self._expect_match('string', STRING_PATTERN)
string_match = self._expect_match("string", STRING_PATTERN)
self.advance(string_match.end())
def handle_raw(self):
# raw blocks are super special, they are a single complete regex
match = self._expect_match('{% raw %}...{% endraw %}',
RAW_BLOCK_PATTERN)
match = self._expect_match("{% raw %}...{% endraw %}", RAW_BLOCK_PATTERN)
self.advance(match.end())
return match.end()
@@ -220,13 +225,12 @@ class TagIterator:
"""
groups = match.groupdict()
# always a value
block_type_name = groups['block_type_name']
block_type_name = groups["block_type_name"]
# might be None
block_name = groups.get('block_name')
block_name = groups.get("block_name")
start_pos = self.pos
if block_type_name == 'raw':
match = self._expect_match('{% raw %}...{% endraw %}',
RAW_BLOCK_PATTERN)
if block_type_name == "raw":
match = self._expect_match("{% raw %}...{% endraw %}", RAW_BLOCK_PATTERN)
self.advance(match.end())
else:
self.advance(match.end())
@@ -235,15 +239,13 @@ class TagIterator:
block_type_name=block_type_name,
block_name=block_name,
start=start_pos,
end=self.pos
end=self.pos,
)
def find_tags(self):
while True:
match = self._first_match(
BLOCK_START_PATTERN,
COMMENT_START_PATTERN,
EXPR_START_PATTERN
BLOCK_START_PATTERN, COMMENT_START_PATTERN, EXPR_START_PATTERN
)
if match is None:
break
@@ -252,9 +254,9 @@ class TagIterator:
# start = self.pos
groups = match.groupdict()
comment_start = groups.get('comment_start')
expr_start = groups.get('expr_start')
block_type_name = groups.get('block_type_name')
comment_start = groups.get("comment_start")
expr_start = groups.get("expr_start")
block_type_name = groups.get("block_type_name")
if comment_start is not None:
self.handle_comment(match)
@@ -264,8 +266,8 @@ class TagIterator:
yield self.handle_tag(match)
else:
raise dbt.exceptions.InternalException(
'Invalid regex match in next_block, expected block start, '
'expr start, or comment start'
"Invalid regex match in next_block, expected block start, "
"expr start, or comment start"
)
def __iter__(self):
@@ -273,21 +275,18 @@ class TagIterator:
duplicate_tags = (
'Got nested tags: {outer.block_type_name} (started at {outer.start}) did '
'not have a matching {{% end{outer.block_type_name} %}} before a '
'subsequent {inner.block_type_name} was found (started at {inner.start})'
"Got nested tags: {outer.block_type_name} (started at {outer.start}) did "
"not have a matching {{% end{outer.block_type_name} %}} before a "
"subsequent {inner.block_type_name} was found (started at {inner.start})"
)
_CONTROL_FLOW_TAGS = {
'if': 'endif',
'for': 'endfor',
"if": "endif",
"for": "endfor",
}
_CONTROL_FLOW_END_TAGS = {
v: k
for k, v in _CONTROL_FLOW_TAGS.items()
}
_CONTROL_FLOW_END_TAGS = {v: k for k, v in _CONTROL_FLOW_TAGS.items()}
class BlockIterator:
@@ -310,15 +309,15 @@ class BlockIterator:
def is_current_end(self, tag):
return (
tag.block_type_name.startswith('end') and
self.current is not None and
tag.block_type_name[3:] == self.current.block_type_name
tag.block_type_name.startswith("end")
and self.current is not None
and tag.block_type_name[3:] == self.current.block_type_name
)
def find_blocks(self, allowed_blocks=None, collect_raw_data=True):
"""Find all top-level blocks in the data."""
if allowed_blocks is None:
allowed_blocks = {'snapshot', 'macro', 'materialization', 'docs'}
allowed_blocks = {"snapshot", "macro", "materialization", "docs"}
for tag in self.tag_parser.find_tags():
if tag.block_type_name in _CONTROL_FLOW_TAGS:
@@ -329,37 +328,43 @@ class BlockIterator:
found = self.stack.pop()
else:
expected = _CONTROL_FLOW_END_TAGS[tag.block_type_name]
dbt.exceptions.raise_compiler_error((
'Got an unexpected control flow end tag, got {} but '
'never saw a preceeding {} (@ {})'
).format(
tag.block_type_name,
expected,
self.tag_parser.linepos(tag.start)
))
dbt.exceptions.raise_compiler_error(
(
"Got an unexpected control flow end tag, got {} but "
"never saw a preceeding {} (@ {})"
).format(
tag.block_type_name,
expected,
self.tag_parser.linepos(tag.start),
)
)
expected = _CONTROL_FLOW_TAGS[found]
if expected != tag.block_type_name:
dbt.exceptions.raise_compiler_error((
'Got an unexpected control flow end tag, got {} but '
'expected {} next (@ {})'
).format(
tag.block_type_name,
expected,
self.tag_parser.linepos(tag.start)
))
dbt.exceptions.raise_compiler_error(
(
"Got an unexpected control flow end tag, got {} but "
"expected {} next (@ {})"
).format(
tag.block_type_name,
expected,
self.tag_parser.linepos(tag.start),
)
)
if tag.block_type_name in allowed_blocks:
if self.stack:
dbt.exceptions.raise_compiler_error((
'Got a block definition inside control flow at {}. '
'All dbt block definitions must be at the top level'
).format(self.tag_parser.linepos(tag.start)))
dbt.exceptions.raise_compiler_error(
(
"Got a block definition inside control flow at {}. "
"All dbt block definitions must be at the top level"
).format(self.tag_parser.linepos(tag.start))
)
if self.current is not None:
dbt.exceptions.raise_compiler_error(
duplicate_tags.format(outer=self.current, inner=tag)
)
if collect_raw_data:
raw_data = self.data[self.last_position:tag.start]
raw_data = self.data[self.last_position : tag.start]
self.last_position = tag.start
if raw_data:
yield BlockData(raw_data)
@@ -371,23 +376,28 @@ class BlockIterator:
yield BlockTag(
block_type_name=self.current.block_type_name,
block_name=self.current.block_name,
contents=self.data[self.current.end:tag.start],
full_block=self.data[self.current.start:tag.end]
contents=self.data[self.current.end : tag.start],
full_block=self.data[self.current.start : tag.end],
)
self.current = None
if self.current:
linecount = self.data[:self.current.end].count('\n') + 1
dbt.exceptions.raise_compiler_error((
'Reached EOF without finding a close tag for '
'{} (searched from line {})'
).format(self.current.block_type_name, linecount))
linecount = self.data[: self.current.end].count("\n") + 1
dbt.exceptions.raise_compiler_error(
(
"Reached EOF without finding a close tag for "
"{} (searched from line {})"
).format(self.current.block_type_name, linecount)
)
if collect_raw_data:
raw_data = self.data[self.last_position:]
raw_data = self.data[self.last_position :]
if raw_data:
yield BlockData(raw_data)
def lex_for_blocks(self, allowed_blocks=None, collect_raw_data=True):
return list(self.find_blocks(allowed_blocks=allowed_blocks,
collect_raw_data=collect_raw_data))
return list(
self.find_blocks(
allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data
)
)

View File

@@ -10,7 +10,7 @@ from typing import Iterable, List, Dict, Union, Optional, Any
from dbt.exceptions import RuntimeException
BOM = BOM_UTF8.decode('utf-8') # '\ufeff'
BOM = BOM_UTF8.decode("utf-8") # '\ufeff'
class ISODateTime(agate.data_types.DateTime):
@@ -30,28 +30,23 @@ class ISODateTime(agate.data_types.DateTime):
except: # noqa
pass
raise agate.exceptions.CastError(
'Can not parse value "%s" as datetime.' % d
)
raise agate.exceptions.CastError('Can not parse value "%s" as datetime.' % d)
def build_type_tester(text_columns: Iterable[str]) -> agate.TypeTester:
types = [
agate.data_types.Number(null_values=('null', '')),
agate.data_types.Date(null_values=('null', ''),
date_format='%Y-%m-%d'),
agate.data_types.DateTime(null_values=('null', ''),
datetime_format='%Y-%m-%d %H:%M:%S'),
ISODateTime(null_values=('null', '')),
agate.data_types.Boolean(true_values=('true',),
false_values=('false',),
null_values=('null', '')),
agate.data_types.Text(null_values=('null', ''))
agate.data_types.Number(null_values=("null", "")),
agate.data_types.Date(null_values=("null", ""), date_format="%Y-%m-%d"),
agate.data_types.DateTime(
null_values=("null", ""), datetime_format="%Y-%m-%d %H:%M:%S"
),
ISODateTime(null_values=("null", "")),
agate.data_types.Boolean(
true_values=("true",), false_values=("false",), null_values=("null", "")
),
agate.data_types.Text(null_values=("null", "")),
]
force = {
k: agate.data_types.Text(null_values=('null', ''))
for k in text_columns
}
force = {k: agate.data_types.Text(null_values=("null", "")) for k in text_columns}
return agate.TypeTester(force=force, types=types)
@@ -115,7 +110,7 @@ def as_matrix(table):
def from_csv(abspath, text_columns):
type_tester = build_type_tester(text_columns=text_columns)
with open(abspath, encoding='utf-8') as fp:
with open(abspath, encoding="utf-8") as fp:
if fp.read(1) != BOM:
fp.seek(0)
return agate.Table.from_csv(fp, column_types=type_tester)
@@ -147,8 +142,8 @@ class ColumnTypeBuilder(Dict[str, NullableAgateType]):
elif not isinstance(value, type(existing_type)):
# actual type mismatch!
raise RuntimeException(
f'Tables contain columns with the same names ({key}), '
f'but different types ({value} vs {existing_type})'
f"Tables contain columns with the same names ({key}), "
f"but different types ({value} vs {existing_type})"
)
def finalize(self) -> Dict[str, agate.data_types.DataType]:
@@ -163,7 +158,7 @@ class ColumnTypeBuilder(Dict[str, NullableAgateType]):
def _merged_column_types(
tables: List[agate.Table]
tables: List[agate.Table],
) -> Dict[str, agate.data_types.DataType]:
# this is a lot like agate.Table.merge, but with handling for all-null
# rows being "any type".
@@ -190,10 +185,7 @@ def merge_tables(tables: List[agate.Table]) -> agate.Table:
rows: List[agate.Row] = []
for table in tables:
if (
table.column_names == column_names and
table.column_types == column_types
):
if table.column_names == column_names and table.column_types == column_types:
rows.extend(table.rows)
else:
for row in table.rows:

View File

@@ -12,7 +12,7 @@ https://cloud.google.com/sdk/
def gcloud_installed():
try:
run_cmd('.', ['gcloud', '--version'])
run_cmd(".", ["gcloud", "--version"])
return True
except OSError as e:
logger.debug(e)
@@ -21,6 +21,6 @@ def gcloud_installed():
def setup_default_credentials():
if gcloud_installed():
run_cmd('.', ["gcloud", "auth", "application-default", "login"])
run_cmd(".", ["gcloud", "auth", "application-default", "login"])
else:
raise dbt.exceptions.RuntimeException(NOT_INSTALLED_MSG)

View File

@@ -7,77 +7,74 @@ import dbt.exceptions
def clone(repo, cwd, dirname=None, remove_git_dir=False, branch=None):
clone_cmd = ['git', 'clone', '--depth', '1']
clone_cmd = ["git", "clone", "--depth", "1"]
if branch is not None:
clone_cmd.extend(['--branch', branch])
clone_cmd.extend(["--branch", branch])
clone_cmd.append(repo)
if dirname is not None:
clone_cmd.append(dirname)
result = run_cmd(cwd, clone_cmd, env={'LC_ALL': 'C'})
result = run_cmd(cwd, clone_cmd, env={"LC_ALL": "C"})
if remove_git_dir:
rmdir(os.path.join(dirname, '.git'))
rmdir(os.path.join(dirname, ".git"))
return result
def list_tags(cwd):
out, err = run_cmd(cwd, ['git', 'tag', '--list'], env={'LC_ALL': 'C'})
tags = out.decode('utf-8').strip().split("\n")
out, err = run_cmd(cwd, ["git", "tag", "--list"], env={"LC_ALL": "C"})
tags = out.decode("utf-8").strip().split("\n")
return tags
def _checkout(cwd, repo, branch):
logger.debug(' Checking out branch {}.'.format(branch))
logger.debug(" Checking out branch {}.".format(branch))
run_cmd(cwd, ['git', 'remote', 'set-branches', 'origin', branch])
run_cmd(cwd, ['git', 'fetch', '--tags', '--depth', '1', 'origin', branch])
run_cmd(cwd, ["git", "remote", "set-branches", "origin", branch])
run_cmd(cwd, ["git", "fetch", "--tags", "--depth", "1", "origin", branch])
tags = list_tags(cwd)
# Prefer tags to branches if one exists
if branch in tags:
spec = 'tags/{}'.format(branch)
spec = "tags/{}".format(branch)
else:
spec = 'origin/{}'.format(branch)
spec = "origin/{}".format(branch)
out, err = run_cmd(cwd, ['git', 'reset', '--hard', spec],
env={'LC_ALL': 'C'})
out, err = run_cmd(cwd, ["git", "reset", "--hard", spec], env={"LC_ALL": "C"})
return out, err
def checkout(cwd, repo, branch=None):
if branch is None:
branch = 'HEAD'
branch = "HEAD"
try:
return _checkout(cwd, repo, branch)
except dbt.exceptions.CommandResultError as exc:
stderr = exc.stderr.decode('utf-8').strip()
stderr = exc.stderr.decode("utf-8").strip()
dbt.exceptions.bad_package_spec(repo, branch, stderr)
def get_current_sha(cwd):
out, err = run_cmd(cwd, ['git', 'rev-parse', 'HEAD'], env={'LC_ALL': 'C'})
out, err = run_cmd(cwd, ["git", "rev-parse", "HEAD"], env={"LC_ALL": "C"})
return out.decode('utf-8')
return out.decode("utf-8")
def remove_remote(cwd):
return run_cmd(cwd, ['git', 'remote', 'rm', 'origin'], env={'LC_ALL': 'C'})
return run_cmd(cwd, ["git", "remote", "rm", "origin"], env={"LC_ALL": "C"})
def clone_and_checkout(repo, cwd, dirname=None, remove_git_dir=False,
branch=None):
def clone_and_checkout(repo, cwd, dirname=None, remove_git_dir=False, branch=None):
exists = None
try:
_, err = clone(repo, cwd, dirname=dirname,
remove_git_dir=remove_git_dir)
_, err = clone(repo, cwd, dirname=dirname, remove_git_dir=remove_git_dir)
except dbt.exceptions.CommandResultError as exc:
err = exc.stderr.decode('utf-8')
err = exc.stderr.decode("utf-8")
exists = re.match("fatal: destination path '(.+)' already exists", err)
if not exists: # something else is wrong, raise it
raise
@@ -86,25 +83,26 @@ def clone_and_checkout(repo, cwd, dirname=None, remove_git_dir=False,
start_sha = None
if exists:
directory = exists.group(1)
logger.debug('Updating existing dependency {}.', directory)
logger.debug("Updating existing dependency {}.", directory)
else:
matches = re.match("Cloning into '(.+)'", err.decode('utf-8'))
matches = re.match("Cloning into '(.+)'", err.decode("utf-8"))
if matches is None:
raise dbt.exceptions.RuntimeException(
f'Error cloning {repo} - never saw "Cloning into ..." from git'
)
directory = matches.group(1)
logger.debug('Pulling new dependency {}.', directory)
logger.debug("Pulling new dependency {}.", directory)
full_path = os.path.join(cwd, directory)
start_sha = get_current_sha(full_path)
checkout(full_path, repo, branch)
end_sha = get_current_sha(full_path)
if exists:
if start_sha == end_sha:
logger.debug(' Already at {}, nothing to do.', start_sha[:7])
logger.debug(" Already at {}, nothing to do.", start_sha[:7])
else:
logger.debug(' Updated checkout from {} to {}.',
start_sha[:7], end_sha[:7])
logger.debug(
" Updated checkout from {} to {}.", start_sha[:7], end_sha[:7]
)
else:
logger.debug(' Checked out at {}.', end_sha[:7])
logger.debug(" Checked out at {}.", end_sha[:7])
return directory

View File

@@ -8,8 +8,17 @@ from ast import literal_eval
from contextlib import contextmanager
from itertools import chain, islice
from typing import (
List, Union, Set, Optional, Dict, Any, Iterator, Type, NoReturn, Tuple,
Callable
List,
Union,
Set,
Optional,
Dict,
Any,
Iterator,
Type,
NoReturn,
Tuple,
Callable,
)
import jinja2
@@ -20,16 +29,22 @@ import jinja2.parser
import jinja2.sandbox
from dbt.utils import (
get_dbt_macro_name, get_docs_macro_name, get_materialization_macro_name,
deep_map
get_dbt_macro_name,
get_docs_macro_name,
get_materialization_macro_name,
deep_map,
)
from dbt.clients._jinja_blocks import BlockIterator, BlockData, BlockTag
from dbt.contracts.graph.compiled import CompiledSchemaTestNode
from dbt.contracts.graph.parsed import ParsedSchemaTestNode
from dbt.exceptions import (
InternalException, raise_compiler_error, CompilationException,
invalid_materialization_argument, MacroReturn, JinjaRenderingException
InternalException,
raise_compiler_error,
CompilationException,
invalid_materialization_argument,
MacroReturn,
JinjaRenderingException,
)
from dbt import flags
from dbt.logger import GLOBAL_LOGGER as logger # noqa
@@ -40,26 +55,26 @@ def _linecache_inject(source, write):
# this is the only reliable way to accomplish this. Obviously, it's
# really darn noisy and will fill your temporary directory
tmp_file = tempfile.NamedTemporaryFile(
prefix='dbt-macro-compiled-',
suffix='.py',
prefix="dbt-macro-compiled-",
suffix=".py",
delete=False,
mode='w+',
encoding='utf-8',
mode="w+",
encoding="utf-8",
)
tmp_file.write(source)
filename = tmp_file.name
else:
# `codecs.encode` actually takes a `bytes` as the first argument if
# the second argument is 'hex' - mypy does not know this.
rnd = codecs.encode(os.urandom(12), 'hex') # type: ignore
filename = rnd.decode('ascii')
rnd = codecs.encode(os.urandom(12), "hex") # type: ignore
filename = rnd.decode("ascii")
# put ourselves in the cache
cache_entry = (
len(source),
None,
[line + '\n' for line in source.splitlines()],
filename
[line + "\n" for line in source.splitlines()],
filename,
)
# linecache does in fact have an attribute `cache`, thanks
linecache.cache[filename] = cache_entry # type: ignore
@@ -73,12 +88,10 @@ class MacroFuzzParser(jinja2.parser.Parser):
# modified to fuzz macros defined in the same file. this way
# dbt can understand the stack of macros being called.
# - @cmcarthur
node.name = get_dbt_macro_name(
self.parse_assign_target(name_only=True).name)
node.name = get_dbt_macro_name(self.parse_assign_target(name_only=True).name)
self.parse_signature(node)
node.body = self.parse_statements(('name:endmacro',),
drop_needle=True)
node.body = self.parse_statements(("name:endmacro",), drop_needle=True)
return node
@@ -94,8 +107,8 @@ class MacroFuzzEnvironment(jinja2.sandbox.SandboxedEnvironment):
If the value is 'write', also write the files to disk.
WARNING: This can write a ton of data if you aren't careful.
"""
if filename == '<template>' and flags.MACRO_DEBUGGING:
write = flags.MACRO_DEBUGGING == 'write'
if filename == "<template>" and flags.MACRO_DEBUGGING:
write = flags.MACRO_DEBUGGING == "write"
filename = _linecache_inject(source, write)
return super()._compile(source, filename) # type: ignore
@@ -138,7 +151,7 @@ def quoted_native_concat(nodes):
head = list(islice(nodes, 2))
if not head:
return ''
return ""
if len(head) == 1:
raw = head[0]
@@ -180,9 +193,7 @@ class NativeSandboxTemplate(jinja2.nativetypes.NativeTemplate): # mypy: ignore
vars = dict(*args, **kwargs)
try:
return quoted_native_concat(
self.root_render_func(self.new_context(vars))
)
return quoted_native_concat(self.root_render_func(self.new_context(vars)))
except Exception:
return self.environment.handle_exception()
@@ -221,10 +232,10 @@ class BaseMacroGenerator:
self.context: Optional[Dict[str, Any]] = context
def get_template(self):
raise NotImplementedError('get_template not implemented!')
raise NotImplementedError("get_template not implemented!")
def get_name(self) -> str:
raise NotImplementedError('get_name not implemented!')
raise NotImplementedError("get_name not implemented!")
def get_macro(self):
name = self.get_name()
@@ -247,9 +258,7 @@ class BaseMacroGenerator:
def call_macro(self, *args, **kwargs):
# called from __call__ methods
if self.context is None:
raise InternalException(
'Context is still None in call_macro!'
)
raise InternalException("Context is still None in call_macro!")
assert self.context is not None
macro = self.get_macro()
@@ -276,7 +285,7 @@ class MacroStack(threading.local):
def pop(self, name):
got = self.call_stack.pop()
if got != name:
raise InternalException(f'popped {got}, expected {name}')
raise InternalException(f"popped {got}, expected {name}")
class MacroGenerator(BaseMacroGenerator):
@@ -285,7 +294,7 @@ class MacroGenerator(BaseMacroGenerator):
macro,
context: Optional[Dict[str, Any]] = None,
node: Optional[Any] = None,
stack: Optional[MacroStack] = None
stack: Optional[MacroStack] = None,
) -> None:
super().__init__(context)
self.macro = macro
@@ -333,9 +342,7 @@ class MacroGenerator(BaseMacroGenerator):
class QueryStringGenerator(BaseMacroGenerator):
def __init__(
self, template_str: str, context: Dict[str, Any]
) -> None:
def __init__(self, template_str: str, context: Dict[str, Any]) -> None:
super().__init__(context)
self.template_str: str = template_str
env = get_environment()
@@ -345,7 +352,7 @@ class QueryStringGenerator(BaseMacroGenerator):
)
def get_name(self) -> str:
return 'query_comment_macro'
return "query_comment_macro"
def get_template(self):
"""Don't use the template cache, we don't have a node"""
@@ -356,45 +363,41 @@ class QueryStringGenerator(BaseMacroGenerator):
class MaterializationExtension(jinja2.ext.Extension):
tags = ['materialization']
tags = ["materialization"]
def parse(self, parser):
node = jinja2.nodes.Macro(lineno=next(parser.stream).lineno)
materialization_name = \
parser.parse_assign_target(name_only=True).name
materialization_name = parser.parse_assign_target(name_only=True).name
adapter_name = 'default'
adapter_name = "default"
node.args = []
node.defaults = []
while parser.stream.skip_if('comma'):
while parser.stream.skip_if("comma"):
target = parser.parse_assign_target(name_only=True)
if target.name == 'default':
if target.name == "default":
pass
elif target.name == 'adapter':
parser.stream.expect('assign')
elif target.name == "adapter":
parser.stream.expect("assign")
value = parser.parse_expression()
adapter_name = value.value
else:
invalid_materialization_argument(
materialization_name, target.name
)
invalid_materialization_argument(materialization_name, target.name)
node.name = get_materialization_macro_name(
materialization_name, adapter_name
node.name = get_materialization_macro_name(materialization_name, adapter_name)
node.body = parser.parse_statements(
("name:endmaterialization",), drop_needle=True
)
node.body = parser.parse_statements(('name:endmaterialization',),
drop_needle=True)
return node
class DocumentationExtension(jinja2.ext.Extension):
tags = ['docs']
tags = ["docs"]
def parse(self, parser):
node = jinja2.nodes.Macro(lineno=next(parser.stream).lineno)
@@ -403,13 +406,12 @@ class DocumentationExtension(jinja2.ext.Extension):
node.args = []
node.defaults = []
node.name = get_docs_macro_name(docs_name)
node.body = parser.parse_statements(('name:enddocs',),
drop_needle=True)
node.body = parser.parse_statements(("name:enddocs",), drop_needle=True)
return node
def _is_dunder_name(name):
return name.startswith('__') and name.endswith('__')
return name.startswith("__") and name.endswith("__")
def create_undefined(node=None):
@@ -430,10 +432,11 @@ def create_undefined(node=None):
return self
def __getattr__(self, name):
if name == 'name' or _is_dunder_name(name):
if name == "name" or _is_dunder_name(name):
raise AttributeError(
"'{}' object has no attribute '{}'"
.format(type(self).__name__, name)
"'{}' object has no attribute '{}'".format(
type(self).__name__, name
)
)
self.name = name
@@ -444,24 +447,24 @@ def create_undefined(node=None):
return self
def __reduce__(self):
raise_compiler_error(f'{self.name} is undefined', node=node)
raise_compiler_error(f"{self.name} is undefined", node=node)
return Undefined
NATIVE_FILTERS: Dict[str, Callable[[Any], Any]] = {
'as_text': TextMarker,
'as_bool': BoolMarker,
'as_native': NativeMarker,
'as_number': NumberMarker,
"as_text": TextMarker,
"as_bool": BoolMarker,
"as_native": NativeMarker,
"as_number": NumberMarker,
}
TEXT_FILTERS: Dict[str, Callable[[Any], Any]] = {
'as_text': lambda x: x,
'as_bool': lambda x: x,
'as_native': lambda x: x,
'as_number': lambda x: x,
"as_text": lambda x: x,
"as_bool": lambda x: x,
"as_native": lambda x: x,
"as_number": lambda x: x,
}
@@ -471,14 +474,14 @@ def get_environment(
native: bool = False,
) -> jinja2.Environment:
args: Dict[str, List[Union[str, Type[jinja2.ext.Extension]]]] = {
'extensions': ['jinja2.ext.do']
"extensions": ["jinja2.ext.do"]
}
if capture_macros:
args['undefined'] = create_undefined(node)
args["undefined"] = create_undefined(node)
args['extensions'].append(MaterializationExtension)
args['extensions'].append(DocumentationExtension)
args["extensions"].append(MaterializationExtension)
args["extensions"].append(DocumentationExtension)
env_cls: Type[jinja2.Environment]
text_filter: Type
@@ -541,8 +544,8 @@ def _requote_result(raw_value: str, rendered: str) -> str:
elif single_quoted:
quote_char = "'"
else:
quote_char = ''
return f'{quote_char}{rendered}{quote_char}'
quote_char = ""
return f"{quote_char}{rendered}{quote_char}"
# performance note: Local benmcharking (so take it with a big grain of salt!)
@@ -550,7 +553,7 @@ def _requote_result(raw_value: str, rendered: str) -> str:
# checking two separate patterns, but the standard deviation is smaller with
# one pattern. The time difference between the two was ~2 std deviations, which
# is small enough that I've just chosen the more readable option.
_HAS_RENDER_CHARS_PAT = re.compile(r'({[{%#]|[#}%]})')
_HAS_RENDER_CHARS_PAT = re.compile(r"({[{%#]|[#}%]})")
def get_rendered(
@@ -567,9 +570,9 @@ def get_rendered(
# native=True case by passing the input string to ast.literal_eval, like
# the native renderer does.
if (
not native and
isinstance(string, str) and
_HAS_RENDER_CHARS_PAT.search(string) is None
not native
and isinstance(string, str)
and _HAS_RENDER_CHARS_PAT.search(string) is None
):
return string
template = get_template(
@@ -606,12 +609,11 @@ def extract_toplevel_blocks(
`collect_raw_data` is `True`) `BlockData` objects.
"""
return BlockIterator(data).lex_for_blocks(
allowed_blocks=allowed_blocks,
collect_raw_data=collect_raw_data
allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data
)
SCHEMA_TEST_KWARGS_NAME = '_dbt_schema_test_kwargs'
SCHEMA_TEST_KWARGS_NAME = "_dbt_schema_test_kwargs"
def add_rendered_test_kwargs(
@@ -623,24 +625,21 @@ def add_rendered_test_kwargs(
renderer, then insert that value into the given context as the special test
keyword arguments member.
"""
looks_like_func = r'^\s*(env_var|ref|var|source|doc)\s*\(.+\)\s*$'
looks_like_func = r"^\s*(env_var|ref|var|source|doc)\s*\(.+\)\s*$"
def _convert_function(
value: Any, keypath: Tuple[Union[str, int], ...]
) -> Any:
def _convert_function(value: Any, keypath: Tuple[Union[str, int], ...]) -> Any:
if isinstance(value, str):
if keypath == ('column_name',):
if keypath == ("column_name",):
# special case: Don't render column names as native, make them
# be strings
return value
if re.match(looks_like_func, value) is not None:
# curly braces to make rendering happy
value = f'{{{{ {value} }}}}'
value = f"{{{{ {value} }}}}"
value = get_rendered(
value, context, node, capture_macros=capture_macros,
native=True
value, context, node, capture_macros=capture_macros, native=True
)
return value

View File

@@ -6,17 +6,17 @@ from dbt.logger import GLOBAL_LOGGER as logger
import os
import time
if os.getenv('DBT_PACKAGE_HUB_URL'):
DEFAULT_REGISTRY_BASE_URL = os.getenv('DBT_PACKAGE_HUB_URL')
if os.getenv("DBT_PACKAGE_HUB_URL"):
DEFAULT_REGISTRY_BASE_URL = os.getenv("DBT_PACKAGE_HUB_URL")
else:
DEFAULT_REGISTRY_BASE_URL = 'https://hub.getdbt.com/'
DEFAULT_REGISTRY_BASE_URL = "https://hub.getdbt.com/"
def _get_url(url, registry_base_url=None):
if registry_base_url is None:
registry_base_url = DEFAULT_REGISTRY_BASE_URL
return '{}{}'.format(registry_base_url, url)
return "{}{}".format(registry_base_url, url)
def _wrap_exceptions(fn):
@@ -33,42 +33,40 @@ def _wrap_exceptions(fn):
time.sleep(1)
continue
raise RegistryException(
'Unable to connect to registry hub'
) from exc
raise RegistryException("Unable to connect to registry hub") from exc
return wrapper
@_wrap_exceptions
def _get(path, registry_base_url=None):
url = _get_url(path, registry_base_url)
logger.debug('Making package registry request: GET {}'.format(url))
logger.debug("Making package registry request: GET {}".format(url))
resp = requests.get(url)
logger.debug('Response from registry: GET {} {}'.format(url,
resp.status_code))
logger.debug("Response from registry: GET {} {}".format(url, resp.status_code))
resp.raise_for_status()
return resp.json()
def index(registry_base_url=None):
return _get('api/v1/index.json', registry_base_url)
return _get("api/v1/index.json", registry_base_url)
index_cached = memoized(index)
def packages(registry_base_url=None):
return _get('api/v1/packages.json', registry_base_url)
return _get("api/v1/packages.json", registry_base_url)
def package(name, registry_base_url=None):
return _get('api/v1/{}.json'.format(name), registry_base_url)
return _get("api/v1/{}.json".format(name), registry_base_url)
def package_version(name, version, registry_base_url=None):
return _get('api/v1/{}/{}.json'.format(name, version), registry_base_url)
return _get("api/v1/{}/{}.json".format(name, version), registry_base_url)
def get_available_versions(name):
response = package(name)
return list(response['versions'])
return list(response["versions"])

View File

@@ -10,16 +10,14 @@ import sys
import tarfile
import requests
import stat
from typing import (
Type, NoReturn, List, Optional, Dict, Any, Tuple, Callable, Union
)
from typing import Type, NoReturn, List, Optional, Dict, Any, Tuple, Callable, Union
import dbt.exceptions
import dbt.utils
from dbt.logger import GLOBAL_LOGGER as logger
if sys.platform == 'win32':
if sys.platform == "win32":
from ctypes import WinDLL, c_bool
else:
WinDLL = None
@@ -51,30 +49,29 @@ def find_matching(
reobj = re.compile(regex, re.IGNORECASE)
for relative_path_to_search in relative_paths_to_search:
absolute_path_to_search = os.path.join(
root_path, relative_path_to_search)
absolute_path_to_search = os.path.join(root_path, relative_path_to_search)
walk_results = os.walk(absolute_path_to_search)
for current_path, subdirectories, local_files in walk_results:
for local_file in local_files:
absolute_path = os.path.join(current_path, local_file)
relative_path = os.path.relpath(
absolute_path, absolute_path_to_search
)
relative_path = os.path.relpath(absolute_path, absolute_path_to_search)
if reobj.match(local_file):
matching.append({
'searched_path': relative_path_to_search,
'absolute_path': absolute_path,
'relative_path': relative_path,
})
matching.append(
{
"searched_path": relative_path_to_search,
"absolute_path": absolute_path,
"relative_path": relative_path,
}
)
return matching
def load_file_contents(path: str, strip: bool = True) -> str:
path = convert_path(path)
with open(path, 'rb') as handle:
to_return = handle.read().decode('utf-8')
with open(path, "rb") as handle:
to_return = handle.read().decode("utf-8")
if strip:
to_return = to_return.strip()
@@ -101,14 +98,14 @@ def make_directory(path: str) -> None:
raise e
def make_file(path: str, contents: str = '', overwrite: bool = False) -> bool:
def make_file(path: str, contents: str = "", overwrite: bool = False) -> bool:
"""
Make a file at `path` assuming that the directory it resides in already
exists. The file is saved with contents `contents`
"""
if overwrite or not os.path.exists(path):
path = convert_path(path)
with open(path, 'w') as fh:
with open(path, "w") as fh:
fh.write(contents)
return True
@@ -120,7 +117,7 @@ def make_symlink(source: str, link_path: str) -> None:
Create a symlink at `link_path` referring to `source`.
"""
if not supports_symlinks():
dbt.exceptions.system_error('create a symbolic link')
dbt.exceptions.system_error("create a symbolic link")
os.symlink(source, link_path)
@@ -129,11 +126,11 @@ def supports_symlinks() -> bool:
return getattr(os, "symlink", None) is not None
def write_file(path: str, contents: str = '') -> bool:
def write_file(path: str, contents: str = "") -> bool:
path = convert_path(path)
try:
make_directory(os.path.dirname(path))
with open(path, 'w', encoding='utf-8') as f:
with open(path, "w", encoding="utf-8") as f:
f.write(str(contents))
except Exception as exc:
# note that you can't just catch FileNotFound, because sometimes
@@ -142,20 +139,20 @@ def write_file(path: str, contents: str = '') -> bool:
# sometimes windows fails to write paths that are less than the length
# limit. So on windows, suppress all errors that happen from writing
# to disk.
if os.name == 'nt':
if os.name == "nt":
# sometimes we get a winerror of 3 which means the path was
# definitely too long, but other times we don't and it means the
# path was just probably too long. This is probably based on the
# windows/python version.
if getattr(exc, 'winerror', 0) == 3:
reason = 'Path was too long'
if getattr(exc, "winerror", 0) == 3:
reason = "Path was too long"
else:
reason = 'Path was possibly too long'
reason = "Path was possibly too long"
# all our hard work and the path was still too long. Log and
# continue.
logger.debug(
f'Could not write to path {path}({len(path)} characters): '
f'{reason}\nexception: {exc}'
f"Could not write to path {path}({len(path)} characters): "
f"{reason}\nexception: {exc}"
)
else:
raise
@@ -189,10 +186,7 @@ def resolve_path_from_base(path_to_resolve: str, base_path: str) -> str:
If path_to_resolve is an absolute path or a user path (~), just
resolve it to an absolute path and return.
"""
return os.path.abspath(
os.path.join(
base_path,
os.path.expanduser(path_to_resolve)))
return os.path.abspath(os.path.join(base_path, os.path.expanduser(path_to_resolve)))
def rmdir(path: str) -> None:
@@ -202,7 +196,7 @@ def rmdir(path: str) -> None:
cloned via git) can cause rmtree to throw a PermissionError exception
"""
path = convert_path(path)
if sys.platform == 'win32':
if sys.platform == "win32":
onerror = _windows_rmdir_readonly
else:
onerror = None
@@ -221,7 +215,7 @@ def _win_prepare_path(path: str) -> str:
# letter back in.
# Unless it starts with '\\'. In that case, the path is a UNC mount point
# and splitdrive will be fine.
if not path.startswith('\\\\') and path.startswith('\\'):
if not path.startswith("\\\\") and path.startswith("\\"):
curdrive = os.path.splitdrive(os.getcwd())[0]
path = curdrive + path
@@ -236,7 +230,7 @@ def _win_prepare_path(path: str) -> str:
def _supports_long_paths() -> bool:
if sys.platform != 'win32':
if sys.platform != "win32":
return True
# Eryk Sun says to use `WinDLL('ntdll')` instead of `windll.ntdll` because
# of pointer caching in a comment here:
@@ -244,11 +238,11 @@ def _supports_long_paths() -> bool:
# I don't know exaclty what he means, but I am inclined to believe him as
# he's pretty active on Python windows bugs!
try:
dll = WinDLL('ntdll')
dll = WinDLL("ntdll")
except OSError: # I don't think this happens? you need ntdll to run python
return False
# not all windows versions have it at all
if not hasattr(dll, 'RtlAreLongPathsEnabled'):
if not hasattr(dll, "RtlAreLongPathsEnabled"):
return False
# tell windows we want to get back a single unsigned byte (a bool).
dll.RtlAreLongPathsEnabled.restype = c_bool
@@ -268,7 +262,7 @@ def convert_path(path: str) -> str:
if _supports_long_paths():
return path
prefix = '\\\\?\\'
prefix = "\\\\?\\"
# Nothing to do
if path.startswith(prefix):
return path
@@ -299,39 +293,35 @@ def path_is_symlink(path: str) -> bool:
def open_dir_cmd() -> str:
# https://docs.python.org/2/library/sys.html#sys.platform
if sys.platform == 'win32':
return 'start'
if sys.platform == "win32":
return "start"
elif sys.platform == 'darwin':
return 'open'
elif sys.platform == "darwin":
return "open"
else:
return 'xdg-open'
return "xdg-open"
def _handle_posix_cwd_error(
exc: OSError, cwd: str, cmd: List[str]
) -> NoReturn:
def _handle_posix_cwd_error(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn:
if exc.errno == errno.ENOENT:
message = 'Directory does not exist'
message = "Directory does not exist"
elif exc.errno == errno.EACCES:
message = 'Current user cannot access directory, check permissions'
message = "Current user cannot access directory, check permissions"
elif exc.errno == errno.ENOTDIR:
message = 'Not a directory'
message = "Not a directory"
else:
message = 'Unknown OSError: {} - cwd'.format(str(exc))
message = "Unknown OSError: {} - cwd".format(str(exc))
raise dbt.exceptions.WorkingDirectoryError(cwd, cmd, message)
def _handle_posix_cmd_error(
exc: OSError, cwd: str, cmd: List[str]
) -> NoReturn:
def _handle_posix_cmd_error(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn:
if exc.errno == errno.ENOENT:
message = "Could not find command, ensure it is in the user's PATH"
elif exc.errno == errno.EACCES:
message = 'User does not have permissions for this command'
message = "User does not have permissions for this command"
else:
message = 'Unknown OSError: {} - cmd'.format(str(exc))
message = "Unknown OSError: {} - cmd".format(str(exc))
raise dbt.exceptions.ExecutableError(cwd, cmd, message)
@@ -356,7 +346,7 @@ def _handle_posix_error(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn:
- exc.errno == EACCES
- exc.filename == None(?)
"""
if getattr(exc, 'filename', None) == cwd:
if getattr(exc, "filename", None) == cwd:
_handle_posix_cwd_error(exc, cwd, cmd)
else:
_handle_posix_cmd_error(exc, cwd, cmd)
@@ -365,46 +355,48 @@ def _handle_posix_error(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn:
def _handle_windows_error(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn:
cls: Type[dbt.exceptions.Exception] = dbt.exceptions.CommandError
if exc.errno == errno.ENOENT:
message = ("Could not find command, ensure it is in the user's PATH "
"and that the user has permissions to run it")
message = (
"Could not find command, ensure it is in the user's PATH "
"and that the user has permissions to run it"
)
cls = dbt.exceptions.ExecutableError
elif exc.errno == errno.ENOEXEC:
message = ('Command was not executable, ensure it is valid')
message = "Command was not executable, ensure it is valid"
cls = dbt.exceptions.ExecutableError
elif exc.errno == errno.ENOTDIR:
message = ('Unable to cd: path does not exist, user does not have'
' permissions, or not a directory')
message = (
"Unable to cd: path does not exist, user does not have"
" permissions, or not a directory"
)
cls = dbt.exceptions.WorkingDirectoryError
else:
message = 'Unknown error: {} (errno={}: "{}")'.format(
str(exc), exc.errno, errno.errorcode.get(exc.errno, '<Unknown!>')
str(exc), exc.errno, errno.errorcode.get(exc.errno, "<Unknown!>")
)
raise cls(cwd, cmd, message)
def _interpret_oserror(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn:
"""Interpret an OSError exc and raise the appropriate dbt exception.
"""
"""Interpret an OSError exc and raise the appropriate dbt exception."""
if len(cmd) == 0:
raise dbt.exceptions.CommandError(cwd, cmd)
# all of these functions raise unconditionally
if os.name == 'nt':
if os.name == "nt":
_handle_windows_error(exc, cwd, cmd)
else:
_handle_posix_error(exc, cwd, cmd)
# this should not be reachable, raise _something_ at least!
raise dbt.exceptions.InternalException(
'Unhandled exception in _interpret_oserror: {}'.format(exc)
"Unhandled exception in _interpret_oserror: {}".format(exc)
)
def run_cmd(
cwd: str, cmd: List[str], env: Optional[Dict[str, Any]] = None
) -> Tuple[bytes, bytes]:
logger.debug('Executing "{}"'.format(' '.join(cmd)))
logger.debug('Executing "{}"'.format(" ".join(cmd)))
if len(cmd) == 0:
raise dbt.exceptions.CommandError(cwd, cmd)
@@ -417,11 +409,8 @@ def run_cmd(
try:
proc = subprocess.Popen(
cmd,
cwd=cwd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=full_env)
cmd, cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=full_env
)
out, err = proc.communicate()
except OSError as exc:
@@ -431,9 +420,8 @@ def run_cmd(
logger.debug('STDERR: "{!s}"'.format(err))
if proc.returncode != 0:
logger.debug('command return code={}'.format(proc.returncode))
raise dbt.exceptions.CommandResultError(cwd, cmd, proc.returncode,
out, err)
logger.debug("command return code={}".format(proc.returncode))
raise dbt.exceptions.CommandResultError(cwd, cmd, proc.returncode, out, err)
return out, err
@@ -442,9 +430,9 @@ def download(
url: str, path: str, timeout: Optional[Union[float, tuple]] = None
) -> None:
path = convert_path(path)
connection_timeout = timeout or float(os.getenv('DBT_HTTP_TIMEOUT', 10))
connection_timeout = timeout or float(os.getenv("DBT_HTTP_TIMEOUT", 10))
response = requests.get(url, timeout=connection_timeout)
with open(path, 'wb') as handle:
with open(path, "wb") as handle:
for block in response.iter_content(1024 * 64):
handle.write(block)
@@ -468,7 +456,7 @@ def untar_package(
) -> None:
tar_path = convert_path(tar_path)
tar_dir_name = None
with tarfile.open(tar_path, 'r') as tarball:
with tarfile.open(tar_path, "r") as tarball:
tarball.extractall(dest_dir)
tar_dir_name = os.path.commonprefix(tarball.getnames())
if rename_to:
@@ -484,7 +472,7 @@ def chmod_and_retry(func, path, exc_info):
We want to retry most operations here, but listdir is one that we know will
be useless.
"""
if func is os.listdir or os.name != 'nt':
if func is os.listdir or os.name != "nt":
raise
os.chmod(path, stat.S_IREAD | stat.S_IWRITE)
# on error,this will raise.
@@ -505,7 +493,7 @@ def move(src, dst):
"""
src = convert_path(src)
dst = convert_path(dst)
if os.name != 'nt':
if os.name != "nt":
return shutil.move(src, dst)
if os.path.isdir(dst):
@@ -513,7 +501,7 @@ def move(src, dst):
os.rename(src, dst)
return
dst = os.path.join(dst, os.path.basename(src.rstrip('/\\')))
dst = os.path.join(dst, os.path.basename(src.rstrip("/\\")))
if os.path.exists(dst):
raise EnvironmentError("Path '{}' already exists".format(dst))
@@ -522,11 +510,10 @@ def move(src, dst):
except OSError:
# probably different drives
if os.path.isdir(src):
if _absnorm(dst + '\\').startswith(_absnorm(src + '\\')):
if _absnorm(dst + "\\").startswith(_absnorm(src + "\\")):
# dst is inside src
raise EnvironmentError(
"Cannot move a directory '{}' into itself '{}'"
.format(src, dst)
"Cannot move a directory '{}' into itself '{}'".format(src, dst)
)
shutil.copytree(src, dst, symlinks=True)
rmtree(src)

View File

@@ -5,15 +5,9 @@ import yaml.scanner
# the C version is faster, but it doesn't always exist
try:
from yaml import (
CLoader as Loader,
CSafeLoader as SafeLoader,
CDumper as Dumper
)
from yaml import CLoader as Loader, CSafeLoader as SafeLoader, CDumper as Dumper
except ImportError:
from yaml import ( # type: ignore # noqa: F401
Loader, SafeLoader, Dumper
)
from yaml import Loader, SafeLoader, Dumper # type: ignore # noqa: F401
YAML_ERROR_MESSAGE = """
@@ -33,14 +27,14 @@ def line_no(i, line, width=3):
def prefix_with_line_numbers(string, no_start, no_end):
line_list = string.split('\n')
line_list = string.split("\n")
numbers = range(no_start, no_end)
relevant_lines = line_list[no_start:no_end]
return "\n".join([
line_no(i + 1, line) for (i, line) in zip(numbers, relevant_lines)
])
return "\n".join(
[line_no(i + 1, line) for (i, line) in zip(numbers, relevant_lines)]
)
def contextualized_yaml_error(raw_contents, error):
@@ -51,9 +45,9 @@ def contextualized_yaml_error(raw_contents, error):
nice_error = prefix_with_line_numbers(raw_contents, min_line, max_line)
return YAML_ERROR_MESSAGE.format(line_number=mark.line + 1,
nice_error=nice_error,
raw_error=error)
return YAML_ERROR_MESSAGE.format(
line_number=mark.line + 1, nice_error=nice_error, raw_error=error
)
def safe_load(contents):
@@ -64,7 +58,7 @@ def load_yaml_text(contents):
try:
return safe_load(contents)
except (yaml.scanner.ScannerError, yaml.YAMLError) as e:
if hasattr(e, 'problem_mark'):
if hasattr(e, "problem_mark"):
error = contextualized_yaml_error(contents, e)
else:
error = str(e)

View File

@@ -32,28 +32,28 @@ from dbt.node_types import NodeType
from dbt.utils import pluralize
import dbt.tracking
graph_file_name = 'graph.gpickle'
graph_file_name = "graph.gpickle"
def _compiled_type_for(model: ParsedNode):
if type(model) not in COMPILED_TYPES:
raise InternalException(
f'Asked to compile {type(model)} node, but it has no compiled form'
f"Asked to compile {type(model)} node, but it has no compiled form"
)
return COMPILED_TYPES[type(model)]
def print_compile_stats(stats):
names = {
NodeType.Model: 'model',
NodeType.Test: 'test',
NodeType.Snapshot: 'snapshot',
NodeType.Analysis: 'analysis',
NodeType.Macro: 'macro',
NodeType.Operation: 'operation',
NodeType.Seed: 'seed file',
NodeType.Source: 'source',
NodeType.Exposure: 'exposure',
NodeType.Model: "model",
NodeType.Test: "test",
NodeType.Snapshot: "snapshot",
NodeType.Analysis: "analysis",
NodeType.Macro: "macro",
NodeType.Operation: "operation",
NodeType.Seed: "seed file",
NodeType.Source: "source",
NodeType.Exposure: "exposure",
}
results = {k: 0 for k in names.keys()}
@@ -64,10 +64,9 @@ def print_compile_stats(stats):
resource_counts = {k.pluralize(): v for k, v in results.items()}
dbt.tracking.track_resource_counts(resource_counts)
stat_line = ", ".join([
pluralize(ct, names.get(t)) for t, ct in results.items()
if t in names
])
stat_line = ", ".join(
[pluralize(ct, names.get(t)) for t, ct in results.items() if t in names]
)
logger.info("Found {}".format(stat_line))
@@ -166,9 +165,7 @@ class Compiler:
extra_context: Dict[str, Any],
) -> Dict[str, Any]:
context = generate_runtime_model(
node, self.config, manifest
)
context = generate_runtime_model(node, self.config, manifest)
context.update(extra_context)
if isinstance(node, CompiledSchemaTestNode):
# for test nodes, add a special keyword args value to the context
@@ -183,8 +180,7 @@ class Compiler:
def _get_relation_name(self, node: ParsedNode):
relation_name = None
if (node.resource_type in NodeType.refable() and
not node.is_ephemeral_model):
if node.resource_type in NodeType.refable() and not node.is_ephemeral_model:
adapter = get_adapter(self.config)
relation_cls = adapter.Relation
relation_name = str(relation_cls.create_from(self.config, node))
@@ -227,32 +223,29 @@ class Compiler:
with_stmt = None
for token in parsed.tokens:
if token.is_keyword and token.normalized == 'WITH':
if token.is_keyword and token.normalized == "WITH":
with_stmt = token
break
if with_stmt is None:
# no with stmt, add one, and inject CTEs right at the beginning
first_token = parsed.token_first()
with_stmt = sqlparse.sql.Token(sqlparse.tokens.Keyword, 'with')
with_stmt = sqlparse.sql.Token(sqlparse.tokens.Keyword, "with")
parsed.insert_before(first_token, with_stmt)
else:
# stmt exists, add a comma (which will come after injected CTEs)
trailing_comma = sqlparse.sql.Token(
sqlparse.tokens.Punctuation, ','
)
trailing_comma = sqlparse.sql.Token(sqlparse.tokens.Punctuation, ",")
parsed.insert_after(with_stmt, trailing_comma)
token = sqlparse.sql.Token(
sqlparse.tokens.Keyword,
", ".join(c.sql for c in ctes)
sqlparse.tokens.Keyword, ", ".join(c.sql for c in ctes)
)
parsed.insert_after(with_stmt, token)
return str(parsed)
def _get_dbt_test_name(self) -> str:
return 'dbt__cte__internal_test'
return "dbt__cte__internal_test"
# This method is called by the 'compile_node' method. Starting
# from the node that it is passed in, it will recursively call
@@ -268,9 +261,7 @@ class Compiler:
) -> Tuple[NonSourceCompiledNode, List[InjectedCTE]]:
if model.compiled_sql is None:
raise RuntimeException(
'Cannot inject ctes into an unparsed node', model
)
raise RuntimeException("Cannot inject ctes into an unparsed node", model)
if model.extra_ctes_injected:
return (model, model.extra_ctes)
@@ -296,19 +287,18 @@ class Compiler:
else:
if cte.id not in manifest.nodes:
raise InternalException(
f'During compilation, found a cte reference that '
f'could not be resolved: {cte.id}'
f"During compilation, found a cte reference that "
f"could not be resolved: {cte.id}"
)
cte_model = manifest.nodes[cte.id]
if not cte_model.is_ephemeral_model:
raise InternalException(f'{cte.id} is not ephemeral')
raise InternalException(f"{cte.id} is not ephemeral")
# This model has already been compiled, so it's been
# through here before
if getattr(cte_model, 'compiled', False):
assert isinstance(cte_model,
tuple(COMPILED_TYPES.values()))
if getattr(cte_model, "compiled", False):
assert isinstance(cte_model, tuple(COMPILED_TYPES.values()))
cte_model = cast(NonSourceCompiledNode, cte_model)
new_prepended_ctes = cte_model.extra_ctes
@@ -316,13 +306,11 @@ class Compiler:
else:
# This is an ephemeral parsed model that we can compile.
# Compile and update the node
cte_model = self._compile_node(
cte_model, manifest, extra_context)
cte_model = self._compile_node(cte_model, manifest, extra_context)
# recursively call this method
cte_model, new_prepended_ctes = \
self._recursively_prepend_ctes(
cte_model, manifest, extra_context
)
cte_model, new_prepended_ctes = self._recursively_prepend_ctes(
cte_model, manifest, extra_context
)
# Save compiled SQL file and sync manifest
self._write_node(cte_model)
manifest.sync_update_node(cte_model)
@@ -330,7 +318,7 @@ class Compiler:
_extend_prepended_ctes(prepended_ctes, new_prepended_ctes)
new_cte_name = self.add_ephemeral_prefix(cte_model.name)
sql = f' {new_cte_name} as (\n{cte_model.compiled_sql}\n)'
sql = f" {new_cte_name} as (\n{cte_model.compiled_sql}\n)"
_add_prepended_cte(prepended_ctes, InjectedCTE(id=cte.id, sql=sql))
@@ -371,11 +359,10 @@ class Compiler:
# compiled_sql, and do the regular prepend logic from CTEs.
name = self._get_dbt_test_name()
cte = InjectedCTE(
id=name,
sql=f' {name} as (\n{compiled_node.compiled_sql}\n)'
id=name, sql=f" {name} as (\n{compiled_node.compiled_sql}\n)"
)
compiled_node.extra_ctes.append(cte)
compiled_node.compiled_sql = f'\nselect count(*) from {name}'
compiled_node.compiled_sql = f"\nselect count(*) from {name}"
return compiled_node
@@ -395,17 +382,17 @@ class Compiler:
logger.debug("Compiling {}".format(node.unique_id))
data = node.to_dict(omit_none=True)
data.update({
'compiled': False,
'compiled_sql': None,
'extra_ctes_injected': False,
'extra_ctes': [],
})
data.update(
{
"compiled": False,
"compiled_sql": None,
"extra_ctes_injected": False,
"extra_ctes": [],
}
)
compiled_node = _compiled_type_for(node).from_dict(data)
context = self._create_node_context(
compiled_node, manifest, extra_context
)
context = self._create_node_context(compiled_node, manifest, extra_context)
compiled_node.compiled_sql = jinja.get_rendered(
node.raw_sql,
@@ -419,9 +406,7 @@ class Compiler:
# add ctes for specific test nodes, and also for
# possible future use in adapters
compiled_node = self._add_ctes(
compiled_node, manifest, extra_context
)
compiled_node = self._add_ctes(compiled_node, manifest, extra_context)
return compiled_node
@@ -431,21 +416,17 @@ class Compiler:
if flags.WRITE_JSON:
linker.write_graph(graph_path, manifest)
def link_node(
self, linker: Linker, node: GraphMemberNode, manifest: Manifest
):
def link_node(self, linker: Linker, node: GraphMemberNode, manifest: Manifest):
linker.add_node(node.unique_id)
for dependency in node.depends_on_nodes:
if dependency in manifest.nodes:
linker.dependency(
node.unique_id,
(manifest.nodes[dependency].unique_id)
node.unique_id, (manifest.nodes[dependency].unique_id)
)
elif dependency in manifest.sources:
linker.dependency(
node.unique_id,
(manifest.sources[dependency].unique_id)
node.unique_id, (manifest.sources[dependency].unique_id)
)
else:
dependency_not_found(node, dependency)
@@ -480,16 +461,13 @@ class Compiler:
# writes the "compiled_sql" into the target/compiled directory
def _write_node(self, node: NonSourceCompiledNode) -> ManifestNode:
if (not node.extra_ctes_injected or
node.resource_type == NodeType.Snapshot):
if not node.extra_ctes_injected or node.resource_type == NodeType.Snapshot:
return node
logger.debug(f'Writing injected SQL for node "{node.unique_id}"')
if node.compiled_sql:
node.build_path = node.write_node(
self.config.target_path,
'compiled',
node.compiled_sql
self.config.target_path, "compiled", node.compiled_sql
)
return node
@@ -507,9 +485,7 @@ class Compiler:
) -> NonSourceCompiledNode:
node = self._compile_node(node, manifest, extra_context)
node, _ = self._recursively_prepend_ctes(
node, manifest, extra_context
)
node, _ = self._recursively_prepend_ctes(node, manifest, extra_context)
if write:
self._write_node(node)
return node

View File

@@ -20,10 +20,8 @@ from dbt.utils import coerce_dict_str
from .renderer import ProfileRenderer
DEFAULT_THREADS = 1
DEFAULT_PROFILES_DIR = os.path.join(os.path.expanduser('~'), '.dbt')
PROFILES_DIR = os.path.expanduser(
os.getenv('DBT_PROFILES_DIR', DEFAULT_PROFILES_DIR)
)
DEFAULT_PROFILES_DIR = os.path.join(os.path.expanduser("~"), ".dbt")
PROFILES_DIR = os.path.expanduser(os.getenv("DBT_PROFILES_DIR", DEFAULT_PROFILES_DIR))
INVALID_PROFILE_MESSAGE = """
dbt encountered an error while trying to read your profiles.yml file.
@@ -43,11 +41,13 @@ Here, [profile name] should be replaced with a profile name
defined in your profiles.yml file. You can find profiles.yml here:
{profiles_file}/profiles.yml
""".format(profiles_file=PROFILES_DIR)
""".format(
profiles_file=PROFILES_DIR
)
def read_profile(profiles_dir: str) -> Dict[str, Any]:
path = os.path.join(profiles_dir, 'profiles.yml')
path = os.path.join(profiles_dir, "profiles.yml")
contents = None
if os.path.isfile(path):
@@ -55,12 +55,8 @@ def read_profile(profiles_dir: str) -> Dict[str, Any]:
contents = load_file_contents(path, strip=False)
yaml_content = load_yaml_text(contents)
if not yaml_content:
msg = f'The profiles.yml file at {path} is empty'
raise DbtProfileError(
INVALID_PROFILE_MESSAGE.format(
error_string=msg
)
)
msg = f"The profiles.yml file at {path} is empty"
raise DbtProfileError(INVALID_PROFILE_MESSAGE.format(error_string=msg))
return yaml_content
except ValidationException as e:
msg = INVALID_PROFILE_MESSAGE.format(error_string=e)
@@ -73,7 +69,7 @@ def read_user_config(directory: str) -> UserConfig:
try:
profile = read_profile(directory)
if profile:
user_cfg = coerce_dict_str(profile.get('config', {}))
user_cfg = coerce_dict_str(profile.get("config", {}))
if user_cfg is not None:
UserConfig.validate(user_cfg)
return UserConfig.from_dict(user_cfg)
@@ -92,9 +88,7 @@ class Profile(HasCredentials):
threads: int
credentials: Credentials
def to_profile_info(
self, serialize_credentials: bool = False
) -> Dict[str, Any]:
def to_profile_info(self, serialize_credentials: bool = False) -> Dict[str, Any]:
"""Unlike to_project_config, this dict is not a mirror of any existing
on-disk data structure. It's used when creating a new profile from an
existing one.
@@ -104,34 +98,35 @@ class Profile(HasCredentials):
:returns dict: The serialized profile.
"""
result = {
'profile_name': self.profile_name,
'target_name': self.target_name,
'config': self.config,
'threads': self.threads,
'credentials': self.credentials,
"profile_name": self.profile_name,
"target_name": self.target_name,
"config": self.config,
"threads": self.threads,
"credentials": self.credentials,
}
if serialize_credentials:
result['config'] = self.config.to_dict(omit_none=True)
result['credentials'] = self.credentials.to_dict(omit_none=True)
result["config"] = self.config.to_dict(omit_none=True)
result["credentials"] = self.credentials.to_dict(omit_none=True)
return result
def to_target_dict(self) -> Dict[str, Any]:
target = dict(
self.credentials.connection_info(with_aliases=True)
target = dict(self.credentials.connection_info(with_aliases=True))
target.update(
{
"type": self.credentials.type,
"threads": self.threads,
"name": self.target_name,
"target_name": self.target_name,
"profile_name": self.profile_name,
"config": self.config.to_dict(omit_none=True),
}
)
target.update({
'type': self.credentials.type,
'threads': self.threads,
'name': self.target_name,
'target_name': self.target_name,
'profile_name': self.profile_name,
'config': self.config.to_dict(omit_none=True),
})
return target
def __eq__(self, other: object) -> bool:
if not (isinstance(other, self.__class__) and
isinstance(self, other.__class__)):
if not (
isinstance(other, self.__class__) and isinstance(self, other.__class__)
):
return NotImplemented
return self.to_profile_info() == other.to_profile_info()
@@ -151,14 +146,17 @@ class Profile(HasCredentials):
) -> Credentials:
# avoid an import cycle
from dbt.adapters.factory import load_plugin
# credentials carry their 'type' in their actual type, not their
# attributes. We do want this in order to pick our Credentials class.
if 'type' not in profile:
if "type" not in profile:
raise DbtProfileError(
'required field "type" not found in profile {} and target {}'
.format(profile_name, target_name))
'required field "type" not found in profile {} and target {}'.format(
profile_name, target_name
)
)
typename = profile.pop('type')
typename = profile.pop("type")
try:
cls = load_plugin(typename)
data = cls.translate_aliases(profile)
@@ -167,8 +165,9 @@ class Profile(HasCredentials):
except (RuntimeException, ValidationError) as e:
msg = str(e) if isinstance(e, RuntimeException) else e.message
raise DbtProfileError(
'Credentials in profile "{}", target "{}" invalid: {}'
.format(profile_name, target_name, msg)
'Credentials in profile "{}", target "{}" invalid: {}'.format(
profile_name, target_name, msg
)
) from e
return credentials
@@ -189,19 +188,21 @@ class Profile(HasCredentials):
def _get_profile_data(
profile: Dict[str, Any], profile_name: str, target_name: str
) -> Dict[str, Any]:
if 'outputs' not in profile:
if "outputs" not in profile:
raise DbtProfileError(
"outputs not specified in profile '{}'".format(profile_name)
)
outputs = profile['outputs']
outputs = profile["outputs"]
if target_name not in outputs:
outputs = '\n'.join(' - {}'.format(output)
for output in outputs)
msg = ("The profile '{}' does not have a target named '{}'. The "
"valid target names for this profile are:\n{}"
.format(profile_name, target_name, outputs))
raise DbtProfileError(msg, result_type='invalid_target')
outputs = "\n".join(" - {}".format(output) for output in outputs)
msg = (
"The profile '{}' does not have a target named '{}'. The "
"valid target names for this profile are:\n{}".format(
profile_name, target_name, outputs
)
)
raise DbtProfileError(msg, result_type="invalid_target")
profile_data = outputs[target_name]
if not isinstance(profile_data, dict):
@@ -209,7 +210,7 @@ class Profile(HasCredentials):
f"output '{target_name}' of profile '{profile_name}' is "
f"misconfigured in profiles.yml"
)
raise DbtProfileError(msg, result_type='invalid_target')
raise DbtProfileError(msg, result_type="invalid_target")
return profile_data
@@ -220,8 +221,8 @@ class Profile(HasCredentials):
threads: int,
profile_name: str,
target_name: str,
user_cfg: Optional[Dict[str, Any]] = None
) -> 'Profile':
user_cfg: Optional[Dict[str, Any]] = None,
) -> "Profile":
"""Create a profile from an existing set of Credentials and the
remaining information.
@@ -244,7 +245,7 @@ class Profile(HasCredentials):
target_name=target_name,
config=config,
threads=threads,
credentials=credentials
credentials=credentials,
)
profile.validate()
return profile
@@ -269,19 +270,18 @@ class Profile(HasCredentials):
# name to extract a profile that we can render.
if target_override is not None:
target_name = target_override
elif 'target' in raw_profile:
elif "target" in raw_profile:
# render the target if it was parsed from yaml
target_name = renderer.render_value(raw_profile['target'])
target_name = renderer.render_value(raw_profile["target"])
else:
target_name = 'default'
target_name = "default"
logger.debug(
"target not specified in profile '{}', using '{}'"
.format(profile_name, target_name)
"target not specified in profile '{}', using '{}'".format(
profile_name, target_name
)
)
raw_profile_data = cls._get_profile_data(
raw_profile, profile_name, target_name
)
raw_profile_data = cls._get_profile_data(raw_profile, profile_name, target_name)
try:
profile_data = renderer.render_data(raw_profile_data)
@@ -298,7 +298,7 @@ class Profile(HasCredentials):
user_cfg: Optional[Dict[str, Any]] = None,
target_override: Optional[str] = None,
threads_override: Optional[int] = None,
) -> 'Profile':
) -> "Profile":
"""Create a profile from its raw profile information.
(this is an intermediate step, mostly useful for unit testing)
@@ -319,7 +319,7 @@ class Profile(HasCredentials):
"""
# user_cfg is not rendered.
if user_cfg is None:
user_cfg = raw_profile.get('config')
user_cfg = raw_profile.get("config")
# TODO: should it be, and the values coerced to bool?
target_name, profile_data = cls.render_profile(
raw_profile, profile_name, target_override, renderer
@@ -327,7 +327,7 @@ class Profile(HasCredentials):
# valid connections never include the number of threads, but it's
# stored on a per-connection level in the raw configs
threads = profile_data.pop('threads', DEFAULT_THREADS)
threads = profile_data.pop("threads", DEFAULT_THREADS)
if threads_override is not None:
threads = threads_override
@@ -340,7 +340,7 @@ class Profile(HasCredentials):
profile_name=profile_name,
target_name=target_name,
threads=threads,
user_cfg=user_cfg
user_cfg=user_cfg,
)
@classmethod
@@ -351,7 +351,7 @@ class Profile(HasCredentials):
renderer: ProfileRenderer,
target_override: Optional[str] = None,
threads_override: Optional[int] = None,
) -> 'Profile':
) -> "Profile":
"""
:param raw_profiles: The profile data, from disk as yaml.
:param profile_name: The profile name to use.
@@ -375,15 +375,9 @@ class Profile(HasCredentials):
# don't render keys, so we can pluck that out
raw_profile = raw_profiles[profile_name]
if not raw_profile:
msg = (
f'Profile {profile_name} in profiles.yml is empty'
)
raise DbtProfileError(
INVALID_PROFILE_MESSAGE.format(
error_string=msg
)
)
user_cfg = raw_profiles.get('config')
msg = f"Profile {profile_name} in profiles.yml is empty"
raise DbtProfileError(INVALID_PROFILE_MESSAGE.format(error_string=msg))
user_cfg = raw_profiles.get("config")
return cls.from_raw_profile_info(
raw_profile=raw_profile,
@@ -400,7 +394,7 @@ class Profile(HasCredentials):
args: Any,
renderer: ProfileRenderer,
project_profile_name: Optional[str],
) -> 'Profile':
) -> "Profile":
"""Given the raw profiles as read from disk and the name of the desired
profile if specified, return the profile component of the runtime
config.
@@ -415,15 +409,16 @@ class Profile(HasCredentials):
target could not be found.
:returns Profile: The new Profile object.
"""
threads_override = getattr(args, 'threads', None)
target_override = getattr(args, 'target', None)
threads_override = getattr(args, "threads", None)
target_override = getattr(args, "target", None)
raw_profiles = read_profile(args.profiles_dir)
profile_name = cls.pick_profile_name(getattr(args, 'profile', None),
project_profile_name)
profile_name = cls.pick_profile_name(
getattr(args, "profile", None), project_profile_name
)
return cls.from_raw_profiles(
raw_profiles=raw_profiles,
profile_name=profile_name,
renderer=renderer,
target_override=target_override,
threads_override=threads_override
threads_override=threads_override,
)

View File

@@ -2,7 +2,13 @@ from copy import deepcopy
from dataclasses import dataclass, field
from itertools import chain
from typing import (
List, Dict, Any, Optional, TypeVar, Union, Mapping,
List,
Dict,
Any,
Optional,
TypeVar,
Union,
Mapping,
)
from typing_extensions import Protocol, runtime_checkable
@@ -82,9 +88,7 @@ def _load_yaml(path):
def package_data_from_root(project_root):
package_filepath = resolve_path_from_base(
'packages.yml', project_root
)
package_filepath = resolve_path_from_base("packages.yml", project_root)
if path_exists(package_filepath):
packages_dict = _load_yaml(package_filepath)
@@ -95,7 +99,7 @@ def package_data_from_root(project_root):
def package_config_from_data(packages_data: Dict[str, Any]):
if not packages_data:
packages_data = {'packages': []}
packages_data = {"packages": []}
try:
PackageConfig.validate(packages_data)
@@ -118,7 +122,7 @@ def _parse_versions(versions: Union[List[str], str]) -> List[VersionSpecifier]:
Regardless, this will return a list of VersionSpecifiers
"""
if isinstance(versions, str):
versions = versions.split(',')
versions = versions.split(",")
return [VersionSpecifier.from_version_string(v) for v in versions]
@@ -129,11 +133,12 @@ def _all_source_paths(
analysis_paths: List[str],
macro_paths: List[str],
) -> List[str]:
return list(chain(source_paths, data_paths, snapshot_paths, analysis_paths,
macro_paths))
return list(
chain(source_paths, data_paths, snapshot_paths, analysis_paths, macro_paths)
)
T = TypeVar('T')
T = TypeVar("T")
def value_or(value: Optional[T], default: T) -> T:
@@ -146,30 +151,27 @@ def value_or(value: Optional[T], default: T) -> T:
def _raw_project_from(project_root: str) -> Dict[str, Any]:
project_root = os.path.normpath(project_root)
project_yaml_filepath = os.path.join(project_root, 'dbt_project.yml')
project_yaml_filepath = os.path.join(project_root, "dbt_project.yml")
# get the project.yml contents
if not path_exists(project_yaml_filepath):
raise DbtProjectError(
'no dbt_project.yml found at expected path {}'
.format(project_yaml_filepath)
"no dbt_project.yml found at expected path {}".format(project_yaml_filepath)
)
project_dict = _load_yaml(project_yaml_filepath)
if not isinstance(project_dict, dict):
raise DbtProjectError(
'dbt_project.yml does not parse to a dictionary'
)
raise DbtProjectError("dbt_project.yml does not parse to a dictionary")
return project_dict
def _query_comment_from_cfg(
cfg_query_comment: Union[QueryComment, NoValue, str, None]
cfg_query_comment: Union[QueryComment, NoValue, str, None]
) -> QueryComment:
if not cfg_query_comment:
return QueryComment(comment='')
return QueryComment(comment="")
if isinstance(cfg_query_comment, str):
return QueryComment(comment=cfg_query_comment)
@@ -186,9 +188,7 @@ def validate_version(dbt_version: List[VersionSpecifier], project_name: str):
if not versions_compatible(*dbt_version):
msg = IMPOSSIBLE_VERSION_ERROR.format(
package=project_name,
version_spec=[
x.to_version_string() for x in dbt_version
]
version_spec=[x.to_version_string() for x in dbt_version],
)
raise DbtProjectError(msg)
@@ -196,9 +196,7 @@ def validate_version(dbt_version: List[VersionSpecifier], project_name: str):
msg = INVALID_VERSION_ERROR.format(
package=project_name,
installed=installed.to_version_string(),
version_spec=[
x.to_version_string() for x in dbt_version
]
version_spec=[x.to_version_string() for x in dbt_version],
)
raise DbtProjectError(msg)
@@ -207,8 +205,8 @@ def _get_required_version(
project_dict: Dict[str, Any],
verify_version: bool,
) -> List[VersionSpecifier]:
dbt_raw_version: Union[List[str], str] = '>=0.0.0'
required = project_dict.get('require-dbt-version')
dbt_raw_version: Union[List[str], str] = ">=0.0.0"
required = project_dict.get("require-dbt-version")
if required is not None:
dbt_raw_version = required
@@ -219,11 +217,11 @@ def _get_required_version(
if verify_version:
# no name is also an error that we want to raise
if 'name' not in project_dict:
if "name" not in project_dict:
raise DbtProjectError(
'Required "name" field not present in project',
)
validate_version(dbt_version, project_dict['name'])
validate_version(dbt_version, project_dict["name"])
return dbt_version
@@ -231,34 +229,36 @@ def _get_required_version(
@dataclass
class RenderComponents:
project_dict: Dict[str, Any] = field(
metadata=dict(description='The project dictionary')
metadata=dict(description="The project dictionary")
)
packages_dict: Dict[str, Any] = field(
metadata=dict(description='The packages dictionary')
metadata=dict(description="The packages dictionary")
)
selectors_dict: Dict[str, Any] = field(
metadata=dict(description='The selectors dictionary')
metadata=dict(description="The selectors dictionary")
)
@dataclass
class PartialProject(RenderComponents):
profile_name: Optional[str] = field(metadata=dict(
description='The unrendered profile name in the project, if set'
))
project_name: Optional[str] = field(metadata=dict(
description=(
'The name of the project. This should always be set and will not '
'be rendered'
profile_name: Optional[str] = field(
metadata=dict(description="The unrendered profile name in the project, if set")
)
project_name: Optional[str] = field(
metadata=dict(
description=(
"The name of the project. This should always be set and will not "
"be rendered"
)
)
))
)
project_root: str = field(
metadata=dict(description='The root directory of the project'),
metadata=dict(description="The root directory of the project"),
)
verify_version: bool = field(
metadata=dict(description=(
'If True, verify the dbt version matches the required version'
))
metadata=dict(
description=("If True, verify the dbt version matches the required version")
)
)
def render_profile_name(self, renderer) -> Optional[str]:
@@ -271,9 +271,7 @@ class PartialProject(RenderComponents):
renderer: DbtProjectYamlRenderer,
) -> RenderComponents:
rendered_project = renderer.render_project(
self.project_dict, self.project_root
)
rendered_project = renderer.render_project(self.project_dict, self.project_root)
rendered_packages = renderer.render_packages(self.packages_dict)
rendered_selectors = renderer.render_selectors(self.selectors_dict)
@@ -283,16 +281,16 @@ class PartialProject(RenderComponents):
selectors_dict=rendered_selectors,
)
def render(self, renderer: DbtProjectYamlRenderer) -> 'Project':
def render(self, renderer: DbtProjectYamlRenderer) -> "Project":
try:
rendered = self.get_rendered(renderer)
return self.create_project(rendered)
except DbtProjectError as exc:
if exc.path is None:
exc.path = os.path.join(self.project_root, 'dbt_project.yml')
exc.path = os.path.join(self.project_root, "dbt_project.yml")
raise
def create_project(self, rendered: RenderComponents) -> 'Project':
def create_project(self, rendered: RenderComponents) -> "Project":
unrendered = RenderComponents(
project_dict=self.project_dict,
packages_dict=self.packages_dict,
@@ -305,9 +303,7 @@ class PartialProject(RenderComponents):
try:
ProjectContract.validate(rendered.project_dict)
cfg = ProjectContract.from_dict(
rendered.project_dict
)
cfg = ProjectContract.from_dict(rendered.project_dict)
except ValidationError as e:
raise DbtProjectError(validator_error_message(e)) from e
# name/version are required in the Project definition, so we can assume
@@ -317,31 +313,30 @@ class PartialProject(RenderComponents):
# this is added at project_dict parse time and should always be here
# once we see it.
if cfg.project_root is None:
raise DbtProjectError('cfg must have a project root!')
raise DbtProjectError("cfg must have a project root!")
else:
project_root = cfg.project_root
# this is only optional in the sense that if it's not present, it needs
# to have been a cli argument.
profile_name = cfg.profile
# these are all the defaults
source_paths: List[str] = value_or(cfg.source_paths, ['models'])
macro_paths: List[str] = value_or(cfg.macro_paths, ['macros'])
data_paths: List[str] = value_or(cfg.data_paths, ['data'])
test_paths: List[str] = value_or(cfg.test_paths, ['test'])
source_paths: List[str] = value_or(cfg.source_paths, ["models"])
macro_paths: List[str] = value_or(cfg.macro_paths, ["macros"])
data_paths: List[str] = value_or(cfg.data_paths, ["data"])
test_paths: List[str] = value_or(cfg.test_paths, ["test"])
analysis_paths: List[str] = value_or(cfg.analysis_paths, [])
snapshot_paths: List[str] = value_or(cfg.snapshot_paths, ['snapshots'])
snapshot_paths: List[str] = value_or(cfg.snapshot_paths, ["snapshots"])
all_source_paths: List[str] = _all_source_paths(
source_paths, data_paths, snapshot_paths, analysis_paths,
macro_paths
source_paths, data_paths, snapshot_paths, analysis_paths, macro_paths
)
docs_paths: List[str] = value_or(cfg.docs_paths, all_source_paths)
asset_paths: List[str] = value_or(cfg.asset_paths, [])
target_path: str = value_or(cfg.target_path, 'target')
target_path: str = value_or(cfg.target_path, "target")
clean_targets: List[str] = value_or(cfg.clean_targets, [target_path])
log_path: str = value_or(cfg.log_path, 'logs')
modules_path: str = value_or(cfg.modules_path, 'dbt_modules')
log_path: str = value_or(cfg.log_path, "logs")
modules_path: str = value_or(cfg.modules_path, "dbt_modules")
# in the default case we'll populate this once we know the adapter type
# It would be nice to just pass along a Quoting here, but that would
# break many things
@@ -373,11 +368,12 @@ class PartialProject(RenderComponents):
packages = package_config_from_data(rendered.packages_dict)
selectors = selector_config_from_data(rendered.selectors_dict)
manifest_selectors: Dict[str, Any] = {}
if rendered.selectors_dict and rendered.selectors_dict['selectors']:
if rendered.selectors_dict and rendered.selectors_dict["selectors"]:
# this is a dict with a single key 'selectors' pointing to a list
# of dicts.
manifest_selectors = SelectorDict.parse_from_selectors_list(
rendered.selectors_dict['selectors'])
rendered.selectors_dict["selectors"]
)
project = Project(
project_name=name,
@@ -426,10 +422,9 @@ class PartialProject(RenderComponents):
*,
verify_version: bool = False,
):
"""Construct a partial project from its constituent dicts.
"""
project_name = project_dict.get('name')
profile_name = project_dict.get('profile')
"""Construct a partial project from its constituent dicts."""
project_name = project_dict.get("name")
profile_name = project_dict.get("profile")
return cls(
profile_name=profile_name,
@@ -444,14 +439,14 @@ class PartialProject(RenderComponents):
@classmethod
def from_project_root(
cls, project_root: str, *, verify_version: bool = False
) -> 'PartialProject':
) -> "PartialProject":
project_root = os.path.normpath(project_root)
project_dict = _raw_project_from(project_root)
config_version = project_dict.get('config-version', 1)
config_version = project_dict.get("config-version", 1)
if config_version != 2:
raise DbtProjectError(
f'Invalid config version: {config_version}, expected 2',
path=os.path.join(project_root, 'dbt_project.yml')
f"Invalid config version: {config_version}, expected 2",
path=os.path.join(project_root, "dbt_project.yml"),
)
packages_dict = package_data_from_root(project_root)
@@ -468,15 +463,10 @@ class PartialProject(RenderComponents):
class VarProvider:
"""Var providers are tied to a particular Project."""
def __init__(
self,
vars: Dict[str, Dict[str, Any]]
) -> None:
def __init__(self, vars: Dict[str, Dict[str, Any]]) -> None:
self.vars = vars
def vars_for(
self, node: IsFQNResource, adapter_type: str
) -> Mapping[str, Any]:
def vars_for(self, node: IsFQNResource, adapter_type: str) -> Mapping[str, Any]:
# in v2, vars are only either project or globally scoped
merged = MultiDict([self.vars])
merged.add(self.vars.get(node.package_name, {}))
@@ -525,8 +515,11 @@ class Project:
@property
def all_source_paths(self) -> List[str]:
return _all_source_paths(
self.source_paths, self.data_paths, self.snapshot_paths,
self.analysis_paths, self.macro_paths
self.source_paths,
self.data_paths,
self.snapshot_paths,
self.analysis_paths,
self.macro_paths,
)
def __str__(self):
@@ -534,11 +527,13 @@ class Project:
return str(cfg)
def __eq__(self, other):
if not (isinstance(other, self.__class__) and
isinstance(self, other.__class__)):
if not (
isinstance(other, self.__class__) and isinstance(self, other.__class__)
):
return False
return self.to_project_config(with_packages=True) == \
other.to_project_config(with_packages=True)
return self.to_project_config(with_packages=True) == other.to_project_config(
with_packages=True
)
def to_project_config(self, with_packages=False):
"""Return a dict representation of the config that could be written to
@@ -548,38 +543,39 @@ class Project:
file in the root.
:returns dict: The serialized profile.
"""
result = deepcopy({
'name': self.project_name,
'version': self.version,
'project-root': self.project_root,
'profile': self.profile_name,
'source-paths': self.source_paths,
'macro-paths': self.macro_paths,
'data-paths': self.data_paths,
'test-paths': self.test_paths,
'analysis-paths': self.analysis_paths,
'docs-paths': self.docs_paths,
'asset-paths': self.asset_paths,
'target-path': self.target_path,
'snapshot-paths': self.snapshot_paths,
'clean-targets': self.clean_targets,
'log-path': self.log_path,
'quoting': self.quoting,
'models': self.models,
'on-run-start': self.on_run_start,
'on-run-end': self.on_run_end,
'seeds': self.seeds,
'snapshots': self.snapshots,
'sources': self.sources,
'vars': self.vars.to_dict(),
'require-dbt-version': [
v.to_version_string() for v in self.dbt_version
],
'config-version': self.config_version,
})
result = deepcopy(
{
"name": self.project_name,
"version": self.version,
"project-root": self.project_root,
"profile": self.profile_name,
"source-paths": self.source_paths,
"macro-paths": self.macro_paths,
"data-paths": self.data_paths,
"test-paths": self.test_paths,
"analysis-paths": self.analysis_paths,
"docs-paths": self.docs_paths,
"asset-paths": self.asset_paths,
"target-path": self.target_path,
"snapshot-paths": self.snapshot_paths,
"clean-targets": self.clean_targets,
"log-path": self.log_path,
"quoting": self.quoting,
"models": self.models,
"on-run-start": self.on_run_start,
"on-run-end": self.on_run_end,
"seeds": self.seeds,
"snapshots": self.snapshots,
"sources": self.sources,
"vars": self.vars.to_dict(),
"require-dbt-version": [
v.to_version_string() for v in self.dbt_version
],
"config-version": self.config_version,
}
)
if self.query_comment:
result['query-comment'] = \
self.query_comment.to_dict(omit_none=True)
result["query-comment"] = self.query_comment.to_dict(omit_none=True)
if with_packages:
result.update(self.packages.to_dict(omit_none=True))
@@ -610,8 +606,8 @@ class Project:
selectors_dict: Dict[str, Any],
renderer: DbtProjectYamlRenderer,
*,
verify_version: bool = False
) -> 'Project':
verify_version: bool = False,
) -> "Project":
partial = PartialProject.from_dicts(
project_root=project_root,
project_dict=project_dict,
@@ -628,17 +624,17 @@ class Project:
renderer: DbtProjectYamlRenderer,
*,
verify_version: bool = False,
) -> 'Project':
) -> "Project":
partial = cls.partial_load(project_root, verify_version=verify_version)
return partial.render(renderer)
def hashed_name(self):
return hashlib.md5(self.project_name.encode('utf-8')).hexdigest()
return hashlib.md5(self.project_name.encode("utf-8")).hexdigest()
def get_selector(self, name: str) -> SelectionSpec:
if name not in self.selectors:
raise RuntimeException(
f'Could not find selector named {name}, expected one of '
f'{list(self.selectors)}'
f"Could not find selector named {name}, expected one of "
f"{list(self.selectors)}"
)
return self.selectors[name]

View File

@@ -2,9 +2,7 @@ from typing import Dict, Any, Tuple, Optional, Union, Callable
from dbt.clients.jinja import get_rendered, catch_jinja
from dbt.exceptions import (
DbtProjectError, CompilationException, RecursionException
)
from dbt.exceptions import DbtProjectError, CompilationException, RecursionException
from dbt.node_types import NodeType
from dbt.utils import deep_map
@@ -18,7 +16,7 @@ class BaseRenderer:
@property
def name(self):
return 'Rendering'
return "Rendering"
def should_render_keypath(self, keypath: Keypath) -> bool:
return True
@@ -29,9 +27,7 @@ class BaseRenderer:
return self.render_value(value, keypath)
def render_value(
self, value: Any, keypath: Optional[Keypath] = None
) -> Any:
def render_value(self, value: Any, keypath: Optional[Keypath] = None) -> Any:
# keypath is ignored.
# if it wasn't read as a string, ignore it
if not isinstance(value, str):
@@ -40,18 +36,16 @@ class BaseRenderer:
with catch_jinja():
return get_rendered(value, self.context, native=True)
except CompilationException as exc:
msg = f'Could not render {value}: {exc.msg}'
msg = f"Could not render {value}: {exc.msg}"
raise CompilationException(msg) from exc
def render_data(
self, data: Dict[str, Any]
) -> Dict[str, Any]:
def render_data(self, data: Dict[str, Any]) -> Dict[str, Any]:
try:
return deep_map(self.render_entry, data)
except RecursionException:
raise DbtProjectError(
f'Cycle detected: {self.name} input has a reference to itself',
project=data
f"Cycle detected: {self.name} input has a reference to itself",
project=data,
)
@@ -78,15 +72,15 @@ class ProjectPostprocessor(Dict[Keypath, Callable[[Any], Any]]):
def __init__(self):
super().__init__()
self[('on-run-start',)] = _list_if_none_or_string
self[('on-run-end',)] = _list_if_none_or_string
self[("on-run-start",)] = _list_if_none_or_string
self[("on-run-end",)] = _list_if_none_or_string
for k in ('models', 'seeds', 'snapshots'):
for k in ("models", "seeds", "snapshots"):
self[(k,)] = _dict_if_none
self[(k, 'vars')] = _dict_if_none
self[(k, 'pre-hook')] = _list_if_none_or_string
self[(k, 'post-hook')] = _list_if_none_or_string
self[('seeds', 'column_types')] = _dict_if_none
self[(k, "vars")] = _dict_if_none
self[(k, "pre-hook")] = _list_if_none_or_string
self[(k, "post-hook")] = _list_if_none_or_string
self[("seeds", "column_types")] = _dict_if_none
def postprocess(self, value: Any, key: Keypath) -> Any:
if key in self:
@@ -101,7 +95,7 @@ class DbtProjectYamlRenderer(BaseRenderer):
@property
def name(self):
'Project config'
"Project config"
def get_package_renderer(self) -> BaseRenderer:
return PackageRenderer(self.context)
@@ -116,7 +110,7 @@ class DbtProjectYamlRenderer(BaseRenderer):
) -> Dict[str, Any]:
"""Render the project and insert the project root after rendering."""
rendered_project = self.render_data(project)
rendered_project['project-root'] = project_root
rendered_project["project-root"] = project_root
return rendered_project
def render_packages(self, packages: Dict[str, Any]):
@@ -138,20 +132,19 @@ class DbtProjectYamlRenderer(BaseRenderer):
first = keypath[0]
# run hooks are not rendered
if first in {'on-run-start', 'on-run-end', 'query-comment'}:
if first in {"on-run-start", "on-run-end", "query-comment"}:
return False
# don't render vars blocks until runtime
if first == 'vars':
if first == "vars":
return False
if first in {'seeds', 'models', 'snapshots', 'seeds'}:
if first in {"seeds", "models", "snapshots", "seeds"}:
keypath_parts = {
(k.lstrip('+') if isinstance(k, str) else k)
for k in keypath
(k.lstrip("+") if isinstance(k, str) else k) for k in keypath
}
# model-level hooks
if 'pre-hook' in keypath_parts or 'post-hook' in keypath_parts:
if "pre-hook" in keypath_parts or "post-hook" in keypath_parts:
return False
return True
@@ -160,17 +153,15 @@ class DbtProjectYamlRenderer(BaseRenderer):
class ProfileRenderer(BaseRenderer):
@property
def name(self):
'Profile'
"Profile"
class SchemaYamlRenderer(BaseRenderer):
DOCUMENTABLE_NODES = frozenset(
n.pluralize() for n in NodeType.documentable()
)
DOCUMENTABLE_NODES = frozenset(n.pluralize() for n in NodeType.documentable())
@property
def name(self):
return 'Rendering yaml'
return "Rendering yaml"
def _is_norender_key(self, keypath: Keypath) -> bool:
"""
@@ -185,13 +176,13 @@ class SchemaYamlRenderer(BaseRenderer):
Return True if it's tests or description - those aren't rendered
"""
if len(keypath) >= 2 and keypath[1] in ('tests', 'description'):
if len(keypath) >= 2 and keypath[1] in ("tests", "description"):
return True
if (
len(keypath) >= 4 and
keypath[1] == 'columns' and
keypath[3] in ('tests', 'description')
len(keypath) >= 4
and keypath[1] == "columns"
and keypath[3] in ("tests", "description")
):
return True
@@ -209,13 +200,13 @@ class SchemaYamlRenderer(BaseRenderer):
return True
if keypath[0] == NodeType.Source.pluralize():
if keypath[2] == 'description':
if keypath[2] == "description":
return False
if keypath[2] == 'tables':
if keypath[2] == "tables":
if self._is_norender_key(keypath[3:]):
return False
elif keypath[0] == NodeType.Macro.pluralize():
if keypath[2] == 'arguments':
if keypath[2] == "arguments":
if self._is_norender_key(keypath[3:]):
return False
elif self._is_norender_key(keypath[1:]):
@@ -229,10 +220,10 @@ class SchemaYamlRenderer(BaseRenderer):
class PackageRenderer(BaseRenderer):
@property
def name(self):
return 'Packages config'
return "Packages config"
class SelectorRenderer(BaseRenderer):
@property
def name(self):
return 'Selector config'
return "Selector config"

View File

@@ -4,8 +4,16 @@ from copy import deepcopy
from dataclasses import dataclass, fields
from pathlib import Path
from typing import (
Dict, Any, Optional, Mapping, Iterator, Iterable, Tuple, List, MutableSet,
Type
Dict,
Any,
Optional,
Mapping,
Iterator,
Iterable,
Tuple,
List,
MutableSet,
Type,
)
from .profile import Profile
@@ -15,7 +23,7 @@ from .utils import parse_cli_vars
from dbt import tracking
from dbt.adapters.factory import get_relation_class_by_name, get_include_paths
from dbt.helper_types import FQNPath, PathSet
from dbt.context.base import generate_base_context
from dbt.context import generate_base_context
from dbt.context.target import generate_target_context
from dbt.contracts.connection import AdapterRequiredConfig, Credentials
from dbt.contracts.graph.manifest import ManifestMetadata
@@ -30,15 +38,13 @@ from dbt.exceptions import (
DbtProjectError,
validator_error_message,
warn_or_error,
raise_compiler_error
raise_compiler_error,
)
from dbt.dataclass_schema import ValidationError
def _project_quoting_dict(
proj: Project, profile: Profile
) -> Dict[ComponentName, bool]:
def _project_quoting_dict(proj: Project, profile: Profile) -> Dict[ComponentName, bool]:
src: Dict[str, Any] = profile.credentials.translate_aliases(proj.quoting)
result: Dict[ComponentName, bool] = {}
for key in ComponentName:
@@ -54,7 +60,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
args: Any
profile_name: str
cli_vars: Dict[str, Any]
dependencies: Optional[Mapping[str, 'RuntimeConfig']] = None
dependencies: Optional[Mapping[str, "RuntimeConfig"]] = None
def __post_init__(self):
self.validate()
@@ -65,8 +71,8 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
project: Project,
profile: Profile,
args: Any,
dependencies: Optional[Mapping[str, 'RuntimeConfig']] = None,
) -> 'RuntimeConfig':
dependencies: Optional[Mapping[str, "RuntimeConfig"]] = None,
) -> "RuntimeConfig":
"""Instantiate a RuntimeConfig from its components.
:param profile: A parsed dbt Profile.
@@ -80,7 +86,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
.replace_dict(_project_quoting_dict(project, profile))
).to_dict(omit_none=True)
cli_vars: Dict[str, Any] = parse_cli_vars(getattr(args, 'vars', '{}'))
cli_vars: Dict[str, Any] = parse_cli_vars(getattr(args, "vars", "{}"))
return cls(
project_name=project.project_name,
@@ -123,7 +129,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
dependencies=dependencies,
)
def new_project(self, project_root: str) -> 'RuntimeConfig':
def new_project(self, project_root: str) -> "RuntimeConfig":
"""Given a new project root, read in its project dictionary, supply the
existing project's profile info, and create a new project file.
@@ -142,7 +148,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
project = Project.from_project_root(
project_root,
renderer,
verify_version=getattr(self.args, 'version_check', False),
verify_version=getattr(self.args, "version_check", False),
)
cfg = self.from_parts(
@@ -165,7 +171,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
"""
result = self.to_project_config(with_packages=True)
result.update(self.to_profile_info(serialize_credentials=True))
result['cli_vars'] = deepcopy(self.cli_vars)
result["cli_vars"] = deepcopy(self.cli_vars)
return result
def validate(self):
@@ -185,30 +191,21 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
profile_renderer: ProfileRenderer,
profile_name: Optional[str],
) -> Profile:
return Profile.render_from_args(
args, profile_renderer, profile_name
)
return Profile.render_from_args(args, profile_renderer, profile_name)
@classmethod
def collect_parts(
cls: Type['RuntimeConfig'], args: Any
) -> Tuple[Project, Profile]:
def collect_parts(cls: Type["RuntimeConfig"], args: Any) -> Tuple[Project, Profile]:
# profile_name from the project
project_root = args.project_dir if args.project_dir else os.getcwd()
version_check = getattr(args, 'version_check', False)
partial = Project.partial_load(
project_root,
verify_version=version_check
)
version_check = getattr(args, "version_check", False)
partial = Project.partial_load(project_root, verify_version=version_check)
# build the profile using the base renderer and the one fact we know
cli_vars: Dict[str, Any] = parse_cli_vars(getattr(args, 'vars', '{}'))
cli_vars: Dict[str, Any] = parse_cli_vars(getattr(args, "vars", "{}"))
profile_renderer = ProfileRenderer(generate_base_context(cli_vars))
profile_name = partial.render_profile_name(profile_renderer)
profile = cls._get_rendered_profile(
args, profile_renderer, profile_name
)
profile = cls._get_rendered_profile(args, profile_renderer, profile_name)
# get a new renderer using our target information and render the
# project
@@ -218,7 +215,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
return (project, profile)
@classmethod
def from_args(cls, args: Any) -> 'RuntimeConfig':
def from_args(cls, args: Any) -> "RuntimeConfig":
"""Given arguments, read in dbt_project.yml from the current directory,
read in packages.yml if it exists, and use them to find the profile to
load.
@@ -238,8 +235,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
def get_metadata(self) -> ManifestMetadata:
return ManifestMetadata(
project_id=self.hashed_name(),
adapter_type=self.credentials.type
project_id=self.hashed_name(), adapter_type=self.credentials.type
)
def _get_v2_config_paths(
@@ -249,7 +245,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
paths: MutableSet[FQNPath],
) -> PathSet:
for key, value in config.items():
if isinstance(value, dict) and not key.startswith('+'):
if isinstance(value, dict) and not key.startswith("+"):
self._get_v2_config_paths(value, path + (key,), paths)
else:
paths.add(path)
@@ -265,7 +261,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
paths = set()
for key, value in config.items():
if isinstance(value, dict) and not key.startswith('+'):
if isinstance(value, dict) and not key.startswith("+"):
self._get_v2_config_paths(value, path + (key,), paths)
else:
paths.add(path)
@@ -277,10 +273,10 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
a configured path in the resource.
"""
return {
'models': self._get_config_paths(self.models),
'seeds': self._get_config_paths(self.seeds),
'snapshots': self._get_config_paths(self.snapshots),
'sources': self._get_config_paths(self.sources),
"models": self._get_config_paths(self.models),
"seeds": self._get_config_paths(self.seeds),
"snapshots": self._get_config_paths(self.snapshots),
"sources": self._get_config_paths(self.sources),
}
def get_unused_resource_config_paths(
@@ -301,9 +297,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
for config_path in config_paths:
if not _is_config_used(config_path, fqns):
unused_resource_config_paths.append(
(resource_type,) + config_path
)
unused_resource_config_paths.append((resource_type,) + config_path)
return unused_resource_config_paths
def warn_for_unused_resource_config_paths(
@@ -316,27 +310,25 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
return
msg = UNUSED_RESOURCE_CONFIGURATION_PATH_MESSAGE.format(
len(unused),
'\n'.join('- {}'.format('.'.join(u)) for u in unused)
len(unused), "\n".join("- {}".format(".".join(u)) for u in unused)
)
warn_or_error(msg, log_fmt=warning_tag('{}'))
warn_or_error(msg, log_fmt=warning_tag("{}"))
def load_dependencies(self) -> Mapping[str, 'RuntimeConfig']:
def load_dependencies(self) -> Mapping[str, "RuntimeConfig"]:
if self.dependencies is None:
all_projects = {self.project_name: self}
internal_packages = get_include_paths(self.credentials.type)
project_paths = itertools.chain(
internal_packages,
self._get_project_directories()
internal_packages, self._get_project_directories()
)
for project_name, project in self.load_projects(project_paths):
if project_name in all_projects:
raise_compiler_error(
f'dbt found more than one package with the name '
f"dbt found more than one package with the name "
f'"{project_name}" included in this project. Package '
f'names must be unique in a project. Please rename '
f'one of these packages.'
f"names must be unique in a project. Please rename "
f"one of these packages."
)
all_projects[project_name] = project
self.dependencies = all_projects
@@ -347,14 +339,14 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
def load_projects(
self, paths: Iterable[Path]
) -> Iterator[Tuple[str, 'RuntimeConfig']]:
) -> Iterator[Tuple[str, "RuntimeConfig"]]:
for path in paths:
try:
project = self.new_project(str(path))
except DbtProjectError as e:
raise DbtProjectError(
f'Failed to read package: {e}',
result_type='invalid_project',
f"Failed to read package: {e}",
result_type="invalid_project",
path=path,
) from e
else:
@@ -365,13 +357,13 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
if root.exists():
for path in root.iterdir():
if path.is_dir() and not path.name.startswith('__'):
if path.is_dir() and not path.name.startswith("__"):
yield path
class UnsetCredentials(Credentials):
def __init__(self):
super().__init__('', '')
super().__init__("", "")
@property
def type(self):
@@ -387,9 +379,7 @@ class UnsetCredentials(Credentials):
class UnsetConfig(UserConfig):
def __getattribute__(self, name):
if name in {f.name for f in fields(UserConfig)}:
raise AttributeError(
f"'UnsetConfig' object has no attribute {name}"
)
raise AttributeError(f"'UnsetConfig' object has no attribute {name}")
def __post_serialize__(self, dct):
return {}
@@ -399,15 +389,15 @@ class UnsetProfile(Profile):
def __init__(self):
self.credentials = UnsetCredentials()
self.config = UnsetConfig()
self.profile_name = ''
self.target_name = ''
self.profile_name = ""
self.target_name = ""
self.threads = -1
def to_target_dict(self):
return {}
def __getattribute__(self, name):
if name in {'profile_name', 'target_name', 'threads'}:
if name in {"profile_name", "target_name", "threads"}:
raise RuntimeException(
f'Error: disallowed attribute "{name}" - no profile!'
)
@@ -431,7 +421,7 @@ class UnsetProfileConfig(RuntimeConfig):
def __getattribute__(self, name):
# Override __getattribute__ to check that the attribute isn't 'banned'.
if name in {'profile_name', 'target_name'}:
if name in {"profile_name", "target_name"}:
raise RuntimeException(
f'Error: disallowed attribute "{name}" - no profile!'
)
@@ -449,8 +439,8 @@ class UnsetProfileConfig(RuntimeConfig):
project: Project,
profile: Profile,
args: Any,
dependencies: Optional[Mapping[str, 'RuntimeConfig']] = None,
) -> 'RuntimeConfig':
dependencies: Optional[Mapping[str, "RuntimeConfig"]] = None,
) -> "RuntimeConfig":
"""Instantiate a RuntimeConfig from its components.
:param profile: Ignored.
@@ -458,7 +448,7 @@ class UnsetProfileConfig(RuntimeConfig):
:param args: The parsed command-line arguments.
:returns RuntimeConfig: The new configuration.
"""
cli_vars: Dict[str, Any] = parse_cli_vars(getattr(args, 'vars', '{}'))
cli_vars: Dict[str, Any] = parse_cli_vars(getattr(args, "vars", "{}"))
return cls(
project_name=project.project_name,
@@ -491,10 +481,10 @@ class UnsetProfileConfig(RuntimeConfig):
vars=project.vars,
config_version=project.config_version,
unrendered=project.unrendered,
profile_name='',
target_name='',
profile_name="",
target_name="",
config=UnsetConfig(),
threads=getattr(args, 'threads', 1),
threads=getattr(args, "threads", 1),
credentials=UnsetCredentials(),
args=args,
cli_vars=cli_vars,
@@ -509,16 +499,11 @@ class UnsetProfileConfig(RuntimeConfig):
profile_name: Optional[str],
) -> Profile:
try:
profile = Profile.render_from_args(
args, profile_renderer, profile_name
)
profile = Profile.render_from_args(args, profile_renderer, profile_name)
except (DbtProjectError, DbtProfileError) as exc:
logger.debug(
'Profile not loaded due to error: {}', exc, exc_info=True
)
logger.debug("Profile not loaded due to error: {}", exc, exc_info=True)
logger.info(
'No profile "{}" found, continuing with no target',
profile_name
'No profile "{}" found, continuing with no target', profile_name
)
# return the poisoned form
profile = UnsetProfile()
@@ -527,7 +512,7 @@ class UnsetProfileConfig(RuntimeConfig):
return profile
@classmethod
def from_args(cls: Type[RuntimeConfig], args: Any) -> 'RuntimeConfig':
def from_args(cls: Type[RuntimeConfig], args: Any) -> "RuntimeConfig":
"""Given arguments, read in dbt_project.yml from the current directory,
read in packages.yml if it exists, and use them to find the profile to
load.
@@ -542,11 +527,7 @@ class UnsetProfileConfig(RuntimeConfig):
# if it's a real profile, return a real config
cls = RuntimeConfig
return cls.from_parts(
project=project,
profile=profile,
args=args
)
return cls.from_parts(project=project, profile=profile, args=args)
UNUSED_RESOURCE_CONFIGURATION_PATH_MESSAGE = """\
@@ -560,6 +541,6 @@ There are {} unused configuration paths:
def _is_config_used(path, fqns):
if fqns:
for fqn in fqns:
if len(path) <= len(fqn) and fqn[:len(path)] == path:
if len(path) <= len(fqn) and fqn[: len(path)] == path:
return True
return False

View File

@@ -1,8 +1,6 @@
from pathlib import Path
from typing import Dict, Any
from dbt.clients.yaml_helper import ( # noqa: F401
yaml, Loader, Dumper, load_yaml_text
)
from dbt.clients.yaml_helper import yaml, Loader, Dumper, load_yaml_text # noqa: F401
from dbt.dataclass_schema import ValidationError
from .renderer import SelectorRenderer
@@ -30,9 +28,8 @@ Validator Error:
class SelectorConfig(Dict[str, SelectionSpec]):
@classmethod
def selectors_from_dict(cls, data: Dict[str, Any]) -> 'SelectorConfig':
def selectors_from_dict(cls, data: Dict[str, Any]) -> "SelectorConfig":
try:
SelectorFile.validate(data)
selector_file = SelectorFile.from_dict(data)
@@ -45,12 +42,12 @@ class SelectorConfig(Dict[str, SelectionSpec]):
f"union, intersection, string, dictionary. No lists. "
f"\nhttps://docs.getdbt.com/reference/node-selection/"
f"yaml-selectors",
result_type='invalid_selector'
result_type="invalid_selector",
) from exc
except RuntimeException as exc:
raise DbtSelectorsError(
f'Could not read selector file data: {exc}',
result_type='invalid_selector',
f"Could not read selector file data: {exc}",
result_type="invalid_selector",
) from exc
return cls(selectors)
@@ -60,26 +57,28 @@ class SelectorConfig(Dict[str, SelectionSpec]):
cls,
data: Dict[str, Any],
renderer: SelectorRenderer,
) -> 'SelectorConfig':
) -> "SelectorConfig":
try:
rendered = renderer.render_data(data)
except (ValidationError, RuntimeException) as exc:
raise DbtSelectorsError(
f'Could not render selector data: {exc}',
result_type='invalid_selector',
f"Could not render selector data: {exc}",
result_type="invalid_selector",
) from exc
return cls.selectors_from_dict(rendered)
@classmethod
def from_path(
cls, path: Path, renderer: SelectorRenderer,
) -> 'SelectorConfig':
cls,
path: Path,
renderer: SelectorRenderer,
) -> "SelectorConfig":
try:
data = load_yaml_text(load_file_contents(str(path)))
except (ValidationError, RuntimeException) as exc:
raise DbtSelectorsError(
f'Could not read selector file: {exc}',
result_type='invalid_selector',
f"Could not read selector file: {exc}",
result_type="invalid_selector",
path=path,
) from exc
@@ -91,9 +90,7 @@ class SelectorConfig(Dict[str, SelectionSpec]):
def selector_data_from_root(project_root: str) -> Dict[str, Any]:
selector_filepath = resolve_path_from_base(
'selectors.yml', project_root
)
selector_filepath = resolve_path_from_base("selectors.yml", project_root)
if path_exists(selector_filepath):
selectors_dict = load_yaml_text(load_file_contents(selector_filepath))
@@ -102,18 +99,16 @@ def selector_data_from_root(project_root: str) -> Dict[str, Any]:
return selectors_dict
def selector_config_from_data(
selectors_data: Dict[str, Any]
) -> SelectorConfig:
def selector_config_from_data(selectors_data: Dict[str, Any]) -> SelectorConfig:
if not selectors_data:
selectors_data = {'selectors': []}
selectors_data = {"selectors": []}
try:
selectors = SelectorConfig.selectors_from_dict(selectors_data)
except ValidationError as e:
raise DbtSelectorsError(
MALFORMED_SELECTOR_ERROR.format(error=str(e.message)),
result_type='invalid_selector',
result_type="invalid_selector",
) from e
return selectors
@@ -125,7 +120,6 @@ def selector_config_from_data(
# be necessary to make changes here. Ideally it would be
# good to combine the two flows into one at some point.
class SelectorDict:
@classmethod
def parse_dict_definition(cls, definition):
key = list(definition)[0]
@@ -136,10 +130,10 @@ class SelectorDict:
new_value = cls.parse_from_definition(sel_def)
new_values.append(new_value)
value = new_values
if key == 'exclude':
if key == "exclude":
definition = {key: value}
elif len(definition) == 1:
definition = {'method': key, 'value': value}
definition = {"method": key, "value": value}
return definition
@classmethod
@@ -161,10 +155,10 @@ class SelectorDict:
def parse_from_definition(cls, definition):
if isinstance(definition, str):
definition = SelectionCriteria.dict_from_single_spec(definition)
elif 'union' in definition:
definition = cls.parse_a_definition('union', definition)
elif 'intersection' in definition:
definition = cls.parse_a_definition('intersection', definition)
elif "union" in definition:
definition = cls.parse_a_definition("union", definition)
elif "intersection" in definition:
definition = cls.parse_a_definition("intersection", definition)
elif isinstance(definition, dict):
definition = cls.parse_dict_definition(definition)
return definition
@@ -175,8 +169,8 @@ class SelectorDict:
def parse_from_selectors_list(cls, selectors):
selector_dict = {}
for selector in selectors:
sel_name = selector['name']
sel_name = selector["name"]
selector_dict[sel_name] = selector
definition = cls.parse_from_definition(selector['definition'])
selector_dict[sel_name]['definition'] = definition
definition = cls.parse_from_definition(selector["definition"])
selector_dict[sel_name]["definition"] = definition
return selector_dict

View File

@@ -15,9 +15,8 @@ def parse_cli_vars(var_string: str) -> Dict[str, Any]:
type_name = var_type.__name__
raise_compiler_error(
"The --vars argument must be a YAML dictionary, but was "
"of type '{}'".format(type_name))
"of type '{}'".format(type_name)
)
except ValidationException:
logger.error(
"The YAML provided in the --vars argument is not valid.\n"
)
logger.error("The YAML provided in the --vars argument is not valid.\n")
raise

View File

@@ -1,14 +1,16 @@
import json
import os
from typing import (
Any, Dict, NoReturn, Optional, Mapping
)
from typing import Any, Dict, NoReturn, Optional, Mapping
from dbt import flags
from dbt import tracking
from dbt.clients.jinja import undefined_error, get_rendered
from dbt.clients.yaml_helper import ( # noqa: F401
yaml, safe_load, SafeLoader, Loader, Dumper
yaml,
safe_load,
SafeLoader,
Loader,
Dumper,
)
from dbt.contracts.graph.compiled import CompiledResource
from dbt.exceptions import raise_compiler_error, MacroReturn
@@ -25,38 +27,26 @@ import re
def get_pytz_module_context() -> Dict[str, Any]:
context_exports = pytz.__all__ # type: ignore
return {
name: getattr(pytz, name) for name in context_exports
}
return {name: getattr(pytz, name) for name in context_exports}
def get_datetime_module_context() -> Dict[str, Any]:
context_exports = [
'date',
'datetime',
'time',
'timedelta',
'tzinfo'
]
context_exports = ["date", "datetime", "time", "timedelta", "tzinfo"]
return {
name: getattr(datetime, name) for name in context_exports
}
return {name: getattr(datetime, name) for name in context_exports}
def get_re_module_context() -> Dict[str, Any]:
context_exports = re.__all__
return {
name: getattr(re, name) for name in context_exports
}
return {name: getattr(re, name) for name in context_exports}
def get_context_modules() -> Dict[str, Dict[str, Any]]:
return {
'pytz': get_pytz_module_context(),
'datetime': get_datetime_module_context(),
're': get_re_module_context(),
"pytz": get_pytz_module_context(),
"datetime": get_datetime_module_context(),
"re": get_re_module_context(),
}
@@ -90,8 +80,8 @@ class ContextMeta(type):
new_dct = {}
for base in bases:
context_members.update(getattr(base, '_context_members_', {}))
context_attrs.update(getattr(base, '_context_attrs_', {}))
context_members.update(getattr(base, "_context_members_", {}))
context_attrs.update(getattr(base, "_context_attrs_", {}))
for key, value in dct.items():
if isinstance(value, ContextMember):
@@ -100,21 +90,22 @@ class ContextMeta(type):
context_attrs[context_key] = key
value = value.inner
new_dct[key] = value
new_dct['_context_members_'] = context_members
new_dct['_context_attrs_'] = context_attrs
new_dct["_context_members_"] = context_members
new_dct["_context_attrs_"] = context_attrs
return type.__new__(mcls, name, bases, new_dct)
class Var:
UndefinedVarError = "Required var '{}' not found in config:\nVars "\
"supplied to {} = {}"
UndefinedVarError = (
"Required var '{}' not found in config:\nVars " "supplied to {} = {}"
)
_VAR_NOTSET = object()
def __init__(
self,
context: Mapping[str, Any],
cli_vars: Mapping[str, Any],
node: Optional[CompiledResource] = None
node: Optional[CompiledResource] = None,
) -> None:
self._context: Mapping[str, Any] = context
self._cli_vars: Mapping[str, Any] = cli_vars
@@ -129,14 +120,12 @@ class Var:
if self._node is not None:
return self._node.name
else:
return '<Configuration>'
return "<Configuration>"
def get_missing_var(self, var_name):
dct = {k: self._merged[k] for k in self._merged}
pretty_vars = json.dumps(dct, sort_keys=True, indent=4)
msg = self.UndefinedVarError.format(
var_name, self.node_name, pretty_vars
)
msg = self.UndefinedVarError.format(var_name, self.node_name, pretty_vars)
raise_compiler_error(msg, self._node)
def has_var(self, var_name: str):
@@ -167,7 +156,7 @@ class BaseContext(metaclass=ContextMeta):
def generate_builtins(self):
builtins: Dict[str, Any] = {}
for key, value in self._context_members_.items():
if hasattr(value, '__get__'):
if hasattr(value, "__get__"):
# handle properties, bound methods, etc
value = value.__get__(self)
builtins[key] = value
@@ -175,9 +164,9 @@ class BaseContext(metaclass=ContextMeta):
# no dbtClassMixin so this is not an actual override
def to_dict(self):
self._ctx['context'] = self._ctx
self._ctx["context"] = self._ctx
builtins = self.generate_builtins()
self._ctx['builtins'] = builtins
self._ctx["builtins"] = builtins
self._ctx.update(builtins)
return self._ctx
@@ -286,18 +275,20 @@ class BaseContext(metaclass=ContextMeta):
msg = f"Env var required but not provided: '{var}'"
undefined_error(msg)
if os.environ.get('DBT_MACRO_DEBUGGING'):
if os.environ.get("DBT_MACRO_DEBUGGING"):
@contextmember
@staticmethod
def debug():
"""Enter a debugger at this line in the compiled jinja code."""
import sys
import ipdb # type: ignore
frame = sys._getframe(3)
ipdb.set_trace(frame)
return ''
return ""
@contextmember('return')
@contextmember("return")
@staticmethod
def _return(data: Any) -> NoReturn:
"""The `return` function can be used in macros to return data to the
@@ -348,9 +339,7 @@ class BaseContext(metaclass=ContextMeta):
@contextmember
@staticmethod
def tojson(
value: Any, default: Any = None, sort_keys: bool = False
) -> Any:
def tojson(value: Any, default: Any = None, sort_keys: bool = False) -> Any:
"""The `tojson` context method can be used to serialize a Python
object primitive, eg. a `dict` or `list` to a json string.
@@ -446,7 +435,7 @@ class BaseContext(metaclass=ContextMeta):
logger.info(msg)
else:
logger.debug(msg)
return ''
return ""
@contextproperty
def run_started_at(self) -> Optional[datetime.datetime]:

View File

@@ -4,16 +4,14 @@ from dbt.contracts.connection import AdapterRequiredConfig
from dbt.node_types import NodeType
from dbt.utils import MultiDict
from dbt.context.base import contextproperty, Var
from dbt.context import contextproperty, Var
from dbt.context.target import TargetContext
class ConfiguredContext(TargetContext):
config: AdapterRequiredConfig
def __init__(
self, config: AdapterRequiredConfig
) -> None:
def __init__(self, config: AdapterRequiredConfig) -> None:
super().__init__(config, config.cli_vars)
@contextproperty
@@ -70,9 +68,7 @@ class SchemaYamlContext(ConfiguredContext):
@contextproperty
def var(self) -> ConfiguredVar:
return ConfiguredVar(
self._ctx, self.config, self._project_name
)
return ConfiguredVar(self._ctx, self.config, self._project_name)
def generate_schema_yml(

View File

@@ -17,8 +17,8 @@ class ModelParts(IsFQNResource):
package_name: str
T = TypeVar('T') # any old type
C = TypeVar('C', bound=BaseConfig)
T = TypeVar("T") # any old type
C = TypeVar("C", bound=BaseConfig)
class ConfigSource:
@@ -36,13 +36,13 @@ class UnrenderedConfig(ConfigSource):
def get_config_dict(self, resource_type: NodeType) -> Dict[str, Any]:
unrendered = self.project.unrendered.project_dict
if resource_type == NodeType.Seed:
model_configs = unrendered.get('seeds')
model_configs = unrendered.get("seeds")
elif resource_type == NodeType.Snapshot:
model_configs = unrendered.get('snapshots')
model_configs = unrendered.get("snapshots")
elif resource_type == NodeType.Source:
model_configs = unrendered.get('sources')
model_configs = unrendered.get("sources")
else:
model_configs = unrendered.get('models')
model_configs = unrendered.get("models")
if model_configs is None:
return {}
@@ -79,8 +79,8 @@ class BaseContextConfigGenerator(Generic[T]):
dependencies = self._active_project.load_dependencies()
if project_name not in dependencies:
raise InternalException(
f'Project name {project_name} not found in dependencies '
f'(found {list(dependencies)})'
f"Project name {project_name} not found in dependencies "
f"(found {list(dependencies)})"
)
return dependencies[project_name]
@@ -92,7 +92,7 @@ class BaseContextConfigGenerator(Generic[T]):
for level_config in fqn_search(model_configs, fqn):
result = {}
for key, value in level_config.items():
if key.startswith('+'):
if key.startswith("+"):
result[key[1:]] = deepcopy(value)
elif not isinstance(value, dict):
result[key] = deepcopy(value)
@@ -171,13 +171,9 @@ class ContextConfigGenerator(BaseContextConfigGenerator[C]):
def _update_from_config(
self, result: C, partial: Dict[str, Any], validate: bool = False
) -> C:
translated = self._active_project.credentials.translate_aliases(
partial
)
translated = self._active_project.credentials.translate_aliases(partial)
return result.update_from(
translated,
self._active_project.credentials.type,
validate=validate
translated, self._active_project.credentials.type, validate=validate
)
def calculate_node_config_dict(
@@ -219,11 +215,7 @@ class UnrenderedConfigGenerator(BaseContextConfigGenerator[Dict[str, Any]]):
base=base,
)
def initial_result(
self,
resource_type: NodeType,
base: bool
) -> Dict[str, Any]:
def initial_result(self, resource_type: NodeType, base: bool) -> Dict[str, Any]:
return {}
def _update_from_config(
@@ -232,9 +224,7 @@ class UnrenderedConfigGenerator(BaseContextConfigGenerator[Dict[str, Any]]):
partial: Dict[str, Any],
validate: bool = False,
) -> Dict[str, Any]:
translated = self._active_project.credentials.translate_aliases(
partial
)
translated = self._active_project.credentials.translate_aliases(partial)
result.update(translated)
return result

View File

@@ -1,6 +1,4 @@
from typing import (
Any, Dict, Union
)
from typing import Any, Dict, Union
from dbt.exceptions import (
doc_invalid_args,
@@ -11,7 +9,7 @@ from dbt.contracts.graph.compiled import CompileResultNode
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.parsed import ParsedMacro
from dbt.context.base import contextmember
from dbt.context import contextmember
from dbt.context.configured import SchemaYamlContext

View File

@@ -1,6 +1,4 @@
from typing import (
Dict, MutableMapping, Optional
)
from typing import Dict, MutableMapping, Optional
from dbt.contracts.graph.parsed import ParsedMacro
from dbt.exceptions import raise_duplicate_macro_name, raise_compiler_error
from dbt.include.global_project import PROJECT_NAME as GLOBAL_PROJECT_NAME
@@ -45,8 +43,7 @@ class MacroResolver:
for pkg in reversed(self.internal_package_names):
if pkg in self.internal_packages:
# Turn the internal packages into a flat namespace
self.internal_packages_namespace.update(
self.internal_packages[pkg])
self.internal_packages_namespace.update(self.internal_packages[pkg])
def _build_macros_by_name(self):
macros_by_name = {}
@@ -74,9 +71,7 @@ class MacroResolver:
package_namespaces[macro.package_name] = namespace
if macro.name in namespace:
raise_duplicate_macro_name(
macro, macro, macro.package_name
)
raise_duplicate_macro_name(macro, macro, macro.package_name)
package_namespaces[macro.package_name][macro.name] = macro
def add_macro(self, macro: ParsedMacro):
@@ -99,8 +94,10 @@ class MacroResolver:
def get_macro_id(self, local_package, macro_name):
local_package_macros = {}
if (local_package not in self.internal_package_names and
local_package in self.packages):
if (
local_package not in self.internal_package_names
and local_package in self.packages
):
local_package_macros = self.packages[local_package]
# First: search the local packages for this macro
if macro_name in local_package_macros:
@@ -117,9 +114,7 @@ class MacroResolver:
# is that you can limit the number of macros provided to the
# context dictionary in the 'to_dict' manifest method.
class TestMacroNamespace:
def __init__(
self, macro_resolver, ctx, node, thread_ctx, depends_on_macros
):
def __init__(self, macro_resolver, ctx, node, thread_ctx, depends_on_macros):
self.macro_resolver = macro_resolver
self.ctx = ctx
self.node = node
@@ -129,7 +124,10 @@ class TestMacroNamespace:
for macro_unique_id in depends_on_macros:
macro = self.manifest.macros[macro_unique_id]
local_namespace[macro.name] = MacroGenerator(
macro, self.ctx, self.node, self.thread_ctx,
macro,
self.ctx,
self.node,
self.thread_ctx,
)
self.local_namespace = local_namespace
@@ -144,10 +142,6 @@ class TestMacroNamespace:
elif package_name in self.resolver.packages:
macro = self.macro_resolver.packages[package_name].get(name)
else:
raise_compiler_error(
f"Could not find package '{package_name}'"
)
macro_func = MacroGenerator(
macro, self.ctx, self.node, self.thread_ctx
)
raise_compiler_error(f"Could not find package '{package_name}'")
macro_func = MacroGenerator(macro, self.ctx, self.node, self.thread_ctx)
return macro_func

View File

@@ -1,13 +1,9 @@
from typing import (
Any, Dict, Iterable, Union, Optional, List, Iterator, Mapping, Set
)
from typing import Any, Dict, Iterable, Union, Optional, List, Iterator, Mapping, Set
from dbt.clients.jinja import MacroGenerator, MacroStack
from dbt.contracts.graph.parsed import ParsedMacro
from dbt.include.global_project import PROJECT_NAME as GLOBAL_PROJECT_NAME
from dbt.exceptions import (
raise_duplicate_macro_name, raise_compiler_error
)
from dbt.exceptions import raise_duplicate_macro_name, raise_compiler_error
FlatNamespace = Dict[str, MacroGenerator]
@@ -75,9 +71,7 @@ class MacroNamespace(Mapping):
elif package_name in self.packages:
return self.packages[package_name].get(name)
else:
raise_compiler_error(
f"Could not find package '{package_name}'"
)
raise_compiler_error(f"Could not find package '{package_name}'")
# This class builds the MacroNamespace by adding macros to
@@ -122,9 +116,7 @@ class MacroNamespaceBuilder:
hierarchy[macro.package_name] = namespace
if macro.name in namespace:
raise_duplicate_macro_name(
macro_func.macro, macro, macro.package_name
)
raise_duplicate_macro_name(macro_func.macro, macro, macro.package_name)
hierarchy[macro.package_name][macro.name] = macro_func
def add_macro(self, macro: ParsedMacro, ctx: Dict[str, Any]):

View File

@@ -17,6 +17,7 @@ class ManifestContext(ConfiguredContext):
The given macros can override any previous context values, which will be
available as if they were accessed relative to the package name.
"""
def __init__(
self,
config: AdapterRequiredConfig,
@@ -37,13 +38,12 @@ class ManifestContext(ConfiguredContext):
# this takes all the macros in the manifest and adds them
# to the MacroNamespaceBuilder stored in self.namespace
builder = self._get_namespace_builder()
return builder.build_namespace(
self.manifest.macros.values(), self._ctx
)
return builder.build_namespace(self.manifest.macros.values(), self._ctx)
def _get_namespace_builder(self) -> MacroNamespaceBuilder:
# avoid an import loop
from dbt.adapters.factory import get_adapter_package_names
internal_packages: List[str] = get_adapter_package_names(
self.config.credentials.type
)
@@ -68,14 +68,10 @@ class ManifestContext(ConfiguredContext):
class QueryHeaderContext(ManifestContext):
def __init__(
self, config: AdapterRequiredConfig, manifest: Manifest
) -> None:
def __init__(self, config: AdapterRequiredConfig, manifest: Manifest) -> None:
super().__init__(config, manifest, config.project_name)
def generate_query_header_context(
config: AdapterRequiredConfig, manifest: Manifest
):
def generate_query_header_context(config: AdapterRequiredConfig, manifest: Manifest):
ctx = QueryHeaderContext(config, manifest)
return ctx.to_dict()

View File

@@ -1,7 +1,15 @@
import abc
import os
from typing import (
Callable, Any, Dict, Optional, Union, List, TypeVar, Type, Iterable,
Callable,
Any,
Dict,
Optional,
Union,
List,
TypeVar,
Type,
Iterable,
Mapping,
)
from typing_extensions import Protocol
@@ -12,16 +20,14 @@ from dbt.adapters.factory import get_adapter, get_adapter_package_names
from dbt.clients import agate_helper
from dbt.clients.jinja import get_rendered, MacroGenerator, MacroStack
from dbt.config import RuntimeConfig, Project
from .base import contextmember, contextproperty, Var
from dbt.context import contextmember, contextproperty, Var
from .configured import FQNLookup
from .context_config import ContextConfig
from dbt.context.macro_resolver import MacroResolver, TestMacroNamespace
from .macros import MacroNamespaceBuilder, MacroNamespace
from .manifest import ManifestContext
from dbt.contracts.connection import AdapterResponse
from dbt.contracts.graph.manifest import (
Manifest, AnyManifest, Disabled, MacroManifest
)
from dbt.contracts.graph.manifest import Manifest, AnyManifest, Disabled, MacroManifest
from dbt.contracts.graph.compiled import (
CompiledResource,
CompiledSeedNode,
@@ -50,9 +56,7 @@ from dbt.config import IsFQNResource
from dbt.logger import GLOBAL_LOGGER as logger # noqa
from dbt.node_types import NodeType
from dbt.utils import (
merge, AttrDict, MultiDict
)
from dbt.utils import merge, AttrDict, MultiDict
import agate
@@ -75,9 +79,8 @@ class RelationProxy:
return self._relation_type.create_from_source(*args, **kwargs)
def create(self, *args, **kwargs):
kwargs['quote_policy'] = merge(
self._quoting_config,
kwargs.pop('quote_policy', {})
kwargs["quote_policy"] = merge(
self._quoting_config, kwargs.pop("quote_policy", {})
)
return self._relation_type.create(*args, **kwargs)
@@ -94,7 +97,7 @@ class BaseDatabaseWrapper:
self._namespace = namespace
def __getattr__(self, name):
raise NotImplementedError('subclasses need to implement this')
raise NotImplementedError("subclasses need to implement this")
@property
def config(self):
@@ -110,7 +113,7 @@ class BaseDatabaseWrapper:
# a future version of this could have plugins automatically call fall
# back to their dependencies' dependencies by using
# `get_adapter_type_names` instead of `[self.config.credentials.type]`
search_prefixes = [self._adapter.type(), 'default']
search_prefixes = [self._adapter.type(), "default"]
return search_prefixes
def dispatch(
@@ -118,8 +121,8 @@ class BaseDatabaseWrapper:
) -> MacroGenerator:
search_packages: List[Optional[str]]
if '.' in macro_name:
suggest_package, suggest_macro_name = macro_name.split('.', 1)
if "." in macro_name:
suggest_package, suggest_macro_name = macro_name.split(".", 1)
msg = (
f'In adapter.dispatch, got a macro name of "{macro_name}", '
f'but "." is not a valid macro name component. Did you mean '
@@ -132,7 +135,7 @@ class BaseDatabaseWrapper:
search_packages = [None]
elif isinstance(packages, str):
raise CompilationException(
f'In adapter.dispatch, got a string packages argument '
f"In adapter.dispatch, got a string packages argument "
f'("{packages}"), but packages should be None or a list.'
)
else:
@@ -142,26 +145,24 @@ class BaseDatabaseWrapper:
for package_name in search_packages:
for prefix in self._get_adapter_macro_prefixes():
search_name = f'{prefix}__{macro_name}'
search_name = f"{prefix}__{macro_name}"
try:
# this uses the namespace from the context
macro = self._namespace.get_from_package(
package_name, search_name
)
macro = self._namespace.get_from_package(package_name, search_name)
except CompilationException as exc:
raise CompilationException(
f'In dispatch: {exc.msg}',
f"In dispatch: {exc.msg}",
) from exc
if package_name is None:
attempts.append(search_name)
else:
attempts.append(f'{package_name}.{search_name}')
attempts.append(f"{package_name}.{search_name}")
if macro is not None:
return macro
searched = ', '.join(repr(a) for a in attempts)
searched = ", ".join(repr(a) for a in attempts)
msg = (
f"In dispatch: No macro named '{macro_name}' found\n"
f" Searched for: {searched}"
@@ -191,14 +192,10 @@ class BaseResolver(metaclass=abc.ABCMeta):
class BaseRefResolver(BaseResolver):
@abc.abstractmethod
def resolve(
self, name: str, package: Optional[str] = None
) -> RelationProxy:
def resolve(self, name: str, package: Optional[str] = None) -> RelationProxy:
...
def _repack_args(
self, name: str, package: Optional[str]
) -> List[str]:
def _repack_args(self, name: str, package: Optional[str]) -> List[str]:
if package is None:
return [name]
else:
@@ -207,14 +204,13 @@ class BaseRefResolver(BaseResolver):
def validate_args(self, name: str, package: Optional[str]):
if not isinstance(name, str):
raise CompilationException(
f'The name argument to ref() must be a string, got '
f'{type(name)}'
f"The name argument to ref() must be a string, got " f"{type(name)}"
)
if package is not None and not isinstance(package, str):
raise CompilationException(
f'The package argument to ref() must be a string or None, got '
f'{type(package)}'
f"The package argument to ref() must be a string or None, got "
f"{type(package)}"
)
def __call__(self, *args: str) -> RelationProxy:
@@ -239,20 +235,19 @@ class BaseSourceResolver(BaseResolver):
def validate_args(self, source_name: str, table_name: str):
if not isinstance(source_name, str):
raise CompilationException(
f'The source name (first) argument to source() must be a '
f'string, got {type(source_name)}'
f"The source name (first) argument to source() must be a "
f"string, got {type(source_name)}"
)
if not isinstance(table_name, str):
raise CompilationException(
f'The table name (second) argument to source() must be a '
f'string, got {type(table_name)}'
f"The table name (second) argument to source() must be a "
f"string, got {type(table_name)}"
)
def __call__(self, *args: str) -> RelationProxy:
if len(args) != 2:
raise_compiler_error(
f"source() takes exactly two arguments ({len(args)} given)",
self.model
f"source() takes exactly two arguments ({len(args)} given)", self.model
)
self.validate_args(args[0], args[1])
return self.resolve(args[0], args[1])
@@ -270,14 +265,15 @@ class ParseConfigObject(Config):
self.context_config = context_config
def _transform_config(self, config):
for oldkey in ('pre_hook', 'post_hook'):
for oldkey in ("pre_hook", "post_hook"):
if oldkey in config:
newkey = oldkey.replace('_', '-')
newkey = oldkey.replace("_", "-")
if newkey in config:
raise_compiler_error(
'Invalid config, has conflicting keys "{}" and "{}"'
.format(oldkey, newkey),
self.model
'Invalid config, has conflicting keys "{}" and "{}"'.format(
oldkey, newkey
),
self.model,
)
config[newkey] = config.pop(oldkey)
return config
@@ -288,29 +284,25 @@ class ParseConfigObject(Config):
elif len(args) == 0 and len(kwargs) > 0:
opts = kwargs
else:
raise_compiler_error(
"Invalid inline model config",
self.model)
raise_compiler_error("Invalid inline model config", self.model)
opts = self._transform_config(opts)
# it's ok to have a parse context with no context config, but you must
# not call it!
if self.context_config is None:
raise RuntimeException(
'At parse time, did not receive a context config'
)
raise RuntimeException("At parse time, did not receive a context config")
self.context_config.update_in_model_config(opts)
return ''
return ""
def set(self, name, value):
return self.__call__({name: value})
def require(self, name, validator=None):
return ''
return ""
def get(self, name, validator=None, default=None):
return ''
return ""
def persist_relation_docs(self) -> bool:
return False
@@ -320,14 +312,12 @@ class ParseConfigObject(Config):
class RuntimeConfigObject(Config):
def __init__(
self, model, context_config: Optional[ContextConfig] = None
):
def __init__(self, model, context_config: Optional[ContextConfig] = None):
self.model = model
# we never use or get a config, only the parser cares
def __call__(self, *args, **kwargs):
return ''
return ""
def set(self, name, value):
return self.__call__({name: value})
@@ -337,7 +327,7 @@ class RuntimeConfigObject(Config):
def _lookup(self, name, default=_MISSING):
# if this is a macro, there might be no `model.config`.
if not hasattr(self.model, 'config'):
if not hasattr(self.model, "config"):
result = default
else:
result = self.model.config.get(name, default)
@@ -362,22 +352,24 @@ class RuntimeConfigObject(Config):
return to_return
def persist_relation_docs(self) -> bool:
persist_docs = self.get('persist_docs', default={})
persist_docs = self.get("persist_docs", default={})
if not isinstance(persist_docs, dict):
raise_compiler_error(
f"Invalid value provided for 'persist_docs'. Expected dict "
f"but received {type(persist_docs)}")
f"but received {type(persist_docs)}"
)
return persist_docs.get('relation', False)
return persist_docs.get("relation", False)
def persist_column_docs(self) -> bool:
persist_docs = self.get('persist_docs', default={})
persist_docs = self.get("persist_docs", default={})
if not isinstance(persist_docs, dict):
raise_compiler_error(
f"Invalid value provided for 'persist_docs'. Expected dict "
f"but received {type(persist_docs)}")
f"but received {type(persist_docs)}"
)
return persist_docs.get('columns', False)
return persist_docs.get("columns", False)
# `adapter` implementations
@@ -387,8 +379,10 @@ class ParseDatabaseWrapper(BaseDatabaseWrapper):
"""
def __getattr__(self, name):
override = (name in self._adapter._available_ and
name in self._adapter._parse_replacements_)
override = (
name in self._adapter._available_
and name in self._adapter._parse_replacements_
)
if override:
return self._adapter._parse_replacements_[name]
@@ -420,9 +414,7 @@ class RuntimeDatabaseWrapper(BaseDatabaseWrapper):
# `ref` implementations
class ParseRefResolver(BaseRefResolver):
def resolve(
self, name: str, package: Optional[str] = None
) -> RelationProxy:
def resolve(self, name: str, package: Optional[str] = None) -> RelationProxy:
self.model.refs.append(self._repack_args(name, package))
return self.Relation.create_from(self.config, self.model)
@@ -452,22 +444,15 @@ class RuntimeRefResolver(BaseRefResolver):
self.validate(target_model, target_name, target_package)
return self.create_relation(target_model, target_name)
def create_relation(
self, target_model: ManifestNode, name: str
) -> RelationProxy:
def create_relation(self, target_model: ManifestNode, name: str) -> RelationProxy:
if target_model.is_ephemeral_model:
self.model.set_cte(target_model.unique_id, None)
return self.Relation.create_ephemeral_from_node(
self.config, target_model
)
return self.Relation.create_ephemeral_from_node(self.config, target_model)
else:
return self.Relation.create_from(self.config, target_model)
def validate(
self,
resolved: ManifestNode,
target_name: str,
target_package: Optional[str]
self, resolved: ManifestNode, target_name: str, target_package: Optional[str]
) -> None:
if resolved.unique_id not in self.model.depends_on.nodes:
args = self._repack_args(target_name, target_package)
@@ -483,16 +468,15 @@ class OperationRefResolver(RuntimeRefResolver):
) -> None:
pass
def create_relation(
self, target_model: ManifestNode, name: str
) -> RelationProxy:
def create_relation(self, target_model: ManifestNode, name: str) -> RelationProxy:
if target_model.is_ephemeral_model:
# In operations, we can't ref() ephemeral nodes, because
# ParsedMacros do not support set_cte
raise_compiler_error(
'Operations can not ref() ephemeral nodes, but {} is ephemeral'
.format(target_model.name),
self.model
"Operations can not ref() ephemeral nodes, but {} is ephemeral".format(
target_model.name
),
self.model,
)
else:
return super().create_relation(target_model, name)
@@ -544,8 +528,7 @@ class ModelConfiguredVar(Var):
if package_name not in dependencies:
# I don't think this is actually reachable
raise_compiler_error(
f'Node package named {package_name} not found!',
self._node
f"Node package named {package_name} not found!", self._node
)
yield dependencies[package_name]
yield self._config
@@ -617,7 +600,7 @@ class OperationProvider(RuntimeProvider):
ref = OperationRefResolver
T = TypeVar('T')
T = TypeVar("T")
# Base context collection, used for parsing configs.
@@ -631,9 +614,7 @@ class ProviderContext(ManifestContext):
context_config: Optional[ContextConfig],
) -> None:
if provider is None:
raise InternalException(
f"Invalid provider given to context: {provider}"
)
raise InternalException(f"Invalid provider given to context: {provider}")
# mypy appeasement - we know it'll be a RuntimeConfig
self.config: RuntimeConfig
self.model: Union[ParsedMacro, ManifestNode] = model
@@ -643,16 +624,12 @@ class ProviderContext(ManifestContext):
self.provider: Provider = provider
self.adapter = get_adapter(self.config)
# The macro namespace is used in creating the DatabaseWrapper
self.db_wrapper = self.provider.DatabaseWrapper(
self.adapter, self.namespace
)
self.db_wrapper = self.provider.DatabaseWrapper(self.adapter, self.namespace)
# This overrides the method in ManifestContext, and provides
# a model, which the ManifestContext builder does not
def _get_namespace_builder(self):
internal_packages = get_adapter_package_names(
self.config.credentials.type
)
internal_packages = get_adapter_package_names(self.config.credentials.type)
return MacroNamespaceBuilder(
self.config.project_name,
self.search_package,
@@ -671,19 +648,19 @@ class ProviderContext(ManifestContext):
@contextmember
def store_result(
self, name: str,
response: Any,
agate_table: Optional[agate.Table] = None
self, name: str, response: Any, agate_table: Optional[agate.Table] = None
) -> str:
if agate_table is None:
agate_table = agate_helper.empty_table()
self.sql_results[name] = AttrDict({
'response': response,
'data': agate_helper.as_matrix(agate_table),
'table': agate_table
})
return ''
self.sql_results[name] = AttrDict(
{
"response": response,
"data": agate_helper.as_matrix(agate_table),
"table": agate_table,
}
)
return ""
@contextmember
def store_raw_result(
@@ -692,10 +669,11 @@ class ProviderContext(ManifestContext):
message=Optional[str],
code=Optional[str],
rows_affected=Optional[str],
agate_table: Optional[agate.Table] = None
agate_table: Optional[agate.Table] = None,
) -> str:
response = AdapterResponse(
_message=message, code=code, rows_affected=rows_affected)
_message=message, code=code, rows_affected=rows_affected
)
return self.store_result(name, response, agate_table)
@contextproperty
@@ -708,25 +686,28 @@ class ProviderContext(ManifestContext):
elif value == arg:
return
raise ValidationException(
'Expected value "{}" to be one of {}'
.format(value, ','.join(map(str, args))))
'Expected value "{}" to be one of {}'.format(
value, ",".join(map(str, args))
)
)
return inner
return AttrDict({
'any': validate_any,
})
return AttrDict(
{
"any": validate_any,
}
)
@contextmember
def write(self, payload: str) -> str:
# macros/source defs aren't 'writeable'.
if isinstance(self.model, (ParsedMacro, ParsedSourceDefinition)):
raise_compiler_error(
'cannot "write" macros or sources'
)
raise_compiler_error('cannot "write" macros or sources')
self.model.build_path = self.model.write_node(
self.config.target_path, 'run', payload
self.config.target_path, "run", payload
)
return ''
return ""
@contextmember
def render(self, string: str) -> str:
@@ -739,20 +720,17 @@ class ProviderContext(ManifestContext):
try:
return func(*args, **kwargs)
except Exception:
raise_compiler_error(
message_if_exception, self.model
)
raise_compiler_error(message_if_exception, self.model)
@contextmember
def load_agate_table(self) -> agate.Table:
if not isinstance(self.model, (ParsedSeedNode, CompiledSeedNode)):
raise_compiler_error(
'can only load_agate_table for seeds (got a {})'
.format(self.model.resource_type)
"can only load_agate_table for seeds (got a {})".format(
self.model.resource_type
)
)
path = os.path.join(
self.model.root_path, self.model.original_file_path
)
path = os.path.join(self.model.root_path, self.model.original_file_path)
column_types = self.model.config.column_types
try:
table = agate_helper.from_csv(path, text_columns=column_types)
@@ -810,7 +788,7 @@ class ProviderContext(ManifestContext):
self.db_wrapper, self.model, self.config, self.manifest
)
@contextproperty('config')
@contextproperty("config")
def ctx_config(self) -> Config:
"""The `config` variable exists to handle end-user configuration for
custom materializations. Configs like `unique_key` can be implemented
@@ -982,7 +960,7 @@ class ProviderContext(ManifestContext):
node=self.model,
)
@contextproperty('adapter')
@contextproperty("adapter")
def ctx_adapter(self) -> BaseDatabaseWrapper:
"""`adapter` is a wrapper around the internal database adapter used by
dbt. It allows users to make calls to the database in their dbt models.
@@ -994,8 +972,8 @@ class ProviderContext(ManifestContext):
@contextproperty
def api(self) -> Dict[str, Any]:
return {
'Relation': self.db_wrapper.Relation,
'Column': self.adapter.Column,
"Relation": self.db_wrapper.Relation,
"Column": self.adapter.Column,
}
@contextproperty
@@ -1113,7 +1091,7 @@ class ProviderContext(ManifestContext):
""" # noqa
return self.manifest.flat_graph
@contextproperty('model')
@contextproperty("model")
def ctx_model(self) -> Dict[str, Any]:
return self.model.to_dict(omit_none=True)
@@ -1177,22 +1155,20 @@ class ProviderContext(ManifestContext):
...
{%- endmacro %}
"""
deprecations.warn('adapter-macro', macro_name=name)
deprecations.warn("adapter-macro", macro_name=name)
original_name = name
package_names: Optional[List[str]] = None
if '.' in name:
package_name, name = name.split('.', 1)
if "." in name:
package_name, name = name.split(".", 1)
package_names = [package_name]
try:
macro = self.db_wrapper.dispatch(
macro_name=name, packages=package_names
)
macro = self.db_wrapper.dispatch(macro_name=name, packages=package_names)
except CompilationException as exc:
raise CompilationException(
f'In adapter_macro: {exc.msg}\n'
f"In adapter_macro: {exc.msg}\n"
f" Original name: '{original_name}'",
node=self.model
node=self.model,
) from exc
return macro(*args, **kwargs)
@@ -1200,10 +1176,10 @@ class ProviderContext(ManifestContext):
class MacroContext(ProviderContext):
"""Internally, macros can be executed like nodes, with some restrictions:
- they don't have have all values available that nodes do:
- 'this', 'pre_hooks', 'post_hooks', and 'sql' are missing
- 'schema' does not use any 'model' information
- they can't be configured with config() directives
- they don't have have all values available that nodes do:
- 'this', 'pre_hooks', 'post_hooks', and 'sql' are missing
- 'schema' does not use any 'model' information
- they can't be configured with config() directives
"""
def __init__(
@@ -1230,35 +1206,27 @@ class ModelContext(ProviderContext):
def pre_hooks(self) -> List[Dict[str, Any]]:
if isinstance(self.model, ParsedSourceDefinition):
return []
return [
h.to_dict(omit_none=True) for h in self.model.config.pre_hook
]
return [h.to_dict(omit_none=True) for h in self.model.config.pre_hook]
@contextproperty
def post_hooks(self) -> List[Dict[str, Any]]:
if isinstance(self.model, ParsedSourceDefinition):
return []
return [
h.to_dict(omit_none=True) for h in self.model.config.post_hook
]
return [h.to_dict(omit_none=True) for h in self.model.config.post_hook]
@contextproperty
def sql(self) -> Optional[str]:
if getattr(self.model, 'extra_ctes_injected', None):
if getattr(self.model, "extra_ctes_injected", None):
return self.model.compiled_sql
return None
@contextproperty
def database(self) -> str:
return getattr(
self.model, 'database', self.config.credentials.database
)
return getattr(self.model, "database", self.config.credentials.database)
@contextproperty
def schema(self) -> str:
return getattr(
self.model, 'schema', self.config.credentials.schema
)
return getattr(self.model, "schema", self.config.credentials.schema)
@contextproperty
def this(self) -> Optional[RelationProxy]:
@@ -1306,9 +1274,7 @@ def generate_parser_model(
# The __init__ method of ModelContext also initializes
# a ManifestContext object which creates a MacroNamespaceBuilder
# which adds every macro in the Manifest.
ctx = ModelContext(
model, config, manifest, ParseProvider(), context_config
)
ctx = ModelContext(model, config, manifest, ParseProvider(), context_config)
# The 'to_dict' method in ManifestContext moves all of the macro names
# in the macro 'namespace' up to top level keys
return ctx.to_dict()
@@ -1319,9 +1285,7 @@ def generate_generate_component_name_macro(
config: RuntimeConfig,
manifest: MacroManifest,
) -> Dict[str, Any]:
ctx = MacroContext(
macro, config, manifest, GenerateNameProvider(), None
)
ctx = MacroContext(macro, config, manifest, GenerateNameProvider(), None)
return ctx.to_dict()
@@ -1330,9 +1294,7 @@ def generate_runtime_model(
config: RuntimeConfig,
manifest: Manifest,
) -> Dict[str, Any]:
ctx = ModelContext(
model, config, manifest, RuntimeProvider(), None
)
ctx = ModelContext(model, config, manifest, RuntimeProvider(), None)
return ctx.to_dict()
@@ -1342,9 +1304,7 @@ def generate_runtime_macro(
manifest: Manifest,
package_name: Optional[str],
) -> Dict[str, Any]:
ctx = MacroContext(
macro, config, manifest, OperationProvider(), package_name
)
ctx = MacroContext(macro, config, manifest, OperationProvider(), package_name)
return ctx.to_dict()
@@ -1353,18 +1313,17 @@ class ExposureRefResolver(BaseResolver):
if len(args) not in (1, 2):
ref_invalid_args(self.model, args)
self.model.refs.append(list(args))
return ''
return ""
class ExposureSourceResolver(BaseResolver):
def __call__(self, *args) -> str:
if len(args) != 2:
raise_compiler_error(
f"source() takes exactly two arguments ({len(args)} given)",
self.model
f"source() takes exactly two arguments ({len(args)} given)", self.model
)
self.model.sources.append(list(args))
return ''
return ""
def generate_parse_exposure(
@@ -1375,18 +1334,18 @@ def generate_parse_exposure(
) -> Dict[str, Any]:
project = config.load_dependencies()[package_name]
return {
'ref': ExposureRefResolver(
"ref": ExposureRefResolver(
None,
exposure,
project,
manifest,
),
'source': ExposureSourceResolver(
"source": ExposureSourceResolver(
None,
exposure,
project,
manifest,
)
),
}
@@ -1422,8 +1381,7 @@ class TestContext(ProviderContext):
if self.model.depends_on and self.model.depends_on.macros:
depends_on_macros = self.model.depends_on.macros
macro_namespace = TestMacroNamespace(
self.macro_resolver, self.ctx, self.node, self.thread_ctx,
depends_on_macros
self.macro_resolver, self.ctx, self.node, self.thread_ctx, depends_on_macros
)
self._namespace = macro_namespace
@@ -1433,11 +1391,10 @@ def generate_test_context(
config: RuntimeConfig,
manifest: Manifest,
context_config: ContextConfig,
macro_resolver: MacroResolver
macro_resolver: MacroResolver,
) -> Dict[str, Any]:
ctx = TestContext(
model, config, manifest, ParseProvider(), context_config,
macro_resolver
model, config, manifest, ParseProvider(), context_config, macro_resolver
)
# The 'to_dict' method in ManifestContext moves all of the macro names
# in the macro 'namespace' up to top level keys

View File

@@ -2,9 +2,7 @@ from typing import Any, Dict
from dbt.contracts.connection import HasCredentials
from dbt.context.base import (
BaseContext, contextproperty
)
from dbt.context import BaseContext, contextproperty
class TargetContext(BaseContext):

View File

@@ -2,25 +2,35 @@ import abc
import itertools
from dataclasses import dataclass, field
from typing import (
Any, ClassVar, Dict, Tuple, Iterable, Optional, List, Callable,
Any,
ClassVar,
Dict,
Tuple,
Iterable,
Optional,
List,
Callable,
)
from dbt.exceptions import InternalException
from dbt.utils import translate_aliases
from dbt.logger import GLOBAL_LOGGER as logger
from typing_extensions import Protocol
from dbt.dataclass_schema import (
dbtClassMixin, StrEnum, ExtensibleDbtClassMixin,
ValidatedStringMixin, register_pattern
dbtClassMixin,
StrEnum,
ExtensibleDbtClassMixin,
ValidatedStringMixin,
register_pattern,
)
from dbt.contracts.util import Replaceable
class Identifier(ValidatedStringMixin):
ValidationRegex = r'^[A-Za-z_][A-Za-z0-9_]+$'
ValidationRegex = r"^[A-Za-z_][A-Za-z0-9_]+$"
# we need register_pattern for jsonschema validation
register_pattern(Identifier, r'^[A-Za-z_][A-Za-z0-9_]+$')
register_pattern(Identifier, r"^[A-Za-z_][A-Za-z0-9_]+$")
@dataclass
@@ -34,10 +44,10 @@ class AdapterResponse(dbtClassMixin):
class ConnectionState(StrEnum):
INIT = 'init'
OPEN = 'open'
CLOSED = 'closed'
FAIL = 'fail'
INIT = "init"
OPEN = "open"
CLOSED = "closed"
FAIL = "fail"
@dataclass(init=False)
@@ -81,8 +91,7 @@ class Connection(ExtensibleDbtClassMixin, Replaceable):
self._handle.resolve(self)
except RecursionError as exc:
raise InternalException(
"A connection's open() method attempted to read the "
"handle value"
"A connection's open() method attempted to read the " "handle value"
) from exc
return self._handle
@@ -101,8 +110,7 @@ class LazyHandle:
def resolve(self, connection: Connection) -> Connection:
logger.debug(
'Opening a new connection, currently in state {}'
.format(connection.state)
"Opening a new connection, currently in state {}".format(connection.state)
)
return self.opener(connection)
@@ -112,33 +120,24 @@ class LazyHandle:
# for why we have type: ignore. Maybe someday dataclasses + abstract classes
# will work.
@dataclass # type: ignore
class Credentials(
ExtensibleDbtClassMixin,
Replaceable,
metaclass=abc.ABCMeta
):
class Credentials(ExtensibleDbtClassMixin, Replaceable, metaclass=abc.ABCMeta):
database: str
schema: str
_ALIASES: ClassVar[Dict[str, str]] = field(default={}, init=False)
@abc.abstractproperty
def type(self) -> str:
raise NotImplementedError(
'type not implemented for base credentials class'
)
raise NotImplementedError("type not implemented for base credentials class")
def connection_info(
self, *, with_aliases: bool = False
) -> Iterable[Tuple[str, Any]]:
"""Return an ordered iterator of key/value pairs for pretty-printing.
"""
"""Return an ordered iterator of key/value pairs for pretty-printing."""
as_dict = self.to_dict(omit_none=False)
connection_keys = set(self._connection_keys())
aliases: List[str] = []
if with_aliases:
aliases = [
k for k, v in self._ALIASES.items() if v in connection_keys
]
aliases = [k for k, v in self._ALIASES.items() if v in connection_keys]
for key in itertools.chain(self._connection_keys(), aliases):
if key in as_dict:
yield key, as_dict[key]
@@ -162,11 +161,13 @@ class Credentials(
def __post_serialize__(self, dct):
# no super() -- do we need it?
if self._ALIASES:
dct.update({
new_name: dct[canonical_name]
for new_name, canonical_name in self._ALIASES.items()
if canonical_name in dct
})
dct.update(
{
new_name: dct[canonical_name]
for new_name, canonical_name in self._ALIASES.items()
if canonical_name in dct
}
)
return dct
@@ -188,10 +189,10 @@ class HasCredentials(Protocol):
threads: int
def to_target_dict(self):
raise NotImplementedError('to_target_dict not implemented')
raise NotImplementedError("to_target_dict not implemented")
DEFAULT_QUERY_COMMENT = '''
DEFAULT_QUERY_COMMENT = """
{%- set comment_dict = {} -%}
{%- do comment_dict.update(
app='dbt',
@@ -208,7 +209,7 @@ DEFAULT_QUERY_COMMENT = '''
{%- do comment_dict.update(connection_name=connection_name) -%}
{%- endif -%}
{{ return(tojson(comment_dict)) }}
'''
"""
@dataclass

View File

@@ -11,7 +11,7 @@ from .util import MacroKey, SourceKey
MAXIMUM_SEED_SIZE = 1 * 1024 * 1024
MAXIMUM_SEED_SIZE_NAME = '1MB'
MAXIMUM_SEED_SIZE_NAME = "1MB"
@dataclass
@@ -28,9 +28,7 @@ class FilePath(dbtClassMixin):
@property
def full_path(self) -> str:
# useful for symlink preservation
return os.path.join(
self.project_root, self.searched_path, self.relative_path
)
return os.path.join(self.project_root, self.searched_path, self.relative_path)
@property
def absolute_path(self) -> str:
@@ -40,13 +38,10 @@ class FilePath(dbtClassMixin):
def original_file_path(self) -> str:
# this is mostly used for reporting errors. It doesn't show the project
# name, should it?
return os.path.join(
self.searched_path, self.relative_path
)
return os.path.join(self.searched_path, self.relative_path)
def seed_too_large(self) -> bool:
"""Return whether the file this represents is over the seed size limit
"""
"""Return whether the file this represents is over the seed size limit"""
return os.stat(self.full_path).st_size > MAXIMUM_SEED_SIZE
@@ -57,35 +52,35 @@ class FileHash(dbtClassMixin):
@classmethod
def empty(cls):
return FileHash(name='none', checksum='')
return FileHash(name="none", checksum="")
@classmethod
def path(cls, path: str):
return FileHash(name='path', checksum=path)
return FileHash(name="path", checksum=path)
def __eq__(self, other):
if not isinstance(other, FileHash):
return NotImplemented
if self.name == 'none' or self.name != other.name:
if self.name == "none" or self.name != other.name:
return False
return self.checksum == other.checksum
def compare(self, contents: str) -> bool:
"""Compare the file contents with the given hash"""
if self.name == 'none':
if self.name == "none":
return False
return self.from_contents(contents, name=self.name) == self.checksum
@classmethod
def from_contents(cls, contents: str, name='sha256') -> 'FileHash':
def from_contents(cls, contents: str, name="sha256") -> "FileHash":
"""Create a file hash from the given file contents. The hash is always
the utf-8 encoding of the contents given, because dbt only reads files
as utf-8.
"""
data = contents.encode('utf-8')
data = contents.encode("utf-8")
checksum = hashlib.new(name, data).hexdigest()
return cls(name=name, checksum=checksum)
@@ -94,24 +89,25 @@ class FileHash(dbtClassMixin):
class RemoteFile(dbtClassMixin):
@property
def searched_path(self) -> str:
return 'from remote system'
return "from remote system"
@property
def relative_path(self) -> str:
return 'from remote system'
return "from remote system"
@property
def absolute_path(self) -> str:
return 'from remote system'
return "from remote system"
@property
def original_file_path(self):
return 'from remote system'
return "from remote system"
@dataclass
class SourceFile(dbtClassMixin):
"""Define a source file in dbt"""
path: Union[FilePath, RemoteFile] # the path information
checksum: FileHash
# we don't want to serialize this
@@ -133,14 +129,14 @@ class SourceFile(dbtClassMixin):
def search_key(self) -> Optional[str]:
if isinstance(self.path, RemoteFile):
return None
if self.checksum.name == 'none':
if self.checksum.name == "none":
return None
return self.path.search_key
@property
def contents(self) -> str:
if self._contents is None:
raise InternalException('SourceFile has no contents!')
raise InternalException("SourceFile has no contents!")
return self._contents
@contents.setter
@@ -148,20 +144,20 @@ class SourceFile(dbtClassMixin):
self._contents = value
@classmethod
def empty(cls, path: FilePath) -> 'SourceFile':
def empty(cls, path: FilePath) -> "SourceFile":
self = cls(path=path, checksum=FileHash.empty())
self.contents = ''
self.contents = ""
return self
@classmethod
def big_seed(cls, path: FilePath) -> 'SourceFile':
def big_seed(cls, path: FilePath) -> "SourceFile":
"""Parse seeds over the size limit with just the path"""
self = cls(path=path, checksum=FileHash.path(path.original_file_path))
self.contents = ''
self.contents = ""
return self
@classmethod
def remote(cls, contents: str) -> 'SourceFile':
def remote(cls, contents: str) -> "SourceFile":
self = cls(path=RemoteFile(), checksum=FileHash.empty())
self.contents = contents
return self

View File

@@ -58,31 +58,29 @@ class CompiledNode(ParsedNode, CompiledNodeMixin):
@dataclass
class CompiledAnalysisNode(CompiledNode):
resource_type: NodeType = field(metadata={'restrict': [NodeType.Analysis]})
resource_type: NodeType = field(metadata={"restrict": [NodeType.Analysis]})
@dataclass
class CompiledHookNode(CompiledNode):
resource_type: NodeType = field(
metadata={'restrict': [NodeType.Operation]}
)
resource_type: NodeType = field(metadata={"restrict": [NodeType.Operation]})
index: Optional[int] = None
@dataclass
class CompiledModelNode(CompiledNode):
resource_type: NodeType = field(metadata={'restrict': [NodeType.Model]})
resource_type: NodeType = field(metadata={"restrict": [NodeType.Model]})
@dataclass
class CompiledRPCNode(CompiledNode):
resource_type: NodeType = field(metadata={'restrict': [NodeType.RPCCall]})
resource_type: NodeType = field(metadata={"restrict": [NodeType.RPCCall]})
@dataclass
class CompiledSeedNode(CompiledNode):
# keep this in sync with ParsedSeedNode!
resource_type: NodeType = field(metadata={'restrict': [NodeType.Seed]})
resource_type: NodeType = field(metadata={"restrict": [NodeType.Seed]})
config: SeedConfig = field(default_factory=SeedConfig)
@property
@@ -96,26 +94,25 @@ class CompiledSeedNode(CompiledNode):
@dataclass
class CompiledSnapshotNode(CompiledNode):
resource_type: NodeType = field(metadata={'restrict': [NodeType.Snapshot]})
resource_type: NodeType = field(metadata={"restrict": [NodeType.Snapshot]})
@dataclass
class CompiledDataTestNode(CompiledNode):
resource_type: NodeType = field(metadata={'restrict': [NodeType.Test]})
resource_type: NodeType = field(metadata={"restrict": [NodeType.Test]})
config: TestConfig = field(default_factory=TestConfig)
@dataclass
class CompiledSchemaTestNode(CompiledNode, HasTestMetadata):
# keep this in sync with ParsedSchemaTestNode!
resource_type: NodeType = field(metadata={'restrict': [NodeType.Test]})
resource_type: NodeType = field(metadata={"restrict": [NodeType.Test]})
column_name: Optional[str] = None
config: TestConfig = field(default_factory=TestConfig)
def same_config(self, other) -> bool:
return (
self.unrendered_config.get('severity') ==
other.unrendered_config.get('severity')
return self.unrendered_config.get("severity") == other.unrendered_config.get(
"severity"
)
def same_column_name(self, other) -> bool:
@@ -125,11 +122,7 @@ class CompiledSchemaTestNode(CompiledNode, HasTestMetadata):
if other is None:
return False
return (
self.same_config(other) and
self.same_fqn(other) and
True
)
return self.same_config(other) and self.same_fqn(other) and True
CompiledTestNode = Union[CompiledDataTestNode, CompiledSchemaTestNode]
@@ -175,8 +168,7 @@ def parsed_instance_for(compiled: CompiledNode) -> ParsedResource:
cls = PARSED_TYPES.get(type(compiled))
if cls is None:
# how???
raise ValueError('invalid resource_type: {}'
.format(compiled.resource_type))
raise ValueError("invalid resource_type: {}".format(compiled.resource_type))
return cls.from_dict(compiled.to_dict(omit_none=True))

View File

@@ -4,25 +4,51 @@ from dataclasses import dataclass, field
from itertools import chain, islice
from multiprocessing.synchronize import Lock
from typing import (
Dict, List, Optional, Union, Mapping, MutableMapping, Any, Set, Tuple,
TypeVar, Callable, Iterable, Generic, cast, AbstractSet
Dict,
List,
Optional,
Union,
Mapping,
MutableMapping,
Any,
Set,
Tuple,
TypeVar,
Callable,
Iterable,
Generic,
cast,
AbstractSet,
)
from typing_extensions import Protocol
from uuid import UUID
from dbt.contracts.graph.compiled import (
CompileResultNode, ManifestNode, NonSourceCompiledNode, GraphMemberNode
CompileResultNode,
ManifestNode,
NonSourceCompiledNode,
GraphMemberNode,
)
from dbt.contracts.graph.parsed import (
ParsedMacro, ParsedDocumentation, ParsedNodePatch, ParsedMacroPatch,
ParsedSourceDefinition, ParsedExposure
ParsedMacro,
ParsedDocumentation,
ParsedNodePatch,
ParsedMacroPatch,
ParsedSourceDefinition,
ParsedExposure,
)
from dbt.contracts.files import SourceFile
from dbt.contracts.util import (
BaseArtifactMetadata, MacroKey, SourceKey, ArtifactMixin, schema_version
BaseArtifactMetadata,
MacroKey,
SourceKey,
ArtifactMixin,
schema_version,
)
from dbt.exceptions import (
raise_duplicate_resource_name, raise_compiler_error, warn_or_error,
raise_duplicate_resource_name,
raise_compiler_error,
warn_or_error,
raise_invalid_patch,
)
from dbt.helper_types import PathSet
@@ -40,12 +66,12 @@ RefName = str
UniqueID = str
K_T = TypeVar('K_T')
V_T = TypeVar('V_T')
K_T = TypeVar("K_T")
V_T = TypeVar("V_T")
class PackageAwareCache(Generic[K_T, V_T]):
def __init__(self, manifest: 'Manifest'):
def __init__(self, manifest: "Manifest"):
self.storage: Dict[K_T, Dict[PackageName, UniqueID]] = {}
self._manifest = manifest
self.populate()
@@ -95,12 +121,10 @@ class DocCache(PackageAwareCache[DocName, ParsedDocumentation]):
for doc in self._manifest.docs.values():
self.add_doc(doc)
def perform_lookup(
self, unique_id: UniqueID
) -> ParsedDocumentation:
def perform_lookup(self, unique_id: UniqueID) -> ParsedDocumentation:
if unique_id not in self._manifest.docs:
raise dbt.exceptions.InternalException(
f'Doc {unique_id} found in cache but not found in manifest'
f"Doc {unique_id} found in cache but not found in manifest"
)
return self._manifest.docs[unique_id]
@@ -117,12 +141,10 @@ class SourceCache(PackageAwareCache[SourceKey, ParsedSourceDefinition]):
for source in self._manifest.sources.values():
self.add_source(source)
def perform_lookup(
self, unique_id: UniqueID
) -> ParsedSourceDefinition:
def perform_lookup(self, unique_id: UniqueID) -> ParsedSourceDefinition:
if unique_id not in self._manifest.sources:
raise dbt.exceptions.InternalException(
f'Source {unique_id} found in cache but not found in manifest'
f"Source {unique_id} found in cache but not found in manifest"
)
return self._manifest.sources[unique_id]
@@ -131,7 +153,7 @@ class RefableCache(PackageAwareCache[RefName, ManifestNode]):
# refables are actually unique, so the Dict[PackageName, UniqueID] will
# only ever have exactly one value, but doing 3 dict lookups instead of 1
# is not a big deal at all and retains consistency
def __init__(self, manifest: 'Manifest'):
def __init__(self, manifest: "Manifest"):
self._cached_types = set(NodeType.refable())
super().__init__(manifest)
@@ -145,12 +167,10 @@ class RefableCache(PackageAwareCache[RefName, ManifestNode]):
for node in self._manifest.nodes.values():
self.add_node(node)
def perform_lookup(
self, unique_id: UniqueID
) -> ManifestNode:
def perform_lookup(self, unique_id: UniqueID) -> ManifestNode:
if unique_id not in self._manifest.nodes:
raise dbt.exceptions.InternalException(
f'Node {unique_id} found in cache but not found in manifest'
f"Node {unique_id} found in cache but not found in manifest"
)
return self._manifest.nodes[unique_id]
@@ -171,30 +191,31 @@ def _search_packages(
@dataclass
class ManifestMetadata(BaseArtifactMetadata):
"""Metadata for the manifest."""
dbt_schema_version: str = field(
default_factory=lambda: str(WritableManifest.dbt_schema_version)
)
project_id: Optional[str] = field(
default=None,
metadata={
'description': 'A unique identifier for the project',
"description": "A unique identifier for the project",
},
)
user_id: Optional[UUID] = field(
default=None,
metadata={
'description': 'A unique identifier for the user',
"description": "A unique identifier for the user",
},
)
send_anonymous_usage_stats: Optional[bool] = field(
default=None,
metadata=dict(description=(
'Whether dbt is configured to send anonymous usage statistics'
)),
metadata=dict(
description=("Whether dbt is configured to send anonymous usage statistics")
),
)
adapter_type: Optional[str] = field(
default=None,
metadata=dict(description='The type name of the adapter'),
metadata=dict(description="The type name of the adapter"),
)
def __post_init__(self):
@@ -205,9 +226,7 @@ class ManifestMetadata(BaseArtifactMetadata):
self.user_id = tracking.active_user.id
if self.send_anonymous_usage_stats is None:
self.send_anonymous_usage_stats = (
not tracking.active_user.do_not_track
)
self.send_anonymous_usage_stats = not tracking.active_user.do_not_track
@classmethod
def default(cls):
@@ -281,7 +300,7 @@ class MaterializationCandidate(MacroCandidate):
@classmethod
def from_macro(
cls, candidate: MacroCandidate, specificity: Specificity
) -> 'MaterializationCandidate':
) -> "MaterializationCandidate":
return cls(
locality=candidate.locality,
macro=candidate.macro,
@@ -292,15 +311,14 @@ class MaterializationCandidate(MacroCandidate):
if not isinstance(other, MaterializationCandidate):
return NotImplemented
equal = (
self.specificity == other.specificity and
self.locality == other.locality
self.specificity == other.specificity and self.locality == other.locality
)
if equal:
raise_compiler_error(
'Found two materializations with the name {} (packages {} and '
'{}). dbt cannot resolve this ambiguity'
.format(self.macro.name, self.macro.package_name,
other.macro.package_name)
"Found two materializations with the name {} (packages {} and "
"{}). dbt cannot resolve this ambiguity".format(
self.macro.name, self.macro.package_name, other.macro.package_name
)
)
return equal
@@ -319,7 +337,7 @@ class MaterializationCandidate(MacroCandidate):
return False
M = TypeVar('M', bound=MacroCandidate)
M = TypeVar("M", bound=MacroCandidate)
class CandidateList(List[M]):
@@ -347,10 +365,10 @@ class Searchable(Protocol):
@property
def search_name(self) -> str:
raise NotImplementedError('search_name not implemented')
raise NotImplementedError("search_name not implemented")
N = TypeVar('N', bound=Searchable)
N = TypeVar("N", bound=Searchable)
@dataclass
@@ -382,7 +400,7 @@ class NameSearcher(Generic[N]):
return None
D = TypeVar('D')
D = TypeVar("D")
@dataclass
@@ -393,19 +411,18 @@ class Disabled(Generic[D]):
MaybeDocumentation = Optional[ParsedDocumentation]
MaybeParsedSource = Optional[Union[
ParsedSourceDefinition,
Disabled[ParsedSourceDefinition],
]]
MaybeParsedSource = Optional[
Union[
ParsedSourceDefinition,
Disabled[ParsedSourceDefinition],
]
]
MaybeNonSource = Optional[Union[
ManifestNode,
Disabled[ManifestNode]
]]
MaybeNonSource = Optional[Union[ManifestNode, Disabled[ManifestNode]]]
T = TypeVar('T', bound=GraphMemberNode)
T = TypeVar("T", bound=GraphMemberNode)
def _update_into(dest: MutableMapping[str, T], new_item: T):
@@ -416,14 +433,13 @@ def _update_into(dest: MutableMapping[str, T], new_item: T):
unique_id = new_item.unique_id
if unique_id not in dest:
raise dbt.exceptions.RuntimeException(
f'got an update_{new_item.resource_type} call with an '
f'unrecognized {new_item.resource_type}: {new_item.unique_id}'
f"got an update_{new_item.resource_type} call with an "
f"unrecognized {new_item.resource_type}: {new_item.unique_id}"
)
existing = dest[unique_id]
if new_item.original_file_path != existing.original_file_path:
raise dbt.exceptions.RuntimeException(
f'cannot update a {new_item.resource_type} to have a new file '
f'path!'
f"cannot update a {new_item.resource_type} to have a new file " f"path!"
)
dest[unique_id] = new_item
@@ -447,6 +463,7 @@ class MacroMethods:
"""
filter: Optional[Callable[[MacroCandidate], bool]] = None
if package is not None:
def filter(candidate: MacroCandidate) -> bool:
return package == candidate.macro.package_name
@@ -469,11 +486,12 @@ class MacroMethods:
- return the `generate_{component}_name` macro from the 'dbt'
internal project
"""
def filter(candidate: MacroCandidate) -> bool:
return candidate.locality != Locality.Imported
candidates: CandidateList = self._find_macros_by_name(
name=f'generate_{component}_name',
name=f"generate_{component}_name",
root_project_name=root_project_name,
# filter out imported packages
filter=filter,
@@ -484,12 +502,12 @@ class MacroMethods:
self,
name: str,
root_project_name: str,
filter: Optional[Callable[[MacroCandidate], bool]] = None
filter: Optional[Callable[[MacroCandidate], bool]] = None,
) -> CandidateList:
"""Find macros by their name.
"""
"""Find macros by their name."""
# avoid an import cycle
from dbt.adapters.factory import get_adapter_package_names
candidates: CandidateList = CandidateList()
packages = set(get_adapter_package_names(self.metadata.adapter_type))
for unique_id, macro in self.macros.items():
@@ -507,8 +525,8 @@ class MacroMethods:
@dataclass
class Manifest(MacroMethods):
"""The manifest for the full graph, after parsing and during compilation.
"""
"""The manifest for the full graph, after parsing and during compilation."""
# These attributes are both positional and by keyword. If an attribute
# is added it must all be added in the __reduce_ex__ method in the
# args tuple in the right position.
@@ -541,7 +559,7 @@ class Manifest(MacroMethods):
"""
with self._lock:
existing = self.nodes[new_node.unique_id]
if getattr(existing, 'compiled', False):
if getattr(existing, "compiled", False):
# already compiled -> must be a NonSourceCompiledNode
return cast(NonSourceCompiledNode, existing)
_update_into(self.nodes, new_node)
@@ -563,39 +581,30 @@ class Manifest(MacroMethods):
manifest!
"""
self.flat_graph = {
'nodes': {
k: v.to_dict(omit_none=False)
for k, v in self.nodes.items()
},
'sources': {
k: v.to_dict(omit_none=False)
for k, v in self.sources.items()
}
"nodes": {k: v.to_dict(omit_none=False) for k, v in self.nodes.items()},
"sources": {k: v.to_dict(omit_none=False) for k, v in self.sources.items()},
}
def find_disabled_by_name(
self, name: str, package: Optional[str] = None
) -> Optional[ManifestNode]:
searcher: NameSearcher = NameSearcher(
name, package, NodeType.refable()
)
searcher: NameSearcher = NameSearcher(name, package, NodeType.refable())
result = searcher.search(self.disabled)
return result
def find_disabled_source_by_name(
self, source_name: str, table_name: str, package: Optional[str] = None
) -> Optional[ParsedSourceDefinition]:
search_name = f'{source_name}.{table_name}'
searcher: NameSearcher = NameSearcher(
search_name, package, [NodeType.Source]
)
search_name = f"{source_name}.{table_name}"
searcher: NameSearcher = NameSearcher(search_name, package, [NodeType.Source])
result = searcher.search(self.disabled)
if result is not None:
assert isinstance(result, ParsedSourceDefinition)
return result
def _materialization_candidates_for(
self, project_name: str,
self,
project_name: str,
materialization_name: str,
adapter_type: Optional[str],
) -> CandidateList:
@@ -618,13 +627,16 @@ class Manifest(MacroMethods):
def find_materialization_macro_by_name(
self, project_name: str, materialization_name: str, adapter_type: str
) -> Optional[ParsedMacro]:
candidates: CandidateList = CandidateList(chain.from_iterable(
self._materialization_candidates_for(
project_name=project_name,
materialization_name=materialization_name,
adapter_type=atype,
) for atype in (adapter_type, None)
))
candidates: CandidateList = CandidateList(
chain.from_iterable(
self._materialization_candidates_for(
project_name=project_name,
materialization_name=materialization_name,
adapter_type=atype,
)
for atype in (adapter_type, None)
)
)
return candidates.last()
def get_resource_fqns(self) -> Mapping[str, PathSet]:
@@ -648,9 +660,7 @@ class Manifest(MacroMethods):
if node.resource_type in NodeType.refable():
self._refs_cache.add_node(node)
def patch_macros(
self, patches: MutableMapping[MacroKey, ParsedMacroPatch]
) -> None:
def patch_macros(self, patches: MutableMapping[MacroKey, ParsedMacroPatch]) -> None:
for macro in self.macros.values():
key = (macro.package_name, macro.name)
patch = patches.pop(key, None)
@@ -662,12 +672,10 @@ class Manifest(MacroMethods):
for patch in patches.values():
warn_or_error(
f'WARNING: Found documentation for macro "{patch.name}" '
f'which was not found'
f"which was not found"
)
def patch_nodes(
self, patches: MutableMapping[str, ParsedNodePatch]
) -> None:
def patch_nodes(self, patches: MutableMapping[str, ParsedNodePatch]) -> None:
"""Patch nodes with the given dict of patches. Note that this consumes
the input!
This relies on the fact that all nodes have unique _name_ fields, not
@@ -684,15 +692,15 @@ class Manifest(MacroMethods):
expected_key = node.resource_type.pluralize()
if expected_key != patch.yaml_key:
if patch.yaml_key == 'models':
if patch.yaml_key == "models":
deprecations.warn(
'models-key-mismatch',
patch=patch, node=node, expected_key=expected_key
"models-key-mismatch",
patch=patch,
node=node,
expected_key=expected_key,
)
else:
raise_invalid_patch(
node, patch.yaml_key, patch.original_file_path
)
raise_invalid_patch(node, patch.yaml_key, patch.original_file_path)
node.patch(patch)
@@ -701,22 +709,25 @@ class Manifest(MacroMethods):
for patch in patches.values():
# since patches aren't nodes, we can't use the existing
# target_not_found warning
logger.debug((
'WARNING: Found documentation for resource "{}" which was '
'not found or is disabled').format(patch.name)
logger.debug(
(
'WARNING: Found documentation for resource "{}" which was '
"not found or is disabled"
).format(patch.name)
)
def get_used_schemas(self, resource_types=None):
return frozenset({
(node.database, node.schema) for node in
chain(self.nodes.values(), self.sources.values())
if not resource_types or node.resource_type in resource_types
})
return frozenset(
{
(node.database, node.schema)
for node in chain(self.nodes.values(), self.sources.values())
if not resource_types or node.resource_type in resource_types
}
)
def get_used_databases(self):
return frozenset(
x.database for x in
chain(self.nodes.values(), self.sources.values())
x.database for x in chain(self.nodes.values(), self.sources.values())
)
def deepcopy(self):
@@ -733,11 +744,13 @@ class Manifest(MacroMethods):
)
def writable_manifest(self):
edge_members = list(chain(
self.nodes.values(),
self.sources.values(),
self.exposures.values(),
))
edge_members = list(
chain(
self.nodes.values(),
self.sources.values(),
self.exposures.values(),
)
)
forward_edges, backward_edges = build_edges(edge_members)
return WritableManifest(
@@ -771,7 +784,7 @@ class Manifest(MacroMethods):
else:
# something terrible has happened
raise dbt.exceptions.InternalException(
'Expected node {} not found in manifest'.format(unique_id)
"Expected node {} not found in manifest".format(unique_id)
)
@property
@@ -820,9 +833,7 @@ class Manifest(MacroMethods):
# it's possible that the node is disabled
if disabled is None:
disabled = self.find_disabled_by_name(
target_model_name, pkg
)
disabled = self.find_disabled_by_name(target_model_name, pkg)
if disabled is not None:
return Disabled(disabled)
@@ -833,7 +844,7 @@ class Manifest(MacroMethods):
target_source_name: str,
target_table_name: str,
current_project: str,
node_package: str
node_package: str,
) -> MaybeParsedSource:
key = (target_source_name, target_table_name)
candidates = _search_packages(current_project, node_package)
@@ -866,9 +877,7 @@ class Manifest(MacroMethods):
resolve_ref except the is_enabled checks are unnecessary as docs are
always enabled.
"""
candidates = _search_packages(
current_project, node_package, package
)
candidates = _search_packages(current_project, node_package, package)
for pkg in candidates:
result = self.docs_cache.find_cached_value(name, pkg)
@@ -879,7 +888,7 @@ class Manifest(MacroMethods):
def merge_from_artifact(
self,
adapter,
other: 'WritableManifest',
other: "WritableManifest",
selected: AbstractSet[UniqueID],
) -> None:
"""Given the selected unique IDs and a writable manifest, update this
@@ -892,10 +901,10 @@ class Manifest(MacroMethods):
for unique_id, node in other.nodes.items():
current = self.nodes.get(unique_id)
if current and (
node.resource_type in refables and
not node.is_ephemeral and
unique_id not in selected and
not adapter.get_relation(
node.resource_type in refables
and not node.is_ephemeral
and unique_id not in selected
and not adapter.get_relation(
current.database, current.schema, current.identifier
)
):
@@ -904,9 +913,7 @@ class Manifest(MacroMethods):
# log up to 5 items
sample = list(islice(merged, 5))
logger.debug(
f'Merged {len(merged)} items from state (sample: {sample})'
)
logger.debug(f"Merged {len(merged)} items from state (sample: {sample})")
# Provide support for copy.deepcopy() - we just need to avoid the lock!
# pickle and deepcopy use this. It returns a callable object used to
@@ -948,47 +955,53 @@ AnyManifest = Union[Manifest, MacroManifest]
@dataclass
@schema_version('manifest', 1)
@schema_version("manifest", 1)
class WritableManifest(ArtifactMixin):
nodes: Mapping[UniqueID, ManifestNode] = field(
metadata=dict(description=(
'The nodes defined in the dbt project and its dependencies'
))
metadata=dict(
description=("The nodes defined in the dbt project and its dependencies")
)
)
sources: Mapping[UniqueID, ParsedSourceDefinition] = field(
metadata=dict(description=(
'The sources defined in the dbt project and its dependencies'
))
metadata=dict(
description=("The sources defined in the dbt project and its dependencies")
)
)
macros: Mapping[UniqueID, ParsedMacro] = field(
metadata=dict(description=(
'The macros defined in the dbt project and its dependencies'
))
metadata=dict(
description=("The macros defined in the dbt project and its dependencies")
)
)
docs: Mapping[UniqueID, ParsedDocumentation] = field(
metadata=dict(description=(
'The docs defined in the dbt project and its dependencies'
))
metadata=dict(
description=("The docs defined in the dbt project and its dependencies")
)
)
exposures: Mapping[UniqueID, ParsedExposure] = field(
metadata=dict(description=(
'The exposures defined in the dbt project and its dependencies'
))
metadata=dict(
description=(
"The exposures defined in the dbt project and its dependencies"
)
)
)
selectors: Mapping[UniqueID, Any] = field(
metadata=dict(description=(
'The selectors defined in selectors.yml'
))
metadata=dict(description=("The selectors defined in selectors.yml"))
)
disabled: Optional[List[CompileResultNode]] = field(
metadata=dict(description="A list of the disabled nodes in the target")
)
parent_map: Optional[NodeEdgeMap] = field(
metadata=dict(
description="A mapping from child nodes to their dependencies",
)
)
child_map: Optional[NodeEdgeMap] = field(
metadata=dict(
description="A mapping from parent nodes to their dependents",
)
)
metadata: ManifestMetadata = field(
metadata=dict(
description="Metadata about the manifest",
)
)
disabled: Optional[List[CompileResultNode]] = field(metadata=dict(
description='A list of the disabled nodes in the target'
))
parent_map: Optional[NodeEdgeMap] = field(metadata=dict(
description='A mapping from child nodes to their dependencies',
))
child_map: Optional[NodeEdgeMap] = field(metadata=dict(
description='A mapping from parent nodes to their dependents',
))
metadata: ManifestMetadata = field(metadata=dict(
description='Metadata about the manifest',
))

View File

@@ -2,11 +2,20 @@ from dataclasses import field, Field, dataclass
from enum import Enum
from itertools import chain
from typing import (
Any, List, Optional, Dict, MutableMapping, Union, Type,
TypeVar, Callable,
Any,
List,
Optional,
Dict,
MutableMapping,
Union,
Type,
TypeVar,
Callable,
)
from dbt.dataclass_schema import (
dbtClassMixin, ValidationError, register_pattern,
dbtClassMixin,
ValidationError,
register_pattern,
)
from dbt.contracts.graph.unparsed import AdditionalPropertiesAllowed
from dbt.exceptions import CompilationException, InternalException
@@ -15,7 +24,7 @@ from dbt import hooks
from dbt.node_types import NodeType
M = TypeVar('M', bound='Metadata')
M = TypeVar("M", bound="Metadata")
def _get_meta_value(cls: Type[M], fld: Field, key: str, default: Any) -> M:
@@ -30,9 +39,7 @@ def _get_meta_value(cls: Type[M], fld: Field, key: str, default: Any) -> M:
try:
return cls(value)
except ValueError as exc:
raise InternalException(
f'Invalid {cls} value: {value}'
) from exc
raise InternalException(f"Invalid {cls} value: {value}") from exc
def _set_meta_value(
@@ -54,19 +61,17 @@ class Metadata(Enum):
return _get_meta_value(cls, fld, key, default)
def meta(
self, existing: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
def meta(self, existing: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
key = self.metadata_key()
return _set_meta_value(self, key, existing)
@classmethod
def default_field(cls) -> 'Metadata':
raise NotImplementedError('Not implemented')
def default_field(cls) -> "Metadata":
raise NotImplementedError("Not implemented")
@classmethod
def metadata_key(cls) -> str:
raise NotImplementedError('Not implemented')
raise NotImplementedError("Not implemented")
class MergeBehavior(Metadata):
@@ -75,12 +80,12 @@ class MergeBehavior(Metadata):
Clobber = 3
@classmethod
def default_field(cls) -> 'MergeBehavior':
def default_field(cls) -> "MergeBehavior":
return cls.Clobber
@classmethod
def metadata_key(cls) -> str:
return 'merge'
return "merge"
class ShowBehavior(Metadata):
@@ -88,12 +93,12 @@ class ShowBehavior(Metadata):
Hide = 2
@classmethod
def default_field(cls) -> 'ShowBehavior':
def default_field(cls) -> "ShowBehavior":
return cls.Show
@classmethod
def metadata_key(cls) -> str:
return 'show_hide'
return "show_hide"
@classmethod
def should_show(cls, fld: Field) -> bool:
@@ -105,12 +110,12 @@ class CompareBehavior(Metadata):
Exclude = 2
@classmethod
def default_field(cls) -> 'CompareBehavior':
def default_field(cls) -> "CompareBehavior":
return cls.Include
@classmethod
def metadata_key(cls) -> str:
return 'compare'
return "compare"
@classmethod
def should_include(cls, fld: Field) -> bool:
@@ -142,32 +147,30 @@ def _merge_field_value(
return _listify(self_value) + _listify(other_value)
elif merge_behavior == MergeBehavior.Update:
if not isinstance(self_value, dict):
raise InternalException(f'expected dict, got {self_value}')
raise InternalException(f"expected dict, got {self_value}")
if not isinstance(other_value, dict):
raise InternalException(f'expected dict, got {other_value}')
raise InternalException(f"expected dict, got {other_value}")
value = self_value.copy()
value.update(other_value)
return value
else:
raise InternalException(
f'Got an invalid merge_behavior: {merge_behavior}'
)
raise InternalException(f"Got an invalid merge_behavior: {merge_behavior}")
def insensitive_patterns(*patterns: str):
lowercased = []
for pattern in patterns:
lowercased.append(
''.join('[{}{}]'.format(s.upper(), s.lower()) for s in pattern)
"".join("[{}{}]".format(s.upper(), s.lower()) for s in pattern)
)
return '^({})$'.format('|'.join(lowercased))
return "^({})$".format("|".join(lowercased))
class Severity(str):
pass
register_pattern(Severity, insensitive_patterns('warn', 'error'))
register_pattern(Severity, insensitive_patterns("warn", "error"))
@dataclass
@@ -177,13 +180,11 @@ class Hook(dbtClassMixin, Replaceable):
index: Optional[int] = None
T = TypeVar('T', bound='BaseConfig')
T = TypeVar("T", bound="BaseConfig")
@dataclass
class BaseConfig(
AdditionalPropertiesAllowed, Replaceable, MutableMapping[str, Any]
):
class BaseConfig(AdditionalPropertiesAllowed, Replaceable, MutableMapping[str, Any]):
# Implement MutableMapping so this config will behave as some macros expect
# during parsing (notably, syntax like `{{ node.config['schema'] }}`)
def __getitem__(self, key):
@@ -204,8 +205,7 @@ class BaseConfig(
def __delitem__(self, key):
if hasattr(self, key):
msg = (
'Error, tried to delete config key "{}": Cannot delete '
'built-in keys'
'Error, tried to delete config key "{}": Cannot delete ' "built-in keys"
).format(key)
raise CompilationException(msg)
else:
@@ -245,9 +245,7 @@ class BaseConfig(
return unrendered[key] == other[key]
@classmethod
def same_contents(
cls, unrendered: Dict[str, Any], other: Dict[str, Any]
) -> bool:
def same_contents(cls, unrendered: Dict[str, Any], other: Dict[str, Any]) -> bool:
"""This is like __eq__, except it ignores some fields."""
seen = set()
for fld, target_name in cls._get_fields():
@@ -265,9 +263,7 @@ class BaseConfig(
return True
@classmethod
def _extract_dict(
cls, src: Dict[str, Any], data: Dict[str, Any]
) -> Dict[str, Any]:
def _extract_dict(cls, src: Dict[str, Any], data: Dict[str, Any]) -> Dict[str, Any]:
"""Find all the items in data that match a target_field on this class,
and merge them with the data found in `src` for target_field, using the
field's specified merge behavior. Matching items will be removed from
@@ -307,6 +303,7 @@ class BaseConfig(
"""
# sadly, this is a circular import
from dbt.adapters.factory import get_config_class_by_name
dct = self.to_dict(omit_none=False)
adapter_config_cls = get_config_class_by_name(adapter_type)
@@ -348,7 +345,7 @@ class SourceConfig(BaseConfig):
@dataclass
class NodeConfig(BaseConfig):
enabled: bool = True
materialized: str = 'view'
materialized: str = "view"
persist_docs: Dict[str, Any] = field(default_factory=dict)
post_hook: List[Hook] = field(
default_factory=list,
@@ -389,16 +386,16 @@ class NodeConfig(BaseConfig):
)
tags: Union[List[str], str] = field(
default_factory=list_str,
metadata=metas(ShowBehavior.Hide,
MergeBehavior.Append,
CompareBehavior.Exclude),
metadata=metas(
ShowBehavior.Hide, MergeBehavior.Append, CompareBehavior.Exclude
),
)
full_refresh: Optional[bool] = None
@classmethod
def __pre_deserialize__(cls, data):
data = super().__pre_deserialize__(data)
field_map = {'post-hook': 'post_hook', 'pre-hook': 'pre_hook'}
field_map = {"post-hook": "post_hook", "pre-hook": "pre_hook"}
# create a new dict because otherwise it gets overwritten in
# tests
new_dict = {}
@@ -416,7 +413,7 @@ class NodeConfig(BaseConfig):
def __post_serialize__(self, dct):
dct = super().__post_serialize__(dct)
field_map = {'post_hook': 'post-hook', 'pre_hook': 'pre-hook'}
field_map = {"post_hook": "post-hook", "pre_hook": "pre-hook"}
for field_name in field_map:
if field_name in dct:
dct[field_map[field_name]] = dct.pop(field_name)
@@ -425,24 +422,24 @@ class NodeConfig(BaseConfig):
# this is still used by jsonschema validation
@classmethod
def field_mapping(cls):
return {'post_hook': 'post-hook', 'pre_hook': 'pre-hook'}
return {"post_hook": "post-hook", "pre_hook": "pre-hook"}
@dataclass
class SeedConfig(NodeConfig):
materialized: str = 'seed'
materialized: str = "seed"
quote_columns: Optional[bool] = None
@dataclass
class TestConfig(NodeConfig):
materialized: str = 'test'
severity: Severity = Severity('ERROR')
materialized: str = "test"
severity: Severity = Severity("ERROR")
@dataclass
class EmptySnapshotConfig(NodeConfig):
materialized: str = 'snapshot'
materialized: str = "snapshot"
@dataclass
@@ -457,25 +454,28 @@ class SnapshotConfig(EmptySnapshotConfig):
@classmethod
def validate(cls, data):
super().validate(data)
if data.get('strategy') == 'check':
if not data.get('check_cols'):
if data.get("strategy") == "check":
if not data.get("check_cols"):
raise ValidationError(
"A snapshot configured with the check strategy must "
"specify a check_cols configuration.")
if (isinstance(data['check_cols'], str) and
data['check_cols'] != 'all'):
"specify a check_cols configuration."
)
if isinstance(data["check_cols"], str) and data["check_cols"] != "all":
raise ValidationError(
f"Invalid value for 'check_cols': {data['check_cols']}. "
"Expected 'all' or a list of strings.")
"Expected 'all' or a list of strings."
)
elif data.get('strategy') == 'timestamp':
if not data.get('updated_at'):
elif data.get("strategy") == "timestamp":
if not data.get("updated_at"):
raise ValidationError(
"A snapshot configured with the timestamp strategy "
"must specify an updated_at configuration.")
if data.get('check_cols'):
"must specify an updated_at configuration."
)
if data.get("check_cols"):
raise ValidationError(
"A 'timestamp' snapshot should not have 'check_cols'")
"A 'timestamp' snapshot should not have 'check_cols'"
)
# If the strategy is not 'check' or 'timestamp' it's a custom strategy,
# formerly supported with GenericSnapshotConfig
@@ -497,9 +497,7 @@ RESOURCE_TYPES: Dict[NodeType, Type[BaseConfig]] = {
# base resource types are like resource types, except nothing has mandatory
# configs.
BASE_RESOURCE_TYPES: Dict[NodeType, Type[BaseConfig]] = RESOURCE_TYPES.copy()
BASE_RESOURCE_TYPES.update({
NodeType.Snapshot: EmptySnapshotConfig
})
BASE_RESOURCE_TYPES.update({NodeType.Snapshot: EmptySnapshotConfig})
def get_config_for(resource_type: NodeType, base=False) -> Type[BaseConfig]:

View File

@@ -13,18 +13,27 @@ from typing import (
TypeVar,
)
from dbt.dataclass_schema import (
dbtClassMixin, ExtensibleDbtClassMixin
)
from dbt.dataclass_schema import dbtClassMixin, ExtensibleDbtClassMixin
from dbt.clients.system import write_file
from dbt.contracts.files import FileHash, MAXIMUM_SEED_SIZE_NAME
from dbt.contracts.graph.unparsed import (
UnparsedNode, UnparsedDocumentation, Quoting, Docs,
UnparsedBaseNode, FreshnessThreshold, ExternalTable,
HasYamlMetadata, MacroArgument, UnparsedSourceDefinition,
UnparsedSourceTableDefinition, UnparsedColumn, TestDef,
ExposureOwner, ExposureType, MaturityType
UnparsedNode,
UnparsedDocumentation,
Quoting,
Docs,
UnparsedBaseNode,
FreshnessThreshold,
ExternalTable,
HasYamlMetadata,
MacroArgument,
UnparsedSourceDefinition,
UnparsedSourceTableDefinition,
UnparsedColumn,
TestDef,
ExposureOwner,
ExposureType,
MaturityType,
)
from dbt.contracts.util import Replaceable, AdditionalPropertiesMixin
from dbt.exceptions import warn_or_error
@@ -44,13 +53,9 @@ from .model_config import (
@dataclass
class ColumnInfo(
AdditionalPropertiesMixin,
ExtensibleDbtClassMixin,
Replaceable
):
class ColumnInfo(AdditionalPropertiesMixin, ExtensibleDbtClassMixin, Replaceable):
name: str
description: str = ''
description: str = ""
meta: Dict[str, Any] = field(default_factory=dict)
data_type: Optional[str] = None
quote: Optional[bool] = None
@@ -62,7 +67,7 @@ class ColumnInfo(
class HasFqn(dbtClassMixin, Replaceable):
fqn: List[str]
def same_fqn(self, other: 'HasFqn') -> bool:
def same_fqn(self, other: "HasFqn") -> bool:
return self.fqn == other.fqn
@@ -101,8 +106,8 @@ class HasRelationMetadata(dbtClassMixin, Replaceable):
@classmethod
def __pre_deserialize__(cls, data):
data = super().__pre_deserialize__(data)
if 'database' not in data:
data['database'] = None
if "database" not in data:
data["database"] = None
return data
@@ -117,7 +122,7 @@ class ParsedNodeMixins(dbtClassMixin):
@property
def is_ephemeral(self):
return self.config.materialized == 'ephemeral'
return self.config.materialized == "ephemeral"
@property
def is_ephemeral_model(self):
@@ -127,7 +132,7 @@ class ParsedNodeMixins(dbtClassMixin):
def depends_on_nodes(self):
return self.depends_on.nodes
def patch(self, patch: 'ParsedNodePatch'):
def patch(self, patch: "ParsedNodePatch"):
"""Given a ParsedNodePatch, add the new information to the node."""
# explicitly pick out the parts to update so we don't inadvertently
# step on the model name or anything
@@ -153,11 +158,7 @@ class ParsedNodeMixins(dbtClassMixin):
@dataclass
class ParsedNodeMandatory(
UnparsedNode,
HasUniqueID,
HasFqn,
HasRelationMetadata,
Replaceable
UnparsedNode, HasUniqueID, HasFqn, HasRelationMetadata, Replaceable
):
alias: str
checksum: FileHash
@@ -174,7 +175,7 @@ class ParsedNodeDefaults(ParsedNodeMandatory):
refs: List[List[str]] = field(default_factory=list)
sources: List[List[Any]] = field(default_factory=list)
depends_on: DependsOn = field(default_factory=DependsOn)
description: str = field(default='')
description: str = field(default="")
columns: Dict[str, ColumnInfo] = field(default_factory=dict)
meta: Dict[str, Any] = field(default_factory=dict)
docs: Docs = field(default_factory=Docs)
@@ -184,31 +185,28 @@ class ParsedNodeDefaults(ParsedNodeMandatory):
unrendered_config: Dict[str, Any] = field(default_factory=dict)
def write_node(self, target_path: str, subdirectory: str, payload: str):
if (os.path.basename(self.path) ==
os.path.basename(self.original_file_path)):
if os.path.basename(self.path) == os.path.basename(self.original_file_path):
# One-to-one relationship of nodes to files.
path = self.original_file_path
else:
# Many-to-one relationship of nodes to files.
path = os.path.join(self.original_file_path, self.path)
full_path = os.path.join(
target_path, subdirectory, self.package_name, path
)
full_path = os.path.join(target_path, subdirectory, self.package_name, path)
write_file(full_path, payload)
return full_path
T = TypeVar('T', bound='ParsedNode')
T = TypeVar("T", bound="ParsedNode")
@dataclass
class ParsedNode(ParsedNodeDefaults, ParsedNodeMixins):
def _persist_column_docs(self) -> bool:
return bool(self.config.persist_docs.get('columns'))
return bool(self.config.persist_docs.get("columns"))
def _persist_relation_docs(self) -> bool:
return bool(self.config.persist_docs.get('relation'))
return bool(self.config.persist_docs.get("relation"))
def same_body(self: T, other: T) -> bool:
return self.raw_sql == other.raw_sql
@@ -223,9 +221,7 @@ class ParsedNode(ParsedNodeDefaults, ParsedNodeMixins):
if self._persist_column_docs():
# assert other._persist_column_docs()
column_descriptions = {
k: v.description for k, v in self.columns.items()
}
column_descriptions = {k: v.description for k, v in self.columns.items()}
other_column_descriptions = {
k: v.description for k, v in other.columns.items()
}
@@ -239,7 +235,7 @@ class ParsedNode(ParsedNodeDefaults, ParsedNodeMixins):
# compares the configured value, rather than the ultimate value (so
# generate_*_name and unset values derived from the target are
# ignored)
keys = ('database', 'schema', 'alias')
keys = ("database", "schema", "alias")
for key in keys:
mine = self.unrendered_config.get(key)
others = other.unrendered_config.get(key)
@@ -258,36 +254,34 @@ class ParsedNode(ParsedNodeDefaults, ParsedNodeMixins):
return False
return (
self.same_body(old) and
self.same_config(old) and
self.same_persisted_description(old) and
self.same_fqn(old) and
self.same_database_representation(old) and
True
self.same_body(old)
and self.same_config(old)
and self.same_persisted_description(old)
and self.same_fqn(old)
and self.same_database_representation(old)
and True
)
@dataclass
class ParsedAnalysisNode(ParsedNode):
resource_type: NodeType = field(metadata={'restrict': [NodeType.Analysis]})
resource_type: NodeType = field(metadata={"restrict": [NodeType.Analysis]})
@dataclass
class ParsedHookNode(ParsedNode):
resource_type: NodeType = field(
metadata={'restrict': [NodeType.Operation]}
)
resource_type: NodeType = field(metadata={"restrict": [NodeType.Operation]})
index: Optional[int] = None
@dataclass
class ParsedModelNode(ParsedNode):
resource_type: NodeType = field(metadata={'restrict': [NodeType.Model]})
resource_type: NodeType = field(metadata={"restrict": [NodeType.Model]})
@dataclass
class ParsedRPCNode(ParsedNode):
resource_type: NodeType = field(metadata={'restrict': [NodeType.RPCCall]})
resource_type: NodeType = field(metadata={"restrict": [NodeType.RPCCall]})
def same_seeds(first: ParsedNode, second: ParsedNode) -> bool:
@@ -297,31 +291,31 @@ def same_seeds(first: ParsedNode, second: ParsedNode) -> bool:
# if the current checksum is a path, we want to log a warning.
result = first.checksum == second.checksum
if first.checksum.name == 'path':
if first.checksum.name == "path":
msg: str
if second.checksum.name != 'path':
if second.checksum.name != "path":
msg = (
f'Found a seed ({first.package_name}.{first.name}) '
f'>{MAXIMUM_SEED_SIZE_NAME} in size. The previous file was '
f'<={MAXIMUM_SEED_SIZE_NAME}, so it has changed'
f"Found a seed ({first.package_name}.{first.name}) "
f">{MAXIMUM_SEED_SIZE_NAME} in size. The previous file was "
f"<={MAXIMUM_SEED_SIZE_NAME}, so it has changed"
)
elif result:
msg = (
f'Found a seed ({first.package_name}.{first.name}) '
f'>{MAXIMUM_SEED_SIZE_NAME} in size at the same path, dbt '
f'cannot tell if it has changed: assuming they are the same'
f"Found a seed ({first.package_name}.{first.name}) "
f">{MAXIMUM_SEED_SIZE_NAME} in size at the same path, dbt "
f"cannot tell if it has changed: assuming they are the same"
)
elif not result:
msg = (
f'Found a seed ({first.package_name}.{first.name}) '
f'>{MAXIMUM_SEED_SIZE_NAME} in size. The previous file was in '
f'a different location, assuming it has changed'
f"Found a seed ({first.package_name}.{first.name}) "
f">{MAXIMUM_SEED_SIZE_NAME} in size. The previous file was in "
f"a different location, assuming it has changed"
)
else:
msg = (
f'Found a seed ({first.package_name}.{first.name}) '
f'>{MAXIMUM_SEED_SIZE_NAME} in size. The previous file had a '
f'checksum type of {second.checksum.name}, so it has changed'
f"Found a seed ({first.package_name}.{first.name}) "
f">{MAXIMUM_SEED_SIZE_NAME} in size. The previous file had a "
f"checksum type of {second.checksum.name}, so it has changed"
)
warn_or_error(msg, node=first)
@@ -331,7 +325,7 @@ def same_seeds(first: ParsedNode, second: ParsedNode) -> bool:
@dataclass
class ParsedSeedNode(ParsedNode):
# keep this in sync with CompiledSeedNode!
resource_type: NodeType = field(metadata={'restrict': [NodeType.Seed]})
resource_type: NodeType = field(metadata={"restrict": [NodeType.Seed]})
config: SeedConfig = field(default_factory=SeedConfig)
@property
@@ -357,21 +351,20 @@ class HasTestMetadata(dbtClassMixin):
@dataclass
class ParsedDataTestNode(ParsedNode):
resource_type: NodeType = field(metadata={'restrict': [NodeType.Test]})
resource_type: NodeType = field(metadata={"restrict": [NodeType.Test]})
config: TestConfig = field(default_factory=TestConfig)
@dataclass
class ParsedSchemaTestNode(ParsedNode, HasTestMetadata):
# keep this in sync with CompiledSchemaTestNode!
resource_type: NodeType = field(metadata={'restrict': [NodeType.Test]})
resource_type: NodeType = field(metadata={"restrict": [NodeType.Test]})
column_name: Optional[str] = None
config: TestConfig = field(default_factory=TestConfig)
def same_config(self, other) -> bool:
return (
self.unrendered_config.get('severity') ==
other.unrendered_config.get('severity')
return self.unrendered_config.get("severity") == other.unrendered_config.get(
"severity"
)
def same_column_name(self, other) -> bool:
@@ -381,11 +374,7 @@ class ParsedSchemaTestNode(ParsedNode, HasTestMetadata):
if other is None:
return False
return (
self.same_config(other) and
self.same_fqn(other) and
True
)
return self.same_config(other) and self.same_fqn(other) and True
@dataclass
@@ -396,13 +385,13 @@ class IntermediateSnapshotNode(ParsedNode):
# defined in config blocks. To fix that, we have an intermediate type that
# uses a regular node config, which the snapshot parser will then convert
# into a full ParsedSnapshotNode after rendering.
resource_type: NodeType = field(metadata={'restrict': [NodeType.Snapshot]})
resource_type: NodeType = field(metadata={"restrict": [NodeType.Snapshot]})
config: EmptySnapshotConfig = field(default_factory=EmptySnapshotConfig)
@dataclass
class ParsedSnapshotNode(ParsedNode):
resource_type: NodeType = field(metadata={'restrict': [NodeType.Snapshot]})
resource_type: NodeType = field(metadata={"restrict": [NodeType.Snapshot]})
config: SnapshotConfig
@@ -431,12 +420,12 @@ class ParsedMacroPatch(ParsedPatch):
class ParsedMacro(UnparsedBaseNode, HasUniqueID):
name: str
macro_sql: str
resource_type: NodeType = field(metadata={'restrict': [NodeType.Macro]})
resource_type: NodeType = field(metadata={"restrict": [NodeType.Macro]})
# TODO: can macros even have tags?
tags: List[str] = field(default_factory=list)
# TODO: is this ever populated?
depends_on: MacroDependsOn = field(default_factory=MacroDependsOn)
description: str = ''
description: str = ""
meta: Dict[str, Any] = field(default_factory=dict)
docs: Docs = field(default_factory=Docs)
patch_path: Optional[str] = None
@@ -457,7 +446,7 @@ class ParsedMacro(UnparsedBaseNode, HasUniqueID):
dct = self.to_dict(omit_none=False)
self.validate(dct)
def same_contents(self, other: Optional['ParsedMacro']) -> bool:
def same_contents(self, other: Optional["ParsedMacro"]) -> bool:
if other is None:
return False
# the only thing that makes one macro different from another with the
@@ -474,7 +463,7 @@ class ParsedDocumentation(UnparsedDocumentation, HasUniqueID):
def search_name(self):
return self.name
def same_contents(self, other: Optional['ParsedDocumentation']) -> bool:
def same_contents(self, other: Optional["ParsedDocumentation"]) -> bool:
if other is None:
return False
# the only thing that makes one doc different from another with the
@@ -493,11 +482,11 @@ def normalize_test(testdef: TestDef) -> Dict[str, Any]:
class UnpatchedSourceDefinition(UnparsedBaseNode, HasUniqueID, HasFqn):
source: UnparsedSourceDefinition
table: UnparsedSourceTableDefinition
resource_type: NodeType = field(metadata={'restrict': [NodeType.Source]})
resource_type: NodeType = field(metadata={"restrict": [NodeType.Source]})
patch_path: Optional[Path] = None
def get_full_source_name(self):
return f'{self.source.name}_{self.table.name}'
return f"{self.source.name}_{self.table.name}"
def get_source_representation(self):
return f'source("{self.source.name}", "{self.table.name}")'
@@ -522,9 +511,7 @@ class UnpatchedSourceDefinition(UnparsedBaseNode, HasUniqueID, HasFqn):
else:
return self.table.columns
def get_tests(
self
) -> Iterator[Tuple[Dict[str, Any], Optional[UnparsedColumn]]]:
def get_tests(self) -> Iterator[Tuple[Dict[str, Any], Optional[UnparsedColumn]]]:
for test in self.tests:
yield normalize_test(test), None
@@ -543,22 +530,19 @@ class UnpatchedSourceDefinition(UnparsedBaseNode, HasUniqueID, HasFqn):
@dataclass
class ParsedSourceDefinition(
UnparsedBaseNode,
HasUniqueID,
HasRelationMetadata,
HasFqn
UnparsedBaseNode, HasUniqueID, HasRelationMetadata, HasFqn
):
name: str
source_name: str
source_description: str
loader: str
identifier: str
resource_type: NodeType = field(metadata={'restrict': [NodeType.Source]})
resource_type: NodeType = field(metadata={"restrict": [NodeType.Source]})
quoting: Quoting = field(default_factory=Quoting)
loaded_at_field: Optional[str] = None
freshness: Optional[FreshnessThreshold] = None
external: Optional[ExternalTable] = None
description: str = ''
description: str = ""
columns: Dict[str, ColumnInfo] = field(default_factory=dict)
meta: Dict[str, Any] = field(default_factory=dict)
source_meta: Dict[str, Any] = field(default_factory=dict)
@@ -568,36 +552,34 @@ class ParsedSourceDefinition(
unrendered_config: Dict[str, Any] = field(default_factory=dict)
relation_name: Optional[str] = None
def same_database_representation(
self, other: 'ParsedSourceDefinition'
) -> bool:
def same_database_representation(self, other: "ParsedSourceDefinition") -> bool:
return (
self.database == other.database and
self.schema == other.schema and
self.identifier == other.identifier and
True
self.database == other.database
and self.schema == other.schema
and self.identifier == other.identifier
and True
)
def same_quoting(self, other: 'ParsedSourceDefinition') -> bool:
def same_quoting(self, other: "ParsedSourceDefinition") -> bool:
return self.quoting == other.quoting
def same_freshness(self, other: 'ParsedSourceDefinition') -> bool:
def same_freshness(self, other: "ParsedSourceDefinition") -> bool:
return (
self.freshness == other.freshness and
self.loaded_at_field == other.loaded_at_field and
True
self.freshness == other.freshness
and self.loaded_at_field == other.loaded_at_field
and True
)
def same_external(self, other: 'ParsedSourceDefinition') -> bool:
def same_external(self, other: "ParsedSourceDefinition") -> bool:
return self.external == other.external
def same_config(self, old: 'ParsedSourceDefinition') -> bool:
def same_config(self, old: "ParsedSourceDefinition") -> bool:
return self.config.same_contents(
self.unrendered_config,
old.unrendered_config,
)
def same_contents(self, old: Optional['ParsedSourceDefinition']) -> bool:
def same_contents(self, old: Optional["ParsedSourceDefinition"]) -> bool:
# existing when it didn't before is a change!
if old is None:
return True
@@ -611,17 +593,17 @@ class ParsedSourceDefinition(
# metadata/tags changes are not "changes"
# patching/description changes are not "changes"
return (
self.same_database_representation(old) and
self.same_fqn(old) and
self.same_config(old) and
self.same_quoting(old) and
self.same_freshness(old) and
self.same_external(old) and
True
self.same_database_representation(old)
and self.same_fqn(old)
and self.same_config(old)
and self.same_quoting(old)
and self.same_freshness(old)
and self.same_external(old)
and True
)
def get_full_source_name(self):
return f'{self.source_name}_{self.name}'
return f"{self.source_name}_{self.name}"
def get_source_representation(self):
return f'source("{self.source.name}", "{self.table.name}")'
@@ -656,7 +638,7 @@ class ParsedSourceDefinition(
@property
def search_name(self):
return f'{self.source_name}.{self.name}'
return f"{self.source_name}.{self.name}"
@dataclass
@@ -665,7 +647,7 @@ class ParsedExposure(UnparsedBaseNode, HasUniqueID, HasFqn):
type: ExposureType
owner: ExposureOwner
resource_type: NodeType = NodeType.Exposure
description: str = ''
description: str = ""
maturity: Optional[MaturityType] = None
url: Optional[str] = None
depends_on: DependsOn = field(default_factory=DependsOn)
@@ -685,38 +667,38 @@ class ParsedExposure(UnparsedBaseNode, HasUniqueID, HasFqn):
def tags(self):
return []
def same_depends_on(self, old: 'ParsedExposure') -> bool:
def same_depends_on(self, old: "ParsedExposure") -> bool:
return set(self.depends_on.nodes) == set(old.depends_on.nodes)
def same_description(self, old: 'ParsedExposure') -> bool:
def same_description(self, old: "ParsedExposure") -> bool:
return self.description == old.description
def same_maturity(self, old: 'ParsedExposure') -> bool:
def same_maturity(self, old: "ParsedExposure") -> bool:
return self.maturity == old.maturity
def same_owner(self, old: 'ParsedExposure') -> bool:
def same_owner(self, old: "ParsedExposure") -> bool:
return self.owner == old.owner
def same_exposure_type(self, old: 'ParsedExposure') -> bool:
def same_exposure_type(self, old: "ParsedExposure") -> bool:
return self.type == old.type
def same_url(self, old: 'ParsedExposure') -> bool:
def same_url(self, old: "ParsedExposure") -> bool:
return self.url == old.url
def same_contents(self, old: Optional['ParsedExposure']) -> bool:
def same_contents(self, old: Optional["ParsedExposure"]) -> bool:
# existing when it didn't before is a change!
if old is None:
return True
return (
self.same_fqn(old) and
self.same_exposure_type(old) and
self.same_owner(old) and
self.same_maturity(old) and
self.same_url(old) and
self.same_description(old) and
self.same_depends_on(old) and
True
self.same_fqn(old)
and self.same_exposure_type(old)
and self.same_owner(old)
and self.same_maturity(old)
and self.same_url(old)
and self.same_description(old)
and self.same_depends_on(old)
and True
)

View File

@@ -4,13 +4,12 @@ from dbt.contracts.util import (
Mergeable,
Replaceable,
)
# trigger the PathEncoder
import dbt.helper_types # noqa:F401
from dbt.exceptions import CompilationException
from dbt.dataclass_schema import (
dbtClassMixin, StrEnum, ExtensibleDbtClassMixin
)
from dbt.dataclass_schema import dbtClassMixin, StrEnum, ExtensibleDbtClassMixin
from dataclasses import dataclass, field
from datetime import timedelta
@@ -37,21 +36,25 @@ class HasSQL:
@dataclass
class UnparsedMacro(UnparsedBaseNode, HasSQL):
resource_type: NodeType = field(metadata={'restrict': [NodeType.Macro]})
resource_type: NodeType = field(metadata={"restrict": [NodeType.Macro]})
@dataclass
class UnparsedNode(UnparsedBaseNode, HasSQL):
name: str
resource_type: NodeType = field(metadata={'restrict': [
NodeType.Model,
NodeType.Analysis,
NodeType.Test,
NodeType.Snapshot,
NodeType.Operation,
NodeType.Seed,
NodeType.RPCCall,
]})
resource_type: NodeType = field(
metadata={
"restrict": [
NodeType.Model,
NodeType.Analysis,
NodeType.Test,
NodeType.Snapshot,
NodeType.Operation,
NodeType.Seed,
NodeType.RPCCall,
]
}
)
@property
def search_name(self):
@@ -60,9 +63,7 @@ class UnparsedNode(UnparsedBaseNode, HasSQL):
@dataclass
class UnparsedRunHook(UnparsedNode):
resource_type: NodeType = field(
metadata={'restrict': [NodeType.Operation]}
)
resource_type: NodeType = field(metadata={"restrict": [NodeType.Operation]})
index: Optional[int] = None
@@ -72,10 +73,9 @@ class Docs(dbtClassMixin, Replaceable):
@dataclass
class HasDocs(AdditionalPropertiesMixin, ExtensibleDbtClassMixin,
Replaceable):
class HasDocs(AdditionalPropertiesMixin, ExtensibleDbtClassMixin, Replaceable):
name: str
description: str = ''
description: str = ""
meta: Dict[str, Any] = field(default_factory=dict)
data_type: Optional[str] = None
docs: Docs = field(default_factory=Docs)
@@ -131,7 +131,7 @@ class UnparsedNodeUpdate(HasColumnTests, HasTests, HasYamlMetadata):
class MacroArgument(dbtClassMixin):
name: str
type: Optional[str] = None
description: str = ''
description: str = ""
@dataclass
@@ -140,12 +140,12 @@ class UnparsedMacroUpdate(HasDocs, HasYamlMetadata):
class TimePeriod(StrEnum):
minute = 'minute'
hour = 'hour'
day = 'day'
minute = "minute"
hour = "hour"
day = "day"
def plural(self) -> str:
return str(self) + 's'
return str(self) + "s"
@dataclass
@@ -167,6 +167,7 @@ class FreshnessThreshold(dbtClassMixin, Mergeable):
def status(self, age: float) -> "dbt.contracts.results.FreshnessStatus":
from dbt.contracts.results import FreshnessStatus
if self.error_after and self.error_after.exceeded(age):
return FreshnessStatus.Error
elif self.warn_after and self.warn_after.exceeded(age):
@@ -179,24 +180,21 @@ class FreshnessThreshold(dbtClassMixin, Mergeable):
@dataclass
class AdditionalPropertiesAllowed(
AdditionalPropertiesMixin,
ExtensibleDbtClassMixin
):
class AdditionalPropertiesAllowed(AdditionalPropertiesMixin, ExtensibleDbtClassMixin):
_extra: Dict[str, Any] = field(default_factory=dict)
@dataclass
class ExternalPartition(AdditionalPropertiesAllowed, Replaceable):
name: str = ''
description: str = ''
data_type: str = ''
name: str = ""
description: str = ""
data_type: str = ""
meta: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
if self.name == '' or self.data_type == '':
if self.name == "" or self.data_type == "":
raise CompilationException(
'External partition columns must have names and data types'
"External partition columns must have names and data types"
)
@@ -225,43 +223,39 @@ class UnparsedSourceTableDefinition(HasColumnTests, HasTests):
loaded_at_field: Optional[str] = None
identifier: Optional[str] = None
quoting: Quoting = field(default_factory=Quoting)
freshness: Optional[FreshnessThreshold] = field(
default_factory=FreshnessThreshold
)
freshness: Optional[FreshnessThreshold] = field(default_factory=FreshnessThreshold)
external: Optional[ExternalTable] = None
tags: List[str] = field(default_factory=list)
def __post_serialize__(self, dct):
dct = super().__post_serialize__(dct)
if 'freshness' not in dct and self.freshness is None:
dct['freshness'] = None
if "freshness" not in dct and self.freshness is None:
dct["freshness"] = None
return dct
@dataclass
class UnparsedSourceDefinition(dbtClassMixin, Replaceable):
name: str
description: str = ''
description: str = ""
meta: Dict[str, Any] = field(default_factory=dict)
database: Optional[str] = None
schema: Optional[str] = None
loader: str = ''
loader: str = ""
quoting: Quoting = field(default_factory=Quoting)
freshness: Optional[FreshnessThreshold] = field(
default_factory=FreshnessThreshold
)
freshness: Optional[FreshnessThreshold] = field(default_factory=FreshnessThreshold)
loaded_at_field: Optional[str] = None
tables: List[UnparsedSourceTableDefinition] = field(default_factory=list)
tags: List[str] = field(default_factory=list)
@property
def yaml_key(self) -> 'str':
return 'sources'
def yaml_key(self) -> "str":
return "sources"
def __post_serialize__(self, dct):
dct = super().__post_serialize__(dct)
if 'freshnewss' not in dct and self.freshness is None:
dct['freshness'] = None
if "freshnewss" not in dct and self.freshness is None:
dct["freshness"] = None
return dct
@@ -275,9 +269,7 @@ class SourceTablePatch(dbtClassMixin):
loaded_at_field: Optional[str] = None
identifier: Optional[str] = None
quoting: Quoting = field(default_factory=Quoting)
freshness: Optional[FreshnessThreshold] = field(
default_factory=FreshnessThreshold
)
freshness: Optional[FreshnessThreshold] = field(default_factory=FreshnessThreshold)
external: Optional[ExternalTable] = None
tags: Optional[List[str]] = None
tests: Optional[List[TestDef]] = None
@@ -285,13 +277,13 @@ class SourceTablePatch(dbtClassMixin):
def to_patch_dict(self) -> Dict[str, Any]:
dct = self.to_dict(omit_none=True)
remove_keys = ('name')
remove_keys = "name"
for key in remove_keys:
if key in dct:
del dct[key]
if self.freshness is None:
dct['freshness'] = None
dct["freshness"] = None
return dct
@@ -299,13 +291,13 @@ class SourceTablePatch(dbtClassMixin):
@dataclass
class SourcePatch(dbtClassMixin, Replaceable):
name: str = field(
metadata=dict(description='The name of the source to override'),
metadata=dict(description="The name of the source to override"),
)
overrides: str = field(
metadata=dict(description='The package of the source to override'),
metadata=dict(description="The package of the source to override"),
)
path: Path = field(
metadata=dict(description='The path to the patch-defining yml file'),
metadata=dict(description="The path to the patch-defining yml file"),
)
description: Optional[str] = None
meta: Optional[Dict[str, Any]] = None
@@ -322,13 +314,13 @@ class SourcePatch(dbtClassMixin, Replaceable):
def to_patch_dict(self) -> Dict[str, Any]:
dct = self.to_dict(omit_none=True)
remove_keys = ('name', 'overrides', 'tables', 'path')
remove_keys = ("name", "overrides", "tables", "path")
for key in remove_keys:
if key in dct:
del dct[key]
if self.freshness is None:
dct['freshness'] = None
dct["freshness"] = None
return dct
@@ -360,9 +352,9 @@ class UnparsedDocumentationFile(UnparsedDocumentation):
# can't use total_ordering decorator here, as str provides an ordering already
# and it's not the one we want.
class Maturity(StrEnum):
low = 'low'
medium = 'medium'
high = 'high'
low = "low"
medium = "medium"
high = "high"
def __lt__(self, other):
if not isinstance(other, Maturity):
@@ -387,17 +379,17 @@ class Maturity(StrEnum):
class ExposureType(StrEnum):
Dashboard = 'dashboard'
Notebook = 'notebook'
Analysis = 'analysis'
ML = 'ml'
Application = 'application'
Dashboard = "dashboard"
Notebook = "notebook"
Analysis = "analysis"
ML = "ml"
Application = "application"
class MaturityType(StrEnum):
Low = 'low'
Medium = 'medium'
High = 'high'
Low = "low"
Medium = "medium"
High = "high"
@dataclass
@@ -411,7 +403,7 @@ class UnparsedExposure(dbtClassMixin, Replaceable):
name: str
type: ExposureType
owner: ExposureOwner
description: str = ''
description: str = ""
maturity: Optional[MaturityType] = None
url: Optional[str] = None
depends_on: List[str] = field(default_factory=list)

View File

@@ -5,24 +5,26 @@ from dbt.logger import GLOBAL_LOGGER as logger # noqa
from dbt import tracking
from dbt import ui
from dbt.dataclass_schema import (
dbtClassMixin, ValidationError,
dbtClassMixin,
ValidationError,
HyphenatedDbtClassMixin,
ExtensibleDbtClassMixin,
register_pattern, ValidatedStringMixin
register_pattern,
ValidatedStringMixin,
)
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Union, Any
from mashumaro.types import SerializableType
PIN_PACKAGE_URL = 'https://docs.getdbt.com/docs/package-management#section-specifying-package-versions' # noqa
PIN_PACKAGE_URL = "https://docs.getdbt.com/docs/package-management#section-specifying-package-versions" # noqa
DEFAULT_SEND_ANONYMOUS_USAGE_STATS = True
class Name(ValidatedStringMixin):
ValidationRegex = r'^[^\d\W]\w*$'
ValidationRegex = r"^[^\d\W]\w*$"
register_pattern(Name, r'^[^\d\W]\w*$')
register_pattern(Name, r"^[^\d\W]\w*$")
class SemverString(str, SerializableType):
@@ -30,7 +32,7 @@ class SemverString(str, SerializableType):
return self
@classmethod
def _deserialize(cls, value: str) -> 'SemverString':
def _deserialize(cls, value: str) -> "SemverString":
return SemverString(value)
@@ -39,7 +41,7 @@ class SemverString(str, SerializableType):
# 'semver lite'.
register_pattern(
SemverString,
r'^(?:0|[1-9]\d*)\.(?:0|[1-9]\d*)(\.(?:0|[1-9]\d*))?$',
r"^(?:0|[1-9]\d*)\.(?:0|[1-9]\d*)(\.(?:0|[1-9]\d*))?$",
)
@@ -105,8 +107,7 @@ class ProjectPackageMetadata:
@classmethod
def from_project(cls, project):
return cls(name=project.project_name,
packages=project.packages.packages)
return cls(name=project.project_name, packages=project.packages.packages)
@dataclass
@@ -124,46 +125,46 @@ class RegistryPackageMetadata(
# A list of all the reserved words that packages may not have as names.
BANNED_PROJECT_NAMES = {
'_sql_results',
'adapter',
'api',
'column',
'config',
'context',
'database',
'env',
'env_var',
'exceptions',
'execute',
'flags',
'fromjson',
'fromyaml',
'graph',
'invocation_id',
'load_agate_table',
'load_result',
'log',
'model',
'modules',
'post_hooks',
'pre_hooks',
'ref',
'render',
'return',
'run_started_at',
'schema',
'source',
'sql',
'sql_now',
'store_result',
'store_raw_result',
'target',
'this',
'tojson',
'toyaml',
'try_or_compiler_error',
'var',
'write',
"_sql_results",
"adapter",
"api",
"column",
"config",
"context",
"database",
"env",
"env_var",
"exceptions",
"execute",
"flags",
"fromjson",
"fromyaml",
"graph",
"invocation_id",
"load_agate_table",
"load_result",
"log",
"model",
"modules",
"post_hooks",
"pre_hooks",
"ref",
"render",
"return",
"run_started_at",
"schema",
"source",
"sql",
"sql_now",
"store_result",
"store_raw_result",
"target",
"this",
"tojson",
"toyaml",
"try_or_compiler_error",
"var",
"write",
}
@@ -198,7 +199,7 @@ class Project(HyphenatedDbtClassMixin, Replaceable):
vars: Optional[Dict[str, Any]] = field(
default=None,
metadata=dict(
description='map project names to their vars override dicts',
description="map project names to their vars override dicts",
),
)
packages: List[PackageSpec] = field(default_factory=list)
@@ -207,7 +208,7 @@ class Project(HyphenatedDbtClassMixin, Replaceable):
@classmethod
def validate(cls, data):
super().validate(data)
if data['name'] in BANNED_PROJECT_NAMES:
if data["name"] in BANNED_PROJECT_NAMES:
raise ValidationError(
f"Invalid project name: {data['name']} is a reserved word"
)
@@ -235,8 +236,8 @@ class UserConfig(ExtensibleDbtClassMixin, Replaceable, UserConfigContract):
@dataclass
class ProfileConfig(HyphenatedDbtClassMixin, Replaceable):
profile_name: str = field(metadata={'preserve_underscore': True})
target_name: str = field(metadata={'preserve_underscore': True})
profile_name: str = field(metadata={"preserve_underscore": True})
target_name: str = field(metadata={"preserve_underscore": True})
config: UserConfig
threads: int
# TODO: make this a dynamic union of some kind?
@@ -255,7 +256,7 @@ class ConfiguredQuoting(Quoting, Replaceable):
class Configuration(Project, ProfileConfig):
cli_vars: Dict[str, Any] = field(
default_factory=dict,
metadata={'preserve_underscore': True},
metadata={"preserve_underscore": True},
)
quoting: Optional[ConfiguredQuoting] = None

View File

@@ -1,7 +1,8 @@
from collections.abc import Mapping
from dataclasses import dataclass, fields
from typing import (
Optional, Dict,
Optional,
Dict,
)
from typing_extensions import Protocol
@@ -14,17 +15,17 @@ from dbt.utils import deep_merge
class RelationType(StrEnum):
Table = 'table'
View = 'view'
CTE = 'cte'
MaterializedView = 'materializedview'
External = 'external'
Table = "table"
View = "view"
CTE = "cte"
MaterializedView = "materializedview"
External = "external"
class ComponentName(StrEnum):
Database = 'database'
Schema = 'schema'
Identifier = 'identifier'
Database = "database"
Schema = "schema"
Identifier = "identifier"
class HasQuoting(Protocol):
@@ -43,12 +44,12 @@ class FakeAPIObject(dbtClassMixin, Replaceable, Mapping):
raise KeyError(key) from None
def __iter__(self):
deprecations.warn('not-a-dictionary', obj=self)
deprecations.warn("not-a-dictionary", obj=self)
for _, name in self._get_fields():
yield name
def __len__(self):
deprecations.warn('not-a-dictionary', obj=self)
deprecations.warn("not-a-dictionary", obj=self)
return len(fields(self.__class__))
def incorporate(self, **kwargs):
@@ -72,8 +73,7 @@ class Policy(FakeAPIObject):
return self.identifier
else:
raise ValueError(
'Got a key of {}, expected one of {}'
.format(key, list(ComponentName))
"Got a key of {}, expected one of {}".format(key, list(ComponentName))
)
def replace_dict(self, dct: Dict[ComponentName, bool]):
@@ -93,15 +93,15 @@ class Path(FakeAPIObject):
# handle pesky jinja2.Undefined sneaking in here and messing up rende
if not isinstance(self.database, (type(None), str)):
raise CompilationException(
'Got an invalid path database: {}'.format(self.database)
"Got an invalid path database: {}".format(self.database)
)
if not isinstance(self.schema, (type(None), str)):
raise CompilationException(
'Got an invalid path schema: {}'.format(self.schema)
"Got an invalid path schema: {}".format(self.schema)
)
if not isinstance(self.identifier, (type(None), str)):
raise CompilationException(
'Got an invalid path identifier: {}'.format(self.identifier)
"Got an invalid path identifier: {}".format(self.identifier)
)
def get_lowered_part(self, key: ComponentName) -> Optional[str]:
@@ -119,8 +119,7 @@ class Path(FakeAPIObject):
return self.identifier
else:
raise ValueError(
'Got a key of {}, expected one of {}'
.format(key, list(ComponentName))
"Got a key of {}, expected one of {}".format(key, list(ComponentName))
)
def replace_dict(self, dct: Dict[ComponentName, str]):

View File

@@ -1,7 +1,5 @@
from dbt.contracts.graph.manifest import CompileResultNode
from dbt.contracts.graph.unparsed import (
FreshnessThreshold
)
from dbt.contracts.graph.unparsed import FreshnessThreshold
from dbt.contracts.graph.parsed import ParsedSourceDefinition
from dbt.contracts.util import (
BaseArtifactMetadata,
@@ -24,7 +22,13 @@ import agate
from dataclasses import dataclass, field
from datetime import datetime
from typing import (
Union, Dict, List, Optional, Any, NamedTuple, Sequence,
Union,
Dict,
List,
Optional,
Any,
NamedTuple,
Sequence,
)
from dbt.clients.system import write_json
@@ -54,7 +58,7 @@ class collect_timing_info:
def __exit__(self, exc_type, exc_value, traceback):
self.timing_info.end()
with JsonOnly(), TimingProcessor(self.timing_info):
logger.debug('finished collecting timing info')
logger.debug("finished collecting timing info")
class NodeStatus(StrEnum):
@@ -99,8 +103,8 @@ class BaseResult(dbtClassMixin):
@classmethod
def __pre_deserialize__(cls, data):
data = super().__pre_deserialize__(data)
if 'message' not in data:
data['message'] = None
if "message" not in data:
data["message"] = None
return data
@@ -112,9 +116,8 @@ class NodeResult(BaseResult):
@dataclass
class RunResult(NodeResult):
agate_table: Optional[agate.Table] = field(
default=None, metadata={
'serialize': lambda x: None, 'deserialize': lambda x: None
}
default=None,
metadata={"serialize": lambda x: None, "deserialize": lambda x: None},
)
@property
@@ -157,7 +160,7 @@ def process_run_result(result: RunResult) -> RunResultOutput:
thread_id=result.thread_id,
execution_time=result.execution_time,
message=result.message,
adapter_response=result.adapter_response
adapter_response=result.adapter_response,
)
@@ -180,7 +183,7 @@ class RunExecutionResult(
@dataclass
@schema_version('run-results', 1)
@schema_version("run-results", 1)
class RunResultsArtifact(ExecutionResult, ArtifactMixin):
results: Sequence[RunResultOutput]
args: Dict[str, Any] = field(default_factory=dict)
@@ -202,7 +205,7 @@ class RunResultsArtifact(ExecutionResult, ArtifactMixin):
metadata=meta,
results=processed_results,
elapsed_time=elapsed_time,
args=args
args=args,
)
def write(self, path: str):
@@ -216,15 +219,14 @@ class RunOperationResult(ExecutionResult):
@dataclass
class RunOperationResultMetadata(BaseArtifactMetadata):
dbt_schema_version: str = field(default_factory=lambda: str(
RunOperationResultsArtifact.dbt_schema_version
))
dbt_schema_version: str = field(
default_factory=lambda: str(RunOperationResultsArtifact.dbt_schema_version)
)
@dataclass
@schema_version('run-operation-result', 1)
@schema_version("run-operation-result", 1)
class RunOperationResultsArtifact(RunOperationResult, ArtifactMixin):
@classmethod
def from_success(
cls,
@@ -243,6 +245,7 @@ class RunOperationResultsArtifact(RunOperationResult, ArtifactMixin):
success=success,
)
# due to issues with typing.Union collapsing subclasses, this can't subclass
# PartialResult
@@ -261,7 +264,7 @@ class SourceFreshnessResult(NodeResult):
class FreshnessErrorEnum(StrEnum):
runtime_error = 'runtime error'
runtime_error = "runtime error"
@dataclass
@@ -291,14 +294,11 @@ class PartialSourceFreshnessResult(NodeResult):
return False
FreshnessNodeResult = Union[PartialSourceFreshnessResult,
SourceFreshnessResult]
FreshnessNodeResult = Union[PartialSourceFreshnessResult, SourceFreshnessResult]
FreshnessNodeOutput = Union[SourceFreshnessRuntimeError, SourceFreshnessOutput]
def process_freshness_result(
result: FreshnessNodeResult
) -> FreshnessNodeOutput:
def process_freshness_result(result: FreshnessNodeResult) -> FreshnessNodeOutput:
unique_id = result.node.unique_id
if result.status == FreshnessStatus.RuntimeErr:
return SourceFreshnessRuntimeError(
@@ -310,16 +310,15 @@ def process_freshness_result(
# we know that this must be a SourceFreshnessResult
if not isinstance(result, SourceFreshnessResult):
raise InternalException(
'Got {} instead of a SourceFreshnessResult for a '
'non-error result in freshness execution!'
.format(type(result))
"Got {} instead of a SourceFreshnessResult for a "
"non-error result in freshness execution!".format(type(result))
)
# if we're here, we must have a non-None freshness threshold
criteria = result.node.freshness
if criteria is None:
raise InternalException(
'Somehow evaluated a freshness result for a source '
'that has no freshness criteria!'
"Somehow evaluated a freshness result for a source "
"that has no freshness criteria!"
)
return SourceFreshnessOutput(
unique_id=unique_id,
@@ -328,16 +327,14 @@ def process_freshness_result(
max_loaded_at_time_ago_in_s=result.age,
status=result.status,
criteria=criteria,
adapter_response=result.adapter_response
adapter_response=result.adapter_response,
)
@dataclass
class FreshnessMetadata(BaseArtifactMetadata):
dbt_schema_version: str = field(
default_factory=lambda: str(
FreshnessExecutionResultArtifact.dbt_schema_version
)
default_factory=lambda: str(FreshnessExecutionResultArtifact.dbt_schema_version)
)
@@ -358,7 +355,7 @@ class FreshnessResult(ExecutionResult):
@dataclass
@schema_version('sources', 1)
@schema_version("sources", 1)
class FreshnessExecutionResultArtifact(
ArtifactMixin,
VersionedSchema,
@@ -380,8 +377,7 @@ class FreshnessExecutionResultArtifact(
Primitive = Union[bool, str, float, None]
CatalogKey = NamedTuple(
'CatalogKey',
[('database', Optional[str]), ('schema', str), ('name', str)]
"CatalogKey", [("database", Optional[str]), ("schema", str), ("name", str)]
)
@@ -450,13 +446,13 @@ class CatalogResults(dbtClassMixin):
def __post_serialize__(self, dct):
dct = super().__post_serialize__(dct)
if '_compile_results' in dct:
del dct['_compile_results']
if "_compile_results" in dct:
del dct["_compile_results"]
return dct
@dataclass
@schema_version('catalog', 1)
@schema_version("catalog", 1)
class CatalogArtifact(CatalogResults, ArtifactMixin):
metadata: CatalogMetadata
@@ -467,8 +463,8 @@ class CatalogArtifact(CatalogResults, ArtifactMixin):
nodes: Dict[str, CatalogTable],
sources: Dict[str, CatalogTable],
compile_results: Optional[Any],
errors: Optional[List[str]]
) -> 'CatalogArtifact':
errors: Optional[List[str]],
) -> "CatalogArtifact":
meta = CatalogMetadata(generated_at=generated_at)
return cls(
metadata=meta,

View File

@@ -10,7 +10,9 @@ from dbt.dataclass_schema import dbtClassMixin, StrEnum
from dbt.contracts.graph.compiled import CompileResultNode
from dbt.contracts.graph.manifest import WritableManifest
from dbt.contracts.results import (
RunResult, RunResultsArtifact, TimingInfo,
RunResult,
RunResultsArtifact,
TimingInfo,
CatalogArtifact,
CatalogResults,
ExecutionResult,
@@ -40,10 +42,10 @@ class RPCParameters(dbtClassMixin):
@classmethod
def __pre_deserialize__(cls, data, omit_none=True):
data = super().__pre_deserialize__(data)
if 'timeout' not in data:
data['timeout'] = None
if 'task_tags' not in data:
data['task_tags'] = None
if "timeout" not in data:
data["timeout"] = None
if "task_tags" not in data:
data["task_tags"] = None
return data
@@ -161,6 +163,7 @@ class GCParameters(RPCParameters):
will be applied to the task manager before GC starts. By default the
existing gc settings remain.
"""
task_ids: Optional[List[TaskID]] = None
before: Optional[datetime] = None
settings: Optional[GCSettings] = None
@@ -182,6 +185,7 @@ class RPCSourceFreshnessParameters(RPCParameters):
class GetManifestParameters(RPCParameters):
pass
# Outputs
@@ -191,13 +195,13 @@ class RemoteResult(VersionedSchema):
@dataclass
@schema_version('remote-deps-result', 1)
@schema_version("remote-deps-result", 1)
class RemoteDepsResult(RemoteResult):
generated_at: datetime = field(default_factory=datetime.utcnow)
@dataclass
@schema_version('remote-catalog-result', 1)
@schema_version("remote-catalog-result", 1)
class RemoteCatalogResults(CatalogResults, RemoteResult):
generated_at: datetime = field(default_factory=datetime.utcnow)
@@ -221,7 +225,7 @@ class RemoteCompileResultMixin(RemoteResult):
@dataclass
@schema_version('remote-compile-result', 1)
@schema_version("remote-compile-result", 1)
class RemoteCompileResult(RemoteCompileResultMixin):
generated_at: datetime = field(default_factory=datetime.utcnow)
@@ -231,7 +235,7 @@ class RemoteCompileResult(RemoteCompileResultMixin):
@dataclass
@schema_version('remote-execution-result', 1)
@schema_version("remote-execution-result", 1)
class RemoteExecutionResult(ExecutionResult, RemoteResult):
results: Sequence[RunResult]
args: Dict[str, Any] = field(default_factory=dict)
@@ -251,7 +255,7 @@ class RemoteExecutionResult(ExecutionResult, RemoteResult):
cls,
base: RunExecutionResult,
logs: List[LogMessage],
) -> 'RemoteExecutionResult':
) -> "RemoteExecutionResult":
return cls(
generated_at=base.generated_at,
results=base.results,
@@ -268,7 +272,7 @@ class ResultTable(dbtClassMixin):
@dataclass
@schema_version('remote-run-operation-result', 1)
@schema_version("remote-run-operation-result", 1)
class RemoteRunOperationResult(RunOperationResult, RemoteResult):
generated_at: datetime = field(default_factory=datetime.utcnow)
@@ -277,7 +281,7 @@ class RemoteRunOperationResult(RunOperationResult, RemoteResult):
cls,
base: RunOperationResultsArtifact,
logs: List[LogMessage],
) -> 'RemoteRunOperationResult':
) -> "RemoteRunOperationResult":
return cls(
generated_at=base.metadata.generated_at,
results=base.results,
@@ -296,15 +300,14 @@ class RemoteRunOperationResult(RunOperationResult, RemoteResult):
@dataclass
@schema_version('remote-freshness-result', 1)
@schema_version("remote-freshness-result", 1)
class RemoteFreshnessResult(FreshnessResult, RemoteResult):
@classmethod
def from_local_result(
cls,
base: FreshnessResult,
logs: List[LogMessage],
) -> 'RemoteFreshnessResult':
) -> "RemoteFreshnessResult":
return cls(
metadata=base.metadata,
results=base.results,
@@ -318,7 +321,7 @@ class RemoteFreshnessResult(FreshnessResult, RemoteResult):
@dataclass
@schema_version('remote-run-result', 1)
@schema_version("remote-run-result", 1)
class RemoteRunResult(RemoteCompileResultMixin):
table: ResultTable
generated_at: datetime = field(default_factory=datetime.utcnow)
@@ -336,14 +339,15 @@ RPCResult = Union[
# GC types
class GCResultState(StrEnum):
Deleted = 'deleted' # successful GC
Missing = 'missing' # nothing to GC
Running = 'running' # can't GC
Deleted = "deleted" # successful GC
Missing = "missing" # nothing to GC
Running = "running" # can't GC
@dataclass
@schema_version('remote-gc-result', 1)
@schema_version("remote-gc-result", 1)
class GCResult(RemoteResult):
logs: List[LogMessage] = field(default_factory=list)
deleted: List[TaskID] = field(default_factory=list)
@@ -358,21 +362,20 @@ class GCResult(RemoteResult):
elif state == GCResultState.Deleted:
self.deleted.append(task_id)
else:
raise InternalException(
f'Got invalid state in add_result: {state}'
)
raise InternalException(f"Got invalid state in add_result: {state}")
# Task management types
class TaskHandlerState(StrEnum):
NotStarted = 'not started'
Initializing = 'initializing'
Running = 'running'
Success = 'success'
Error = 'error'
Killed = 'killed'
Failed = 'failed'
NotStarted = "not started"
Initializing = "initializing"
Running = "running"
Success = "success"
Error = "error"
Killed = "killed"
Failed = "failed"
def __lt__(self, other) -> bool:
"""A logical ordering for TaskHandlerState:
@@ -380,7 +383,7 @@ class TaskHandlerState(StrEnum):
NotStarted < Initializing < Running < (Success, Error, Killed, Failed)
"""
if not isinstance(other, TaskHandlerState):
raise TypeError('cannot compare to non-TaskHandlerState')
raise TypeError("cannot compare to non-TaskHandlerState")
order = (self.NotStarted, self.Initializing, self.Running)
smaller = set()
for value in order:
@@ -392,13 +395,11 @@ class TaskHandlerState(StrEnum):
def __le__(self, other) -> bool:
# so that ((Success <= Error) is True)
return ((self < other) or
(self == other) or
(self.finished and other.finished))
return (self < other) or (self == other) or (self.finished and other.finished)
def __gt__(self, other) -> bool:
if not isinstance(other, TaskHandlerState):
raise TypeError('cannot compare to non-TaskHandlerState')
raise TypeError("cannot compare to non-TaskHandlerState")
order = (self.NotStarted, self.Initializing, self.Running)
smaller = set()
for value in order:
@@ -409,9 +410,7 @@ class TaskHandlerState(StrEnum):
def __ge__(self, other) -> bool:
# so that ((Success <= Error) is True)
return ((self > other) or
(self == other) or
(self.finished and other.finished))
return (self > other) or (self == other) or (self.finished and other.finished)
@property
def finished(self) -> bool:
@@ -430,7 +429,7 @@ class TaskTiming(dbtClassMixin):
@classmethod
def __pre_deserialize__(cls, data):
data = super().__pre_deserialize__(data)
for field_name in ('start', 'end', 'elapsed'):
for field_name in ("start", "end", "elapsed"):
if field_name not in data:
data[field_name] = None
return data
@@ -447,27 +446,27 @@ class TaskRow(TaskTiming):
@dataclass
@schema_version('remote-ps-result', 1)
@schema_version("remote-ps-result", 1)
class PSResult(RemoteResult):
rows: List[TaskRow]
class KillResultStatus(StrEnum):
Missing = 'missing'
NotStarted = 'not_started'
Killed = 'killed'
Finished = 'finished'
Missing = "missing"
NotStarted = "not_started"
Killed = "killed"
Finished = "finished"
@dataclass
@schema_version('remote-kill-result', 1)
@schema_version("remote-kill-result", 1)
class KillResult(RemoteResult):
state: KillResultStatus = KillResultStatus.Missing
logs: List[LogMessage] = field(default_factory=list)
@dataclass
@schema_version('remote-manifest-result', 1)
@schema_version("remote-manifest-result", 1)
class GetManifestResult(RemoteResult):
manifest: Optional[WritableManifest] = None
@@ -498,29 +497,28 @@ class PollResult(RemoteResult, TaskTiming):
@classmethod
def __pre_deserialize__(cls, data):
data = super().__pre_deserialize__(data)
for field_name in ('start', 'end', 'elapsed'):
for field_name in ("start", "end", "elapsed"):
if field_name not in data:
data[field_name] = None
return data
@dataclass
@schema_version('poll-remote-deps-result', 1)
@schema_version("poll-remote-deps-result", 1)
class PollRemoteEmptyCompleteResult(PollResult, RemoteResult):
state: TaskHandlerState = field(
metadata=restrict_to(TaskHandlerState.Success,
TaskHandlerState.Failed),
metadata=restrict_to(TaskHandlerState.Success, TaskHandlerState.Failed),
)
generated_at: datetime = field(default_factory=datetime.utcnow)
@classmethod
def from_result(
cls: Type['PollRemoteEmptyCompleteResult'],
cls: Type["PollRemoteEmptyCompleteResult"],
base: RemoteDepsResult,
tags: TaskTags,
timing: TaskTiming,
logs: List[LogMessage],
) -> 'PollRemoteEmptyCompleteResult':
) -> "PollRemoteEmptyCompleteResult":
return cls(
logs=logs,
tags=tags,
@@ -528,12 +526,12 @@ class PollRemoteEmptyCompleteResult(PollResult, RemoteResult):
start=timing.start,
end=timing.end,
elapsed=timing.elapsed,
generated_at=base.generated_at
generated_at=base.generated_at,
)
@dataclass
@schema_version('poll-remote-killed-result', 1)
@schema_version("poll-remote-killed-result", 1)
class PollKilledResult(PollResult):
state: TaskHandlerState = field(
metadata=restrict_to(TaskHandlerState.Killed),
@@ -541,24 +539,23 @@ class PollKilledResult(PollResult):
@dataclass
@schema_version('poll-remote-execution-result', 1)
@schema_version("poll-remote-execution-result", 1)
class PollExecuteCompleteResult(
RemoteExecutionResult,
PollResult,
):
state: TaskHandlerState = field(
metadata=restrict_to(TaskHandlerState.Success,
TaskHandlerState.Failed),
metadata=restrict_to(TaskHandlerState.Success, TaskHandlerState.Failed),
)
@classmethod
def from_result(
cls: Type['PollExecuteCompleteResult'],
cls: Type["PollExecuteCompleteResult"],
base: RemoteExecutionResult,
tags: TaskTags,
timing: TaskTiming,
logs: List[LogMessage],
) -> 'PollExecuteCompleteResult':
) -> "PollExecuteCompleteResult":
return cls(
results=base.results,
elapsed_time=base.elapsed_time,
@@ -573,24 +570,23 @@ class PollExecuteCompleteResult(
@dataclass
@schema_version('poll-remote-compile-result', 1)
@schema_version("poll-remote-compile-result", 1)
class PollCompileCompleteResult(
RemoteCompileResult,
PollResult,
):
state: TaskHandlerState = field(
metadata=restrict_to(TaskHandlerState.Success,
TaskHandlerState.Failed),
metadata=restrict_to(TaskHandlerState.Success, TaskHandlerState.Failed),
)
@classmethod
def from_result(
cls: Type['PollCompileCompleteResult'],
cls: Type["PollCompileCompleteResult"],
base: RemoteCompileResult,
tags: TaskTags,
timing: TaskTiming,
logs: List[LogMessage],
) -> 'PollCompileCompleteResult':
) -> "PollCompileCompleteResult":
return cls(
raw_sql=base.raw_sql,
compiled_sql=base.compiled_sql,
@@ -602,29 +598,28 @@ class PollCompileCompleteResult(
start=timing.start,
end=timing.end,
elapsed=timing.elapsed,
generated_at=base.generated_at
generated_at=base.generated_at,
)
@dataclass
@schema_version('poll-remote-run-result', 1)
@schema_version("poll-remote-run-result", 1)
class PollRunCompleteResult(
RemoteRunResult,
PollResult,
):
state: TaskHandlerState = field(
metadata=restrict_to(TaskHandlerState.Success,
TaskHandlerState.Failed),
metadata=restrict_to(TaskHandlerState.Success, TaskHandlerState.Failed),
)
@classmethod
def from_result(
cls: Type['PollRunCompleteResult'],
cls: Type["PollRunCompleteResult"],
base: RemoteRunResult,
tags: TaskTags,
timing: TaskTiming,
logs: List[LogMessage],
) -> 'PollRunCompleteResult':
) -> "PollRunCompleteResult":
return cls(
raw_sql=base.raw_sql,
compiled_sql=base.compiled_sql,
@@ -637,29 +632,28 @@ class PollRunCompleteResult(
start=timing.start,
end=timing.end,
elapsed=timing.elapsed,
generated_at=base.generated_at
generated_at=base.generated_at,
)
@dataclass
@schema_version('poll-remote-run-operation-result', 1)
@schema_version("poll-remote-run-operation-result", 1)
class PollRunOperationCompleteResult(
RemoteRunOperationResult,
PollResult,
):
state: TaskHandlerState = field(
metadata=restrict_to(TaskHandlerState.Success,
TaskHandlerState.Failed),
metadata=restrict_to(TaskHandlerState.Success, TaskHandlerState.Failed),
)
@classmethod
def from_result(
cls: Type['PollRunOperationCompleteResult'],
cls: Type["PollRunOperationCompleteResult"],
base: RemoteRunOperationResult,
tags: TaskTags,
timing: TaskTiming,
logs: List[LogMessage],
) -> 'PollRunOperationCompleteResult':
) -> "PollRunOperationCompleteResult":
return cls(
success=base.success,
results=base.results,
@@ -675,21 +669,20 @@ class PollRunOperationCompleteResult(
@dataclass
@schema_version('poll-remote-catalog-result', 1)
@schema_version("poll-remote-catalog-result", 1)
class PollCatalogCompleteResult(RemoteCatalogResults, PollResult):
state: TaskHandlerState = field(
metadata=restrict_to(TaskHandlerState.Success,
TaskHandlerState.Failed),
metadata=restrict_to(TaskHandlerState.Success, TaskHandlerState.Failed),
)
@classmethod
def from_result(
cls: Type['PollCatalogCompleteResult'],
cls: Type["PollCatalogCompleteResult"],
base: RemoteCatalogResults,
tags: TaskTags,
timing: TaskTiming,
logs: List[LogMessage],
) -> 'PollCatalogCompleteResult':
) -> "PollCatalogCompleteResult":
return cls(
nodes=base.nodes,
sources=base.sources,
@@ -706,27 +699,26 @@ class PollCatalogCompleteResult(RemoteCatalogResults, PollResult):
@dataclass
@schema_version('poll-remote-in-progress-result', 1)
@schema_version("poll-remote-in-progress-result", 1)
class PollInProgressResult(PollResult):
pass
@dataclass
@schema_version('poll-remote-get-manifest-result', 1)
@schema_version("poll-remote-get-manifest-result", 1)
class PollGetManifestResult(GetManifestResult, PollResult):
state: TaskHandlerState = field(
metadata=restrict_to(TaskHandlerState.Success,
TaskHandlerState.Failed),
metadata=restrict_to(TaskHandlerState.Success, TaskHandlerState.Failed),
)
@classmethod
def from_result(
cls: Type['PollGetManifestResult'],
cls: Type["PollGetManifestResult"],
base: GetManifestResult,
tags: TaskTags,
timing: TaskTiming,
logs: List[LogMessage],
) -> 'PollGetManifestResult':
) -> "PollGetManifestResult":
return cls(
manifest=base.manifest,
logs=logs,
@@ -739,21 +731,20 @@ class PollGetManifestResult(GetManifestResult, PollResult):
@dataclass
@schema_version('poll-remote-freshness-result', 1)
@schema_version("poll-remote-freshness-result", 1)
class PollFreshnessResult(RemoteFreshnessResult, PollResult):
state: TaskHandlerState = field(
metadata=restrict_to(TaskHandlerState.Success,
TaskHandlerState.Failed),
metadata=restrict_to(TaskHandlerState.Success, TaskHandlerState.Failed),
)
@classmethod
def from_result(
cls: Type['PollFreshnessResult'],
cls: Type["PollFreshnessResult"],
base: RemoteFreshnessResult,
tags: TaskTags,
timing: TaskTiming,
logs: List[LogMessage],
) -> 'PollFreshnessResult':
) -> "PollFreshnessResult":
return cls(
logs=logs,
tags=tags,
@@ -766,18 +757,19 @@ class PollFreshnessResult(RemoteFreshnessResult, PollResult):
elapsed_time=base.elapsed_time,
)
# Manifest parsing types
class ManifestStatus(StrEnum):
Init = 'init'
Compiling = 'compiling'
Ready = 'ready'
Error = 'error'
Init = "init"
Compiling = "compiling"
Ready = "ready"
Error = "error"
@dataclass
@schema_version('remote-status-result', 1)
@schema_version("remote-status-result", 1)
class LastParse(RemoteResult):
state: ManifestStatus = ManifestStatus.Init
logs: List[LogMessage] = field(default_factory=list)

View File

@@ -8,7 +8,7 @@ from typing import List, Dict, Any, Union
class SelectorDefinition(dbtClassMixin):
name: str
definition: Union[str, Dict[str, Any]]
description: str = ''
description: str = ""
@dataclass

View File

@@ -9,7 +9,7 @@ class PreviousState:
self.path: Path = path
self.manifest: Optional[WritableManifest] = None
manifest_path = self.path / 'manifest.json'
manifest_path = self.path / "manifest.json"
if manifest_path.exists() and manifest_path.is_file():
try:
self.manifest = WritableManifest.read(str(manifest_path))

View File

@@ -1,9 +1,7 @@
import dataclasses
import os
from datetime import datetime
from typing import (
List, Tuple, ClassVar, Type, TypeVar, Dict, Any, Optional
)
from typing import List, Tuple, ClassVar, Type, TypeVar, Dict, Any, Optional
from dbt.clients.system import write_json, read_json
from dbt.exceptions import (
@@ -57,9 +55,7 @@ class Mergeable(Replaceable):
class Writable:
def write(self, path: str):
write_json(
path, self.to_dict(omit_none=False) # type: ignore
)
write_json(path, self.to_dict(omit_none=False)) # type: ignore
class AdditionalPropertiesMixin:
@@ -68,6 +64,7 @@ class AdditionalPropertiesMixin:
The underlying class definition must include a type definition for a field
named '_extra' that is of type `Dict[str, Any]`.
"""
ADDITIONAL_PROPERTIES = True
# This takes attributes in the dictionary that are
@@ -86,10 +83,10 @@ class AdditionalPropertiesMixin:
cls_keys = cls._get_field_names()
new_dict = {}
for key, value in data.items():
if key not in cls_keys and key != '_extra':
if '_extra' not in new_dict:
new_dict['_extra'] = {}
new_dict['_extra'][key] = value
if key not in cls_keys and key != "_extra":
if "_extra" not in new_dict:
new_dict["_extra"] = {}
new_dict["_extra"][key] = value
else:
new_dict[key] = value
data = new_dict
@@ -99,8 +96,8 @@ class AdditionalPropertiesMixin:
def __post_serialize__(self, dct):
data = super().__post_serialize__(dct)
data.update(self.extra)
if '_extra' in data:
del data['_extra']
if "_extra" in data:
del data["_extra"]
return data
def replace(self, **kwargs):
@@ -126,8 +123,8 @@ class Readable:
return cls.from_dict(data) # type: ignore
BASE_SCHEMAS_URL = 'https://schemas.getdbt.com/'
SCHEMA_PATH = 'dbt/{name}/v{version}.json'
BASE_SCHEMAS_URL = "https://schemas.getdbt.com/"
SCHEMA_PATH = "dbt/{name}/v{version}.json"
@dataclasses.dataclass
@@ -137,24 +134,22 @@ class SchemaVersion:
@property
def path(self) -> str:
return SCHEMA_PATH.format(
name=self.name,
version=self.version
)
return SCHEMA_PATH.format(name=self.name, version=self.version)
def __str__(self) -> str:
return BASE_SCHEMAS_URL + self.path
SCHEMA_VERSION_KEY = 'dbt_schema_version'
SCHEMA_VERSION_KEY = "dbt_schema_version"
METADATA_ENV_PREFIX = 'DBT_ENV_CUSTOM_ENV_'
METADATA_ENV_PREFIX = "DBT_ENV_CUSTOM_ENV_"
def get_metadata_env() -> Dict[str, str]:
return {
k[len(METADATA_ENV_PREFIX):]: v for k, v in os.environ.items()
k[len(METADATA_ENV_PREFIX) :]: v
for k, v in os.environ.items()
if k.startswith(METADATA_ENV_PREFIX)
}
@@ -163,12 +158,8 @@ def get_metadata_env() -> Dict[str, str]:
class BaseArtifactMetadata(dbtClassMixin):
dbt_schema_version: str
dbt_version: str = __version__
generated_at: datetime = dataclasses.field(
default_factory=datetime.utcnow
)
invocation_id: Optional[str] = dataclasses.field(
default_factory=get_invocation_id
)
generated_at: datetime = dataclasses.field(default_factory=datetime.utcnow)
invocation_id: Optional[str] = dataclasses.field(default_factory=get_invocation_id)
env: Dict[str, str] = dataclasses.field(default_factory=get_metadata_env)
@@ -179,6 +170,7 @@ def schema_version(name: str, version: int):
version=version,
)
return cls
return inner
@@ -190,11 +182,11 @@ class VersionedSchema(dbtClassMixin):
def json_schema(cls, embeddable: bool = False) -> Dict[str, Any]:
result = super().json_schema(embeddable=embeddable)
if not embeddable:
result['$id'] = str(cls.dbt_schema_version)
result["$id"] = str(cls.dbt_schema_version)
return result
T = TypeVar('T', bound='ArtifactMixin')
T = TypeVar("T", bound="ArtifactMixin")
# metadata should really be a Generic[T_M] where T_M is a TypeVar bound to
@@ -208,6 +200,4 @@ class ArtifactMixin(VersionedSchema, Writable, Readable):
def validate(cls, data):
super().validate(data)
if cls.dbt_schema_version is None:
raise InternalException(
'Cannot call from_dict with no schema version!'
)
raise InternalException("Cannot call from_dict with no schema version!")

View File

@@ -1,5 +1,7 @@
from typing import (
Type, ClassVar, cast,
Type,
ClassVar,
cast,
)
import re
from dataclasses import fields
@@ -11,9 +13,7 @@ from hologram import JsonSchemaMixin, FieldEncoder, ValidationError
# type: ignore
from mashumaro import DataClassDictMixin
from mashumaro.config import (
TO_DICT_ADD_OMIT_NONE_FLAG, BaseConfig as MashBaseConfig
)
from mashumaro.config import TO_DICT_ADD_OMIT_NONE_FLAG, BaseConfig as MashBaseConfig
from mashumaro.types import SerializableType, SerializationStrategy
@@ -26,9 +26,7 @@ class DateTimeSerialization(SerializationStrategy):
return out
def deserialize(self, value):
return (
value if isinstance(value, datetime) else parse(cast(str, value))
)
return value if isinstance(value, datetime) else parse(cast(str, value))
# This class pulls in both JsonSchemaMixin from Hologram and
@@ -38,8 +36,8 @@ class DateTimeSerialization(SerializationStrategy):
# come from Hologram.
class dbtClassMixin(DataClassDictMixin, JsonSchemaMixin):
"""Mixin which adds methods to generate a JSON schema and
convert to and from JSON encodable dicts with validation
against the schema
convert to and from JSON encodable dicts with validation
against the schema
"""
class Config(MashBaseConfig):
@@ -60,8 +58,8 @@ class dbtClassMixin(DataClassDictMixin, JsonSchemaMixin):
if self._hyphenated:
new_dict = {}
for key in dct:
if '_' in key:
new_key = key.replace('_', '-')
if "_" in key:
new_key = key.replace("_", "-")
new_dict[new_key] = dct[key]
else:
new_dict[key] = dct[key]
@@ -76,8 +74,8 @@ class dbtClassMixin(DataClassDictMixin, JsonSchemaMixin):
if cls._hyphenated:
new_dict = {}
for key in data:
if '-' in key:
new_key = key.replace('-', '_')
if "-" in key:
new_key = key.replace("-", "_")
new_dict[new_key] = data[key]
else:
new_dict[key] = data[key]
@@ -89,16 +87,16 @@ class dbtClassMixin(DataClassDictMixin, JsonSchemaMixin):
# hologram and in mashumaro.
def _local_to_dict(self, **kwargs):
args = {}
if 'omit_none' in kwargs:
args['omit_none'] = kwargs['omit_none']
if "omit_none" in kwargs:
args["omit_none"] = kwargs["omit_none"]
return self.to_dict(**args)
class ValidatedStringMixin(str, SerializableType):
ValidationRegex = ''
ValidationRegex = ""
@classmethod
def _deserialize(cls, value: str) -> 'ValidatedStringMixin':
def _deserialize(cls, value: str) -> "ValidatedStringMixin":
cls.validate(value)
return ValidatedStringMixin(value)

View File

@@ -14,39 +14,31 @@ class DBTDeprecation:
def name(self) -> str:
if self._name is not None:
return self._name
raise NotImplementedError(
'name not implemented for {}'.format(self)
)
raise NotImplementedError("name not implemented for {}".format(self))
def track_deprecation_warn(self) -> None:
if dbt.tracking.active_user is not None:
dbt.tracking.track_deprecation_warn({
"deprecation_name": self.name
})
dbt.tracking.track_deprecation_warn({"deprecation_name": self.name})
@property
def description(self) -> str:
if self._description is not None:
return self._description
raise NotImplementedError(
'description not implemented for {}'.format(self)
)
raise NotImplementedError("description not implemented for {}".format(self))
def show(self, *args, **kwargs) -> None:
if self.name not in active_deprecations:
desc = self.description.format(**kwargs)
msg = ui.line_wrap_message(
desc, prefix='* Deprecation Warning: '
)
msg = ui.line_wrap_message(desc, prefix="* Deprecation Warning: ")
dbt.exceptions.warn_or_error(msg)
self.track_deprecation_warn()
active_deprecations.add(self.name)
class MaterializationReturnDeprecation(DBTDeprecation):
_name = 'materialization-return'
_name = "materialization-return"
_description = '''\
_description = """\
The materialization ("{materialization}") did not explicitly return a list
of relations to add to the cache. By default the target relation will be
added, but this behavior will be removed in a future version of dbt.
@@ -56,22 +48,22 @@ class MaterializationReturnDeprecation(DBTDeprecation):
For more information, see:
https://docs.getdbt.com/v0.15/docs/creating-new-materializations#section-6-returning-relations
'''
"""
class NotADictionaryDeprecation(DBTDeprecation):
_name = 'not-a-dictionary'
_name = "not-a-dictionary"
_description = '''\
_description = """\
The object ("{obj}") was used as a dictionary. In a future version of dbt
this capability will be removed from objects of this type.
'''
"""
class ColumnQuotingDeprecation(DBTDeprecation):
_name = 'column-quoting-unset'
_name = "column-quoting-unset"
_description = '''\
_description = """\
The quote_columns parameter was not set for seeds, so the default value of
False was chosen. The default will change to True in a future release.
@@ -80,13 +72,13 @@ class ColumnQuotingDeprecation(DBTDeprecation):
For more information, see:
https://docs.getdbt.com/v0.15/docs/seeds#section-specify-column-quoting
'''
"""
class ModelsKeyNonModelDeprecation(DBTDeprecation):
_name = 'models-key-mismatch'
_name = "models-key-mismatch"
_description = '''\
_description = """\
"{node.name}" is a {node.resource_type} node, but it is specified in
the {patch.yaml_key} section of {patch.original_file_path}.
@@ -96,25 +88,25 @@ class ModelsKeyNonModelDeprecation(DBTDeprecation):
the {expected_key} key instead.
This warning will become an error in a future release.
'''
"""
class ExecuteMacrosReleaseDeprecation(DBTDeprecation):
_name = 'execute-macro-release'
_description = '''\
_name = "execute-macro-release"
_description = """\
The "release" argument to execute_macro is now ignored, and will be removed
in a future relase of dbt. At that time, providing a `release` argument
will result in an error.
'''
"""
class AdapterMacroDeprecation(DBTDeprecation):
_name = 'adapter-macro'
_description = '''\
_name = "adapter-macro"
_description = """\
The "adapter_macro" macro has been deprecated. Instead, use the
`adapter.dispatch` method to find a macro and call the result.
adapter_macro was called for: {macro_name}
'''
"""
_adapter_renamed_description = """\
@@ -128,11 +120,11 @@ Documentation for {new_name} can be found here:
def renamed_method(old_name: str, new_name: str):
class AdapterDeprecationWarning(DBTDeprecation):
_name = 'adapter:{}'.format(old_name)
_description = _adapter_renamed_description.format(old_name=old_name,
new_name=new_name)
_name = "adapter:{}".format(old_name)
_description = _adapter_renamed_description.format(
old_name=old_name, new_name=new_name
)
dep = AdapterDeprecationWarning()
deprecations_list.append(dep)
@@ -142,9 +134,7 @@ def renamed_method(old_name: str, new_name: str):
def warn(name, *args, **kwargs):
if name not in deprecations:
# this should (hopefully) never happen
raise RuntimeError(
"Error showing deprecation warning: {}".format(name)
)
raise RuntimeError("Error showing deprecation warning: {}".format(name))
deprecations[name].show(*args, **kwargs)
@@ -163,9 +153,7 @@ deprecations_list: List[DBTDeprecation] = [
AdapterMacroDeprecation(),
]
deprecations: Dict[str, DBTDeprecation] = {
d.name: d for d in deprecations_list
}
deprecations: Dict[str, DBTDeprecation] = {d.name: d for d in deprecations_list}
def reset_deprecations():

View File

@@ -22,12 +22,12 @@ def downloads_directory():
# the user might have set an environment variable. Set it to that, and do
# not remove it when finished.
if DOWNLOADS_PATH is None:
DOWNLOADS_PATH = os.getenv('DBT_DOWNLOADS_DIR')
DOWNLOADS_PATH = os.getenv("DBT_DOWNLOADS_DIR")
remove_downloads = False
# if we are making a per-run temp directory, remove it at the end of
# successful runs
if DOWNLOADS_PATH is None:
DOWNLOADS_PATH = tempfile.mkdtemp(prefix='dbt-downloads-')
DOWNLOADS_PATH = tempfile.mkdtemp(prefix="dbt-downloads-")
remove_downloads = True
system.make_directory(DOWNLOADS_PATH)
@@ -62,7 +62,7 @@ class PinnedPackage(BasePackage):
if not version:
return self.name
return '{}@{}'.format(self.name, version)
return "{}@{}".format(self.name, version)
@abc.abstractmethod
def get_version(self) -> Optional[str]:
@@ -94,8 +94,8 @@ class PinnedPackage(BasePackage):
return os.path.join(project.modules_path, dest_dirname)
SomePinned = TypeVar('SomePinned', bound=PinnedPackage)
SomeUnpinned = TypeVar('SomeUnpinned', bound='UnpinnedPackage')
SomePinned = TypeVar("SomePinned", bound=PinnedPackage)
SomeUnpinned = TypeVar("SomeUnpinned", bound="UnpinnedPackage")
class UnpinnedPackage(Generic[SomePinned], BasePackage):

View File

@@ -8,18 +8,16 @@ from dbt.contracts.project import (
ProjectPackageMetadata,
GitPackage,
)
from dbt.deps.base import PinnedPackage, UnpinnedPackage, get_downloads_path
from dbt.exceptions import (
ExecutableError, warn_or_error, raise_dependency_error
)
from dbt.deps import PinnedPackage, UnpinnedPackage, get_downloads_path
from dbt.exceptions import ExecutableError, warn_or_error, raise_dependency_error
from dbt.logger import GLOBAL_LOGGER as logger
from dbt import ui
PIN_PACKAGE_URL = 'https://docs.getdbt.com/docs/package-management#section-specifying-package-versions' # noqa
PIN_PACKAGE_URL = "https://docs.getdbt.com/docs/package-management#section-specifying-package-versions" # noqa
def md5sum(s: str):
return hashlib.md5(s.encode('latin-1')).hexdigest()
return hashlib.md5(s.encode("latin-1")).hexdigest()
class GitPackageMixin:
@@ -32,13 +30,11 @@ class GitPackageMixin:
return self.git
def source_type(self) -> str:
return 'git'
return "git"
class GitPinnedPackage(GitPackageMixin, PinnedPackage):
def __init__(
self, git: str, revision: str, warn_unpinned: bool = True
) -> None:
def __init__(self, git: str, revision: str, warn_unpinned: bool = True) -> None:
super().__init__(git)
self.revision = revision
self.warn_unpinned = warn_unpinned
@@ -48,15 +44,15 @@ class GitPinnedPackage(GitPackageMixin, PinnedPackage):
return self.revision
def nice_version_name(self):
if self.revision == 'HEAD':
return 'HEAD (default branch)'
if self.revision == "HEAD":
return "HEAD (default branch)"
else:
return 'revision {}'.format(self.revision)
return "revision {}".format(self.revision)
def unpinned_msg(self):
if self.revision == 'HEAD':
return 'not pinned, using HEAD (default branch)'
elif self.revision in ('main', 'master'):
if self.revision == "HEAD":
return "not pinned, using HEAD (default branch)"
elif self.revision in ("main", "master"):
return f'pinned to the "{self.revision}" branch'
else:
return None
@@ -68,15 +64,17 @@ class GitPinnedPackage(GitPackageMixin, PinnedPackage):
the path to the checked out directory."""
try:
dir_ = git.clone_and_checkout(
self.git, get_downloads_path(), branch=self.revision,
dirname=self._checkout_name
self.git,
get_downloads_path(),
branch=self.revision,
dirname=self._checkout_name,
)
except ExecutableError as exc:
if exc.cmd and exc.cmd[0] == 'git':
if exc.cmd and exc.cmd[0] == "git":
logger.error(
'Make sure git is installed on your machine. More '
'information: '
'https://docs.getdbt.com/docs/package-management'
"Make sure git is installed on your machine. More "
"information: "
"https://docs.getdbt.com/docs/package-management"
)
raise
return os.path.join(get_downloads_path(), dir_)
@@ -87,9 +85,10 @@ class GitPinnedPackage(GitPackageMixin, PinnedPackage):
if self.unpinned_msg() and self.warn_unpinned:
warn_or_error(
'The git package "{}" \n\tis {}.\n\tThis can introduce '
'breaking changes into your project without warning!\n\nSee {}'
.format(self.git, self.unpinned_msg(), PIN_PACKAGE_URL),
log_fmt=ui.yellow('WARNING: {}')
"breaking changes into your project without warning!\n\nSee {}".format(
self.git, self.unpinned_msg(), PIN_PACKAGE_URL
),
log_fmt=ui.yellow("WARNING: {}"),
)
loaded = Project.from_project_root(path, renderer)
return ProjectPackageMetadata.from_project(loaded)
@@ -114,26 +113,21 @@ class GitUnpinnedPackage(GitPackageMixin, UnpinnedPackage[GitPinnedPackage]):
self.warn_unpinned = warn_unpinned
@classmethod
def from_contract(
cls, contract: GitPackage
) -> 'GitUnpinnedPackage':
def from_contract(cls, contract: GitPackage) -> "GitUnpinnedPackage":
revisions = contract.get_revisions()
# we want to map None -> True
warn_unpinned = contract.warn_unpinned is not False
return cls(git=contract.git, revisions=revisions,
warn_unpinned=warn_unpinned)
return cls(git=contract.git, revisions=revisions, warn_unpinned=warn_unpinned)
def all_names(self) -> List[str]:
if self.git.endswith('.git'):
if self.git.endswith(".git"):
other = self.git[:-4]
else:
other = self.git + '.git'
other = self.git + ".git"
return [self.git, other]
def incorporate(
self, other: 'GitUnpinnedPackage'
) -> 'GitUnpinnedPackage':
def incorporate(self, other: "GitUnpinnedPackage") -> "GitUnpinnedPackage":
warn_unpinned = self.warn_unpinned and other.warn_unpinned
return GitUnpinnedPackage(
@@ -145,13 +139,13 @@ class GitUnpinnedPackage(GitPackageMixin, UnpinnedPackage[GitPinnedPackage]):
def resolved(self) -> GitPinnedPackage:
requested = set(self.revisions)
if len(requested) == 0:
requested = {'HEAD'}
requested = {"HEAD"}
elif len(requested) > 1:
raise_dependency_error(
'git dependencies should contain exactly one version. '
'{} contains: {}'.format(self.git, requested))
"git dependencies should contain exactly one version. "
"{} contains: {}".format(self.git, requested)
)
return GitPinnedPackage(
git=self.git, revision=requested.pop(),
warn_unpinned=self.warn_unpinned
git=self.git, revision=requested.pop(), warn_unpinned=self.warn_unpinned
)

View File

@@ -1,7 +1,7 @@
import shutil
from dbt.clients import system
from dbt.deps.base import PinnedPackage, UnpinnedPackage
from dbt.deps import PinnedPackage, UnpinnedPackage
from dbt.contracts.project import (
ProjectPackageMetadata,
LocalPackage,
@@ -19,7 +19,7 @@ class LocalPackageMixin:
return self.local
def source_type(self):
return 'local'
return "local"
class LocalPinnedPackage(LocalPackageMixin, PinnedPackage):
@@ -30,7 +30,7 @@ class LocalPinnedPackage(LocalPackageMixin, PinnedPackage):
return None
def nice_version_name(self):
return '<local @ {}>'.format(self.local)
return "<local @ {}>".format(self.local)
def resolve_path(self, project):
return system.resolve_path_from_base(
@@ -39,9 +39,7 @@ class LocalPinnedPackage(LocalPackageMixin, PinnedPackage):
)
def _fetch_metadata(self, project, renderer):
loaded = project.from_project_root(
self.resolve_path(project), renderer
)
loaded = project.from_project_root(self.resolve_path(project), renderer)
return ProjectPackageMetadata.from_project(loaded)
def install(self, project, renderer):
@@ -57,27 +55,22 @@ class LocalPinnedPackage(LocalPackageMixin, PinnedPackage):
system.remove_file(dest_path)
if can_create_symlink:
logger.debug(' Creating symlink to local dependency.')
logger.debug(" Creating symlink to local dependency.")
system.make_symlink(src_path, dest_path)
else:
logger.debug(' Symlinks are not available on this '
'OS, copying dependency.')
logger.debug(
" Symlinks are not available on this " "OS, copying dependency."
)
shutil.copytree(src_path, dest_path)
class LocalUnpinnedPackage(
LocalPackageMixin, UnpinnedPackage[LocalPinnedPackage]
):
class LocalUnpinnedPackage(LocalPackageMixin, UnpinnedPackage[LocalPinnedPackage]):
@classmethod
def from_contract(
cls, contract: LocalPackage
) -> 'LocalUnpinnedPackage':
def from_contract(cls, contract: LocalPackage) -> "LocalUnpinnedPackage":
return cls(local=contract.local)
def incorporate(
self, other: 'LocalUnpinnedPackage'
) -> 'LocalUnpinnedPackage':
def incorporate(self, other: "LocalUnpinnedPackage") -> "LocalUnpinnedPackage":
return LocalUnpinnedPackage(local=self.local)
def resolved(self) -> LocalPinnedPackage:

View File

@@ -7,7 +7,7 @@ from dbt.contracts.project import (
RegistryPackageMetadata,
RegistryPackage,
)
from dbt.deps.base import PinnedPackage, UnpinnedPackage, get_downloads_path
from dbt.deps import PinnedPackage, UnpinnedPackage, get_downloads_path
from dbt.exceptions import (
package_version_not_found,
VersionsNotCompatibleException,
@@ -26,7 +26,7 @@ class RegistryPackageMixin:
return self.package
def source_type(self) -> str:
return 'hub'
return "hub"
class RegistryPinnedPackage(RegistryPackageMixin, PinnedPackage):
@@ -39,13 +39,13 @@ class RegistryPinnedPackage(RegistryPackageMixin, PinnedPackage):
return self.package
def source_type(self):
return 'hub'
return "hub"
def get_version(self):
return self.version
def nice_version_name(self):
return 'version {}'.format(self.version)
return "version {}".format(self.version)
def _fetch_metadata(self, project, renderer) -> RegistryPackageMetadata:
dct = registry.package_version(self.package, self.version)
@@ -54,10 +54,8 @@ class RegistryPinnedPackage(RegistryPackageMixin, PinnedPackage):
def install(self, project, renderer):
metadata = self.fetch_metadata(project, renderer)
tar_name = '{}.{}.tar.gz'.format(self.package, self.version)
tar_path = os.path.realpath(
os.path.join(get_downloads_path(), tar_name)
)
tar_name = "{}.{}.tar.gz".format(self.package, self.version)
tar_path = os.path.realpath(os.path.join(get_downloads_path(), tar_name))
system.make_directory(os.path.dirname(tar_path))
download_url = metadata.downloads.tarball
@@ -70,9 +68,7 @@ class RegistryPinnedPackage(RegistryPackageMixin, PinnedPackage):
class RegistryUnpinnedPackage(
RegistryPackageMixin, UnpinnedPackage[RegistryPinnedPackage]
):
def __init__(
self, package: str, versions: List[semver.VersionSpecifier]
) -> None:
def __init__(self, package: str, versions: List[semver.VersionSpecifier]) -> None:
super().__init__(package)
self.versions = versions
@@ -82,20 +78,15 @@ class RegistryUnpinnedPackage(
package_not_found(self.package)
@classmethod
def from_contract(
cls, contract: RegistryPackage
) -> 'RegistryUnpinnedPackage':
def from_contract(cls, contract: RegistryPackage) -> "RegistryUnpinnedPackage":
raw_version = contract.get_versions()
versions = [
semver.VersionSpecifier.from_version_string(v)
for v in raw_version
]
versions = [semver.VersionSpecifier.from_version_string(v) for v in raw_version]
return cls(package=contract.package, versions=versions)
def incorporate(
self, other: 'RegistryUnpinnedPackage'
) -> 'RegistryUnpinnedPackage':
self, other: "RegistryUnpinnedPackage"
) -> "RegistryUnpinnedPackage":
return RegistryUnpinnedPackage(
package=self.package,
versions=self.versions + other.versions,
@@ -106,8 +97,7 @@ class RegistryUnpinnedPackage(
try:
range_ = semver.reduce_versions(*self.versions)
except VersionsNotCompatibleException as e:
new_msg = ('Version error for package {}: {}'
.format(self.name, e))
new_msg = "Version error for package {}: {}".format(self.name, e)
raise DependencyException(new_msg) from e
available = registry.get_available_versions(self.package)

View File

@@ -6,7 +6,7 @@ from dbt.exceptions import raise_dependency_error, InternalException
from dbt.context.target import generate_target_context
from dbt.config import Project, RuntimeConfig
from dbt.config.renderer import DbtProjectYamlRenderer
from dbt.deps.base import BasePackage, PinnedPackage, UnpinnedPackage
from dbt.deps import BasePackage, PinnedPackage, UnpinnedPackage
from dbt.deps.local import LocalUnpinnedPackage
from dbt.deps.git import GitUnpinnedPackage
from dbt.deps.registry import RegistryUnpinnedPackage
@@ -49,12 +49,10 @@ class PackageListing:
key_str: str = self._pick_key(key)
self.packages[key_str] = value
def _mismatched_types(
self, old: UnpinnedPackage, new: UnpinnedPackage
) -> NoReturn:
def _mismatched_types(self, old: UnpinnedPackage, new: UnpinnedPackage) -> NoReturn:
raise_dependency_error(
f'Cannot incorporate {new} ({new.__class__.__name__}) in {old} '
f'({old.__class__.__name__}): mismatched types'
f"Cannot incorporate {new} ({new.__class__.__name__}) in {old} "
f"({old.__class__.__name__}): mismatched types"
)
def incorporate(self, package: UnpinnedPackage):
@@ -78,14 +76,14 @@ class PackageListing:
pkg = RegistryUnpinnedPackage.from_contract(contract)
else:
raise InternalException(
'Invalid package type {}'.format(type(contract))
"Invalid package type {}".format(type(contract))
)
self.incorporate(pkg)
@classmethod
def from_contracts(
cls: Type['PackageListing'], src: List[PackageContract]
) -> 'PackageListing':
cls: Type["PackageListing"], src: List[PackageContract]
) -> "PackageListing":
self = cls({})
self.update_from(src)
return self
@@ -108,14 +106,14 @@ def _check_for_duplicate_project_names(
if project_name in seen:
raise_dependency_error(
f'Found duplicate project "{project_name}". This occurs when '
'a dependency has the same project name as some other '
'dependency.'
"a dependency has the same project name as some other "
"dependency."
)
elif project_name == config.project_name:
raise_dependency_error(
'Found a dependency with the same name as the root project '
"Found a dependency with the same name as the root project "
f'"{project_name}". Package names must be unique in a project.'
' Please rename one of these packages.'
" Please rename one of these packages."
)
seen.add(project_name)

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,7 @@
import os
import multiprocessing
if os.name != 'nt':
if os.name != "nt":
# https://bugs.python.org/issue41567
import multiprocessing.popen_spawn_posix # type: ignore
from pathlib import Path
@@ -23,7 +24,7 @@ def env_set_truthy(key: str) -> Optional[str]:
otherwise.
"""
value = os.getenv(key)
if not value or value.lower() in ('0', 'false', 'f'):
if not value or value.lower() in ("0", "false", "f"):
return None
return value
@@ -36,24 +37,23 @@ def env_set_path(key: str) -> Optional[Path]:
return Path(value)
SINGLE_THREADED_WEBSERVER = env_set_truthy('DBT_SINGLE_THREADED_WEBSERVER')
SINGLE_THREADED_HANDLER = env_set_truthy('DBT_SINGLE_THREADED_HANDLER')
MACRO_DEBUGGING = env_set_truthy('DBT_MACRO_DEBUGGING')
DEFER_MODE = env_set_truthy('DBT_DEFER_TO_STATE')
ARTIFACT_STATE_PATH = env_set_path('DBT_ARTIFACT_STATE_PATH')
SINGLE_THREADED_WEBSERVER = env_set_truthy("DBT_SINGLE_THREADED_WEBSERVER")
SINGLE_THREADED_HANDLER = env_set_truthy("DBT_SINGLE_THREADED_HANDLER")
MACRO_DEBUGGING = env_set_truthy("DBT_MACRO_DEBUGGING")
DEFER_MODE = env_set_truthy("DBT_DEFER_TO_STATE")
ARTIFACT_STATE_PATH = env_set_path("DBT_ARTIFACT_STATE_PATH")
def _get_context():
# TODO: change this back to use fork() on linux when we have made that safe
return multiprocessing.get_context('spawn')
return multiprocessing.get_context("spawn")
MP_CONTEXT = _get_context()
def reset():
global STRICT_MODE, FULL_REFRESH, USE_CACHE, WARN_ERROR, TEST_NEW_PARSER, \
WRITE_JSON, PARTIAL_PARSE, MP_CONTEXT, USE_COLORS
global STRICT_MODE, FULL_REFRESH, USE_CACHE, WARN_ERROR, TEST_NEW_PARSER, WRITE_JSON, PARTIAL_PARSE, MP_CONTEXT, USE_COLORS
STRICT_MODE = False
FULL_REFRESH = False
@@ -67,26 +67,22 @@ def reset():
def set_from_args(args):
global STRICT_MODE, FULL_REFRESH, USE_CACHE, WARN_ERROR, TEST_NEW_PARSER, \
WRITE_JSON, PARTIAL_PARSE, MP_CONTEXT, USE_COLORS
global STRICT_MODE, FULL_REFRESH, USE_CACHE, WARN_ERROR, TEST_NEW_PARSER, WRITE_JSON, PARTIAL_PARSE, MP_CONTEXT, USE_COLORS
USE_CACHE = getattr(args, 'use_cache', USE_CACHE)
USE_CACHE = getattr(args, "use_cache", USE_CACHE)
FULL_REFRESH = getattr(args, 'full_refresh', FULL_REFRESH)
STRICT_MODE = getattr(args, 'strict', STRICT_MODE)
WARN_ERROR = (
STRICT_MODE or
getattr(args, 'warn_error', STRICT_MODE or WARN_ERROR)
)
FULL_REFRESH = getattr(args, "full_refresh", FULL_REFRESH)
STRICT_MODE = getattr(args, "strict", STRICT_MODE)
WARN_ERROR = STRICT_MODE or getattr(args, "warn_error", STRICT_MODE or WARN_ERROR)
TEST_NEW_PARSER = getattr(args, 'test_new_parser', TEST_NEW_PARSER)
WRITE_JSON = getattr(args, 'write_json', WRITE_JSON)
PARTIAL_PARSE = getattr(args, 'partial_parse', None)
TEST_NEW_PARSER = getattr(args, "test_new_parser", TEST_NEW_PARSER)
WRITE_JSON = getattr(args, "write_json", WRITE_JSON)
PARTIAL_PARSE = getattr(args, "partial_parse", None)
MP_CONTEXT = _get_context()
# The use_colors attribute will always have a value because it is assigned
# None by default from the add_mutually_exclusive_group function
use_colors_override = getattr(args, 'use_colors')
use_colors_override = getattr(args, "use_colors")
if use_colors_override is not None:
USE_COLORS = use_colors_override

View File

@@ -2,9 +2,7 @@
import itertools
from dbt.clients.yaml_helper import yaml, Loader, Dumper # noqa: F401
from typing import (
Dict, List, Optional, Tuple, Any, Union
)
from typing import Dict, List, Optional, Tuple, Any, Union
from dbt.contracts.selection import SelectorDefinition, SelectorFile
from dbt.exceptions import InternalException, ValidationException
@@ -17,21 +15,17 @@ from .selector_spec import (
SelectionCriteria,
)
INTERSECTION_DELIMITER = ','
INTERSECTION_DELIMITER = ","
DEFAULT_INCLUDES: List[str] = ['fqn:*', 'source:*', 'exposure:*']
DEFAULT_INCLUDES: List[str] = ["fqn:*", "source:*", "exposure:*"]
DEFAULT_EXCLUDES: List[str] = []
DATA_TEST_SELECTOR: str = 'test_type:data'
SCHEMA_TEST_SELECTOR: str = 'test_type:schema'
DATA_TEST_SELECTOR: str = "test_type:data"
SCHEMA_TEST_SELECTOR: str = "test_type:schema"
def parse_union(
components: List[str], expect_exists: bool
) -> SelectionUnion:
def parse_union(components: List[str], expect_exists: bool) -> SelectionUnion:
# turn ['a b', 'c'] -> ['a', 'b', 'c']
raw_specs = itertools.chain.from_iterable(
r.split(' ') for r in components
)
raw_specs = itertools.chain.from_iterable(r.split(" ") for r in components)
union_components: List[SelectionSpec] = []
# ['a', 'b', 'c,d'] -> union('a', 'b', intersection('c', 'd'))
@@ -40,11 +34,13 @@ def parse_union(
SelectionCriteria.from_single_spec(part)
for part in raw_spec.split(INTERSECTION_DELIMITER)
]
union_components.append(SelectionIntersection(
components=intersection_components,
expect_exists=expect_exists,
raw=raw_spec,
))
union_components.append(
SelectionIntersection(
components=intersection_components,
expect_exists=expect_exists,
raw=raw_spec,
)
)
return SelectionUnion(
components=union_components,
@@ -78,9 +74,7 @@ def parse_test_selectors(
union_components = []
if data:
union_components.append(
SelectionCriteria.from_single_spec(DATA_TEST_SELECTOR)
)
union_components.append(SelectionCriteria.from_single_spec(DATA_TEST_SELECTOR))
if schema:
union_components.append(
SelectionCriteria.from_single_spec(SCHEMA_TEST_SELECTOR)
@@ -98,27 +92,21 @@ def parse_test_selectors(
raw=[DATA_TEST_SELECTOR, SCHEMA_TEST_SELECTOR],
)
return SelectionIntersection(
components=[base, intersect_with], expect_exists=True
)
return SelectionIntersection(components=[base, intersect_with], expect_exists=True)
RawDefinition = Union[str, Dict[str, Any]]
def _get_list_dicts(
dct: Dict[str, Any], key: str
) -> List[RawDefinition]:
def _get_list_dicts(dct: Dict[str, Any], key: str) -> List[RawDefinition]:
result: List[RawDefinition] = []
if key not in dct:
raise InternalException(
f'Expected to find key {key} in dict, only found {list(dct)}'
f"Expected to find key {key} in dict, only found {list(dct)}"
)
values = dct[key]
if not isinstance(values, list):
raise ValidationException(
f'Invalid value for key "{key}". Expected a list.'
)
raise ValidationException(f'Invalid value for key "{key}". Expected a list.')
for value in values:
if isinstance(value, dict):
for value_key in value:
@@ -133,36 +121,31 @@ def _get_list_dicts(
else:
raise ValidationException(
f'Invalid value type {type(value)} in key "{key}", expected '
f'dict or str (value: {value}).'
f"dict or str (value: {value})."
)
return result
def _parse_exclusions(definition) -> Optional[SelectionSpec]:
exclusions = _get_list_dicts(definition, 'exclude')
parsed_exclusions = [
parse_from_definition(excl) for excl in exclusions
]
exclusions = _get_list_dicts(definition, "exclude")
parsed_exclusions = [parse_from_definition(excl) for excl in exclusions]
if len(parsed_exclusions) == 1:
return parsed_exclusions[0]
elif len(parsed_exclusions) > 1:
return SelectionUnion(
components=parsed_exclusions,
raw=exclusions
)
return SelectionUnion(components=parsed_exclusions, raw=exclusions)
else:
return None
def _parse_include_exclude_subdefs(
definitions: List[RawDefinition]
definitions: List[RawDefinition],
) -> Tuple[List[SelectionSpec], Optional[SelectionSpec]]:
include_parts: List[SelectionSpec] = []
diff_arg: Optional[SelectionSpec] = None
for definition in definitions:
if isinstance(definition, dict) and 'exclude' in definition:
if isinstance(definition, dict) and "exclude" in definition:
# do not allow multiple exclude: defs at the same level
if diff_arg is not None:
yaml_sel_cfg = yaml.dump(definition)
@@ -178,7 +161,7 @@ def _parse_include_exclude_subdefs(
def parse_union_definition(definition: Dict[str, Any]) -> SelectionSpec:
union_def_parts = _get_list_dicts(definition, 'union')
union_def_parts = _get_list_dicts(definition, "union")
include, exclude = _parse_include_exclude_subdefs(union_def_parts)
union = SelectionUnion(components=include)
@@ -187,16 +170,11 @@ def parse_union_definition(definition: Dict[str, Any]) -> SelectionSpec:
union.raw = definition
return union
else:
return SelectionDifference(
components=[union, exclude],
raw=definition
)
return SelectionDifference(components=[union, exclude], raw=definition)
def parse_intersection_definition(
definition: Dict[str, Any]
) -> SelectionSpec:
intersection_def_parts = _get_list_dicts(definition, 'intersection')
def parse_intersection_definition(definition: Dict[str, Any]) -> SelectionSpec:
intersection_def_parts = _get_list_dicts(definition, "intersection")
include, exclude = _parse_include_exclude_subdefs(intersection_def_parts)
intersection = SelectionIntersection(components=include)
@@ -204,10 +182,7 @@ def parse_intersection_definition(
intersection.raw = definition
return intersection
else:
return SelectionDifference(
components=[intersection, exclude],
raw=definition
)
return SelectionDifference(components=[intersection, exclude], raw=definition)
def parse_dict_definition(definition: Dict[str, Any]) -> SelectionSpec:
@@ -221,14 +196,14 @@ def parse_dict_definition(definition: Dict[str, Any]) -> SelectionSpec:
f'"{type(key)}" ({key})'
)
dct = {
'method': key,
'value': value,
"method": key,
"value": value,
}
elif 'method' in definition and 'value' in definition:
elif "method" in definition and "value" in definition:
dct = definition
if 'exclude' in definition:
if "exclude" in definition:
diff_arg = _parse_exclusions(definition)
dct = {k: v for k, v in dct.items() if k != 'exclude'}
dct = {k: v for k, v in dct.items() if k != "exclude"}
else:
raise ValidationException(
f'Expected either 1 key or else "method" '
@@ -243,13 +218,14 @@ def parse_dict_definition(definition: Dict[str, Any]) -> SelectionSpec:
return SelectionDifference(components=[base, diff_arg])
def parse_from_definition(
definition: RawDefinition, rootlevel=False
) -> SelectionSpec:
def parse_from_definition(definition: RawDefinition, rootlevel=False) -> SelectionSpec:
if (isinstance(definition, dict) and
('union' in definition or 'intersection' in definition) and
rootlevel and len(definition) > 1):
if (
isinstance(definition, dict)
and ("union" in definition or "intersection" in definition)
and rootlevel
and len(definition) > 1
):
keys = ",".join(definition.keys())
raise ValidationException(
f"Only a single 'union' or 'intersection' key is allowed "
@@ -257,25 +233,24 @@ def parse_from_definition(
)
if isinstance(definition, str):
return SelectionCriteria.from_single_spec(definition)
elif 'union' in definition:
elif "union" in definition:
return parse_union_definition(definition)
elif 'intersection' in definition:
elif "intersection" in definition:
return parse_intersection_definition(definition)
elif isinstance(definition, dict):
return parse_dict_definition(definition)
else:
raise ValidationException(
f'Expected to find union, intersection, str or dict, instead '
f'found {type(definition)}: {definition}'
f"Expected to find union, intersection, str or dict, instead "
f"found {type(definition)}: {definition}"
)
def parse_from_selectors_definition(
source: SelectorFile
) -> Dict[str, SelectionSpec]:
def parse_from_selectors_definition(source: SelectorFile) -> Dict[str, SelectionSpec]:
result: Dict[str, SelectionSpec] = {}
selector: SelectorDefinition
for selector in source.selectors:
result[selector.name] = parse_from_definition(selector.definition,
rootlevel=True)
result[selector.name] = parse_from_definition(
selector.definition, rootlevel=True
)
return result

View File

@@ -1,17 +1,16 @@
from typing import (
Set, Iterable, Iterator, Optional, NewType
)
from typing import Set, Iterable, Iterator, Optional, NewType
import networkx as nx # type: ignore
from dbt.exceptions import InternalException
UniqueId = NewType('UniqueId', str)
UniqueId = NewType("UniqueId", str)
class Graph:
"""A wrapper around the networkx graph that understands SelectionCriteria
and how they interact with the graph.
"""
def __init__(self, graph):
self.graph = graph
@@ -29,12 +28,11 @@ class Graph:
) -> Set[UniqueId]:
"""Returns all nodes having a path to `node` in `graph`"""
if not self.graph.has_node(node):
raise InternalException(f'Node {node} not found in the graph!')
raise InternalException(f"Node {node} not found in the graph!")
with nx.utils.reversed(self.graph):
anc = nx.single_source_shortest_path_length(G=self.graph,
source=node,
cutoff=max_depth)\
.keys()
anc = nx.single_source_shortest_path_length(
G=self.graph, source=node, cutoff=max_depth
).keys()
return anc - {node}
def descendants(
@@ -42,16 +40,13 @@ class Graph:
) -> Set[UniqueId]:
"""Returns all nodes reachable from `node` in `graph`"""
if not self.graph.has_node(node):
raise InternalException(f'Node {node} not found in the graph!')
des = nx.single_source_shortest_path_length(G=self.graph,
source=node,
cutoff=max_depth)\
.keys()
raise InternalException(f"Node {node} not found in the graph!")
des = nx.single_source_shortest_path_length(
G=self.graph, source=node, cutoff=max_depth
).keys()
return des - {node}
def select_childrens_parents(
self, selected: Set[UniqueId]
) -> Set[UniqueId]:
def select_childrens_parents(self, selected: Set[UniqueId]) -> Set[UniqueId]:
ancestors_for = self.select_children(selected) | selected
return self.select_parents(ancestors_for) | ancestors_for
@@ -77,7 +72,7 @@ class Graph:
successors.update(self.graph.successors(node))
return successors
def get_subset_graph(self, selected: Iterable[UniqueId]) -> 'Graph':
def get_subset_graph(self, selected: Iterable[UniqueId]) -> "Graph":
"""Create and return a new graph that is a shallow copy of the graph,
but with only the nodes in include_nodes. Transitive edges across
removed nodes are preserved as explicit new edges.
@@ -98,7 +93,7 @@ class Graph:
)
return Graph(new_graph)
def subgraph(self, nodes: Iterable[UniqueId]) -> 'Graph':
def subgraph(self, nodes: Iterable[UniqueId]) -> "Graph":
return Graph(self.graph.subgraph(nodes))
def get_dependent_nodes(self, node: UniqueId):

View File

@@ -1,8 +1,6 @@
import threading
from queue import PriorityQueue
from typing import (
Dict, Set, Optional
)
from typing import Dict, Set, Optional
import networkx as nx # type: ignore
@@ -21,9 +19,8 @@ class GraphQueue:
that separate threads do not call `.empty()` or `__len__()` and `.get()` at
the same time, as there is an unlocked race!
"""
def __init__(
self, graph: nx.DiGraph, manifest: Manifest, selected: Set[UniqueId]
):
def __init__(self, graph: nx.DiGraph, manifest: Manifest, selected: Set[UniqueId]):
self.graph = graph
self.manifest = manifest
self._selected = selected
@@ -75,10 +72,13 @@ class GraphQueue:
"""
scores = {}
for node in self.graph.nodes():
score = -1 * len([
d for d in nx.descendants(self.graph, node)
if self._include_in_cost(d)
])
score = -1 * len(
[
d
for d in nx.descendants(self.graph, node)
if self._include_in_cost(d)
]
)
scores[node] = score
return scores

View File

@@ -1,4 +1,3 @@
from typing import Set, List, Optional
from .graph import Graph, UniqueId
@@ -25,14 +24,13 @@ def get_package_names(nodes):
def alert_non_existence(raw_spec, nodes):
if len(nodes) == 0:
warn_or_error(
f"The selection criterion '{str(raw_spec)}' does not match"
f" any nodes"
f"The selection criterion '{str(raw_spec)}' does not match" f" any nodes"
)
class NodeSelector(MethodManager):
"""The node selector is aware of the graph and manifest,
"""
"""The node selector is aware of the graph and manifest,"""
def __init__(
self,
graph: Graph,
@@ -45,13 +43,16 @@ class NodeSelector(MethodManager):
# build a subgraph containing only non-empty, enabled nodes and enabled
# sources.
graph_members = {
unique_id for unique_id in self.full_graph.nodes()
unique_id
for unique_id in self.full_graph.nodes()
if self._is_graph_member(unique_id)
}
self.graph = self.full_graph.subgraph(graph_members)
def select_included(
self, included_nodes: Set[UniqueId], spec: SelectionCriteria,
self,
included_nodes: Set[UniqueId],
spec: SelectionCriteria,
) -> Set[UniqueId]:
"""Select the explicitly included nodes, using the given spec. Return
the selected set of unique IDs.
@@ -116,10 +117,7 @@ class NodeSelector(MethodManager):
if isinstance(spec, SelectionCriteria):
result = self.get_nodes_from_criteria(spec)
else:
node_selections = [
self.select_nodes(component)
for component in spec
]
node_selections = [self.select_nodes(component) for component in spec]
result = spec.combined(node_selections)
if spec.expect_exists:
alert_non_existence(spec.raw, result)
@@ -149,18 +147,14 @@ class NodeSelector(MethodManager):
elif unique_id in self.manifest.exposures:
node = self.manifest.exposures[unique_id]
else:
raise InternalException(
f'Node {unique_id} not found in the manifest!'
)
raise InternalException(f"Node {unique_id} not found in the manifest!")
return self.node_is_match(node)
def filter_selection(self, selected: Set[UniqueId]) -> Set[UniqueId]:
"""Return the subset of selected nodes that is a match for this
selector.
"""
return {
unique_id for unique_id in selected if self._is_match(unique_id)
}
return {unique_id for unique_id in selected if self._is_match(unique_id)}
def expand_selection(self, selected: Set[UniqueId]) -> Set[UniqueId]:
"""Perform selector-specific expansion."""
@@ -169,14 +163,14 @@ class NodeSelector(MethodManager):
def get_selected(self, spec: SelectionSpec) -> Set[UniqueId]:
"""get_selected runs trhough the node selection process:
- node selection. Based on the include/exclude sets, the set
of matched unique IDs is returned
- expand the graph at each leaf node, before combination
- selectors might override this. for example, this is where
tests are added
- filtering:
- selectors can filter the nodes after all of them have been
selected
- node selection. Based on the include/exclude sets, the set
of matched unique IDs is returned
- expand the graph at each leaf node, before combination
- selectors might override this. for example, this is where
tests are added
- filtering:
- selectors can filter the nodes after all of them have been
selected
"""
selected_nodes = self.select_nodes(spec)
filtered_nodes = self.filter_selection(selected_nodes)

View File

@@ -31,28 +31,28 @@ from dbt.node_types import NodeType
from dbt.ui import warning_tag
SELECTOR_GLOB = '*'
SELECTOR_DELIMITER = ':'
SELECTOR_GLOB = "*"
SELECTOR_DELIMITER = ":"
class MethodName(StrEnum):
FQN = 'fqn'
Tag = 'tag'
Source = 'source'
Path = 'path'
Package = 'package'
Config = 'config'
TestName = 'test_name'
TestType = 'test_type'
ResourceType = 'resource_type'
State = 'state'
Exposure = 'exposure'
FQN = "fqn"
Tag = "tag"
Source = "source"
Path = "path"
Package = "package"
Config = "config"
TestName = "test_name"
TestType = "test_type"
ResourceType = "resource_type"
State = "state"
Exposure = "exposure"
def is_selected_node(real_node, node_selector):
for i, selector_part in enumerate(node_selector):
is_last = (i == len(node_selector) - 1)
is_last = i == len(node_selector) - 1
# if we hit a GLOB, then this node is selected
if selector_part == SELECTOR_GLOB:
@@ -83,15 +83,14 @@ class SelectorMethod(metaclass=abc.ABCMeta):
self,
manifest: Manifest,
previous_state: Optional[PreviousState],
arguments: List[str]
arguments: List[str],
):
self.manifest: Manifest = manifest
self.previous_state = previous_state
self.arguments: List[str] = arguments
def parsed_nodes(
self,
included_nodes: Set[UniqueId]
self, included_nodes: Set[UniqueId]
) -> Iterator[Tuple[UniqueId, ManifestNode]]:
for key, node in self.manifest.nodes.items():
@@ -101,8 +100,7 @@ class SelectorMethod(metaclass=abc.ABCMeta):
yield unique_id, node
def source_nodes(
self,
included_nodes: Set[UniqueId]
self, included_nodes: Set[UniqueId]
) -> Iterator[Tuple[UniqueId, ParsedSourceDefinition]]:
for key, source in self.manifest.sources.items():
@@ -112,8 +110,7 @@ class SelectorMethod(metaclass=abc.ABCMeta):
yield unique_id, source
def exposure_nodes(
self,
included_nodes: Set[UniqueId]
self, included_nodes: Set[UniqueId]
) -> Iterator[Tuple[UniqueId, ParsedExposure]]:
for key, exposure in self.manifest.exposures.items():
@@ -123,26 +120,28 @@ class SelectorMethod(metaclass=abc.ABCMeta):
yield unique_id, exposure
def all_nodes(
self,
included_nodes: Set[UniqueId]
self, included_nodes: Set[UniqueId]
) -> Iterator[Tuple[UniqueId, SelectorTarget]]:
yield from chain(self.parsed_nodes(included_nodes),
self.source_nodes(included_nodes),
self.exposure_nodes(included_nodes))
yield from chain(
self.parsed_nodes(included_nodes),
self.source_nodes(included_nodes),
self.exposure_nodes(included_nodes),
)
def configurable_nodes(
self,
included_nodes: Set[UniqueId]
self, included_nodes: Set[UniqueId]
) -> Iterator[Tuple[UniqueId, CompileResultNode]]:
yield from chain(self.parsed_nodes(included_nodes),
self.source_nodes(included_nodes))
yield from chain(
self.parsed_nodes(included_nodes), self.source_nodes(included_nodes)
)
def non_source_nodes(
self,
included_nodes: Set[UniqueId],
) -> Iterator[Tuple[UniqueId, Union[ParsedExposure, ManifestNode]]]:
yield from chain(self.parsed_nodes(included_nodes),
self.exposure_nodes(included_nodes))
yield from chain(
self.parsed_nodes(included_nodes), self.exposure_nodes(included_nodes)
)
@abc.abstractmethod
def search(
@@ -150,7 +149,7 @@ class SelectorMethod(metaclass=abc.ABCMeta):
included_nodes: Set[UniqueId],
selector: str,
) -> Iterator[UniqueId]:
raise NotImplementedError('subclasses should implement this')
raise NotImplementedError("subclasses should implement this")
class QualifiedNameSelectorMethod(SelectorMethod):
@@ -216,7 +215,7 @@ class SourceSelectorMethod(SelectorMethod):
self, included_nodes: Set[UniqueId], selector: str
) -> Iterator[UniqueId]:
"""yields nodes from included are the specified source."""
parts = selector.split('.')
parts = selector.split(".")
target_package = SELECTOR_GLOB
if len(parts) == 1:
target_source, target_table = parts[0], None
@@ -227,9 +226,9 @@ class SourceSelectorMethod(SelectorMethod):
else: # len(parts) > 3 or len(parts) == 0
msg = (
'Invalid source selector value "{}". Sources must be of the '
'form `${{source_name}}`, '
'`${{source_name}}.${{target_name}}`, or '
'`${{package_name}}.${{source_name}}.${{target_name}}'
"form `${{source_name}}`, "
"`${{source_name}}.${{target_name}}`, or "
"`${{package_name}}.${{source_name}}.${{target_name}}"
).format(selector)
raise RuntimeException(msg)
@@ -248,7 +247,7 @@ class ExposureSelectorMethod(SelectorMethod):
def search(
self, included_nodes: Set[UniqueId], selector: str
) -> Iterator[UniqueId]:
parts = selector.split('.')
parts = selector.split(".")
target_package = SELECTOR_GLOB
if len(parts) == 1:
target_name = parts[0]
@@ -257,8 +256,8 @@ class ExposureSelectorMethod(SelectorMethod):
else:
msg = (
'Invalid exposure selector value "{}". Exposures must be of '
'the form ${{exposure_name}} or '
'${{exposure_package.exposure_name}}'
"the form ${{exposure_name}} or "
"${{exposure_package.exposure_name}}"
).format(selector)
raise RuntimeException(msg)
@@ -275,9 +274,7 @@ class PathSelectorMethod(SelectorMethod):
def search(
self, included_nodes: Set[UniqueId], selector: str
) -> Iterator[UniqueId]:
"""Yields nodes from inclucded that match the given path.
"""
"""Yields nodes from inclucded that match the given path."""
# use '.' and not 'root' for easy comparison
root = Path.cwd()
paths = set(p.relative_to(root) for p in root.glob(selector))
@@ -336,7 +333,7 @@ class ConfigSelectorMethod(SelectorMethod):
parts = self.arguments
# special case: if the user wanted to compare test severity,
# make the comparison case-insensitive
if parts == ['severity']:
if parts == ["severity"]:
selector = CaseInsensitive(selector)
# search sources is kind of useless now source configs only have
@@ -382,14 +379,13 @@ class TestTypeSelectorMethod(SelectorMethod):
self, included_nodes: Set[UniqueId], selector: str
) -> Iterator[UniqueId]:
search_types: Tuple[Type, ...]
if selector == 'schema':
if selector == "schema":
search_types = (ParsedSchemaTestNode, CompiledSchemaTestNode)
elif selector == 'data':
elif selector == "data":
search_types = (ParsedDataTestNode, CompiledDataTestNode)
else:
raise RuntimeException(
f'Invalid test type selector {selector}: expected "data" or '
'"schema"'
f'Invalid test type selector {selector}: expected "data" or ' '"schema"'
)
for node, real_node in self.parsed_nodes(included_nodes):
@@ -405,25 +401,23 @@ class StateSelectorMethod(SelectorMethod):
def _macros_modified(self) -> List[str]:
# we checked in the caller!
if self.previous_state is None or self.previous_state.manifest is None:
raise InternalException(
'No comparison manifest in _macros_modified'
)
raise InternalException("No comparison manifest in _macros_modified")
old_macros = self.previous_state.manifest.macros
new_macros = self.manifest.macros
modified = []
for uid, macro in new_macros.items():
name = f'{macro.package_name}.{macro.name}'
name = f"{macro.package_name}.{macro.name}"
if uid in old_macros:
old_macro = old_macros[uid]
if macro.macro_sql != old_macro.macro_sql:
modified.append(f'{name} changed')
modified.append(f"{name} changed")
else:
modified.append(f'{name} added')
modified.append(f"{name} added")
for uid, macro in old_macros.items():
if uid not in new_macros:
modified.append(f'{macro.package_name}.{macro.name} removed')
modified.append(f"{macro.package_name}.{macro.name} removed")
return modified[:3]
@@ -437,12 +431,14 @@ class StateSelectorMethod(SelectorMethod):
if self.macros_were_modified is None:
self.macros_were_modified = self._macros_modified()
if self.macros_were_modified:
log_str = ', '.join(self.macros_were_modified)
logger.warning(warning_tag(
f'During a state comparison, dbt detected a change in '
f'macros. This will not be marked as a modification. Some '
f'macros: {log_str}'
))
log_str = ", ".join(self.macros_were_modified)
logger.warning(
warning_tag(
f"During a state comparison, dbt detected a change in "
f"macros. This will not be marked as a modification. Some "
f"macros: {log_str}"
)
)
return not new.same_contents(old) # type: ignore
@@ -458,12 +454,12 @@ class StateSelectorMethod(SelectorMethod):
) -> Iterator[UniqueId]:
if self.previous_state is None or self.previous_state.manifest is None:
raise RuntimeException(
'Got a state selector method, but no comparison manifest'
"Got a state selector method, but no comparison manifest"
)
state_checks = {
'modified': self.check_modified,
'new': self.check_new,
"modified": self.check_modified,
"new": self.check_new,
}
if selector in state_checks:
checker = state_checks[selector]
@@ -517,7 +513,7 @@ class MethodManager:
if method not in self.SELECTOR_METHODS:
raise InternalException(
f'Method name "{method}" is a valid node selection '
f'method name, but it is not handled'
f"method name, but it is not handled"
)
cls: Type[SelectorMethod] = self.SELECTOR_METHODS[method]
return cls(self.manifest, self.previous_state, method_arguments)

View File

@@ -3,23 +3,21 @@ import re
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import (
Set, Iterator, List, Optional, Dict, Union, Any, Iterable, Tuple
)
from typing import Set, Iterator, List, Optional, Dict, Union, Any, Iterable, Tuple
from .graph import UniqueId
from .selector_methods import MethodName
from dbt.exceptions import RuntimeException, InvalidSelectorException
RAW_SELECTOR_PATTERN = re.compile(
r'\A'
r'(?P<childrens_parents>(\@))?'
r'(?P<parents>((?P<parents_depth>(\d*))\+))?'
r'((?P<method>([\w.]+)):)?(?P<value>(.*?))'
r'(?P<children>(\+(?P<children_depth>(\d*))))?'
r'\Z'
r"\A"
r"(?P<childrens_parents>(\@))?"
r"(?P<parents>((?P<parents_depth>(\d*))\+))?"
r"((?P<method>([\w.]+)):)?(?P<value>(.*?))"
r"(?P<children>(\+(?P<children_depth>(\d*))))?"
r"\Z"
)
SELECTOR_METHOD_SEPARATOR = '.'
SELECTOR_METHOD_SEPARATOR = "."
def _probably_path(value: str):
@@ -43,15 +41,15 @@ def _match_to_int(match: Dict[str, str], key: str) -> Optional[int]:
return int(raw)
except ValueError as exc:
raise RuntimeException(
f'Invalid node spec - could not handle parent depth {raw}'
f"Invalid node spec - could not handle parent depth {raw}"
) from exc
SelectionSpec = Union[
'SelectionCriteria',
'SelectionIntersection',
'SelectionDifference',
'SelectionUnion',
"SelectionCriteria",
"SelectionIntersection",
"SelectionDifference",
"SelectionUnion",
]
@@ -71,7 +69,7 @@ class SelectionCriteria:
if self.children and self.childrens_parents:
raise RuntimeException(
f'Invalid node spec {self.raw} - "@" prefix and "+" suffix '
'are incompatible'
"are incompatible"
)
@classmethod
@@ -82,12 +80,10 @@ class SelectionCriteria:
return MethodName.FQN
@classmethod
def parse_method(
cls, groupdict: Dict[str, Any]
) -> Tuple[MethodName, List[str]]:
raw_method = groupdict.get('method')
def parse_method(cls, groupdict: Dict[str, Any]) -> Tuple[MethodName, List[str]]:
raw_method = groupdict.get("method")
if raw_method is None:
return cls.default_method(groupdict['value']), []
return cls.default_method(groupdict["value"]), []
method_parts: List[str] = raw_method.split(SELECTOR_METHOD_SEPARATOR)
try:
@@ -104,24 +100,22 @@ class SelectionCriteria:
@classmethod
def selection_criteria_from_dict(
cls, raw: Any, dct: Dict[str, Any]
) -> 'SelectionCriteria':
if 'value' not in dct:
raise RuntimeException(
f'Invalid node spec "{raw}" - no search value!'
)
) -> "SelectionCriteria":
if "value" not in dct:
raise RuntimeException(f'Invalid node spec "{raw}" - no search value!')
method_name, method_arguments = cls.parse_method(dct)
parents_depth = _match_to_int(dct, 'parents_depth')
children_depth = _match_to_int(dct, 'children_depth')
parents_depth = _match_to_int(dct, "parents_depth")
children_depth = _match_to_int(dct, "children_depth")
return cls(
raw=raw,
method=method_name,
method_arguments=method_arguments,
value=dct['value'],
childrens_parents=bool(dct.get('childrens_parents')),
parents=bool(dct.get('parents')),
value=dct["value"],
childrens_parents=bool(dct.get("childrens_parents")),
parents=bool(dct.get("parents")),
parents_depth=parents_depth,
children=bool(dct.get('children')),
children=bool(dct.get("children")),
children_depth=children_depth,
)
@@ -129,24 +123,24 @@ class SelectionCriteria:
def dict_from_single_spec(cls, raw: str):
result = RAW_SELECTOR_PATTERN.match(raw)
if result is None:
return {'error': 'Invalid selector spec'}
return {"error": "Invalid selector spec"}
dct: Dict[str, Any] = result.groupdict()
method_name, method_arguments = cls.parse_method(dct)
meth_name = str(method_name)
if method_arguments:
meth_name = meth_name + '.' + '.'.join(method_arguments)
dct['method'] = meth_name
dct = {k: v for k, v in dct.items() if (v is not None and v != '')}
if 'childrens_parents' in dct:
dct['childrens_parents'] = bool(dct.get('childrens_parents'))
if 'parents' in dct:
dct['parents'] = bool(dct.get('parents'))
if 'children' in dct:
dct['children'] = bool(dct.get('children'))
meth_name = meth_name + "." + ".".join(method_arguments)
dct["method"] = meth_name
dct = {k: v for k, v in dct.items() if (v is not None and v != "")}
if "childrens_parents" in dct:
dct["childrens_parents"] = bool(dct.get("childrens_parents"))
if "parents" in dct:
dct["parents"] = bool(dct.get("parents"))
if "children" in dct:
dct["children"] = bool(dct.get("children"))
return dct
@classmethod
def from_single_spec(cls, raw: str) -> 'SelectionCriteria':
def from_single_spec(cls, raw: str) -> "SelectionCriteria":
result = RAW_SELECTOR_PATTERN.match(raw)
if result is None:
# bad spec!
@@ -175,9 +169,7 @@ class BaseSelectionGroup(Iterable[SelectionSpec], metaclass=ABCMeta):
self,
selections: List[Set[UniqueId]],
) -> Set[UniqueId]:
raise NotImplementedError(
'_combine_selections not implemented!'
)
raise NotImplementedError("_combine_selections not implemented!")
def combined(self, selections: List[Set[UniqueId]]) -> Set[UniqueId]:
if not selections:

View File

@@ -5,7 +5,9 @@ from pathlib import Path
from typing import Tuple, AbstractSet, Union
from dbt.dataclass_schema import (
dbtClassMixin, ValidationError, StrEnum,
dbtClassMixin,
ValidationError,
StrEnum,
)
from hologram import FieldEncoder, JsonDict
from mashumaro.types import SerializableType
@@ -13,11 +15,11 @@ from mashumaro.types import SerializableType
class Port(int, SerializableType):
@classmethod
def _deserialize(cls, value: Union[int, str]) -> 'Port':
def _deserialize(cls, value: Union[int, str]) -> "Port":
try:
value = int(value)
except ValueError:
raise ValidationError(f'Cannot encode {value} into port number')
raise ValidationError(f"Cannot encode {value} into port number")
return Port(value)
@@ -28,7 +30,7 @@ class Port(int, SerializableType):
class PortEncoder(FieldEncoder):
@property
def json_schema(self):
return {'type': 'integer', 'minimum': 0, 'maximum': 65535}
return {"type": "integer", "minimum": 0, "maximum": 65535}
class TimeDeltaFieldEncoder(FieldEncoder[timedelta]):
@@ -44,12 +46,12 @@ class TimeDeltaFieldEncoder(FieldEncoder[timedelta]):
return timedelta(seconds=value)
except TypeError:
raise ValidationError(
'cannot encode {} into timedelta'.format(value)
"cannot encode {} into timedelta".format(value)
) from None
@property
def json_schema(self) -> JsonDict:
return {'type': 'number'}
return {"type": "number"}
class PathEncoder(FieldEncoder):
@@ -63,16 +65,16 @@ class PathEncoder(FieldEncoder):
return Path(value)
except TypeError:
raise ValidationError(
'cannot encode {} into timedelta'.format(value)
"cannot encode {} into timedelta".format(value)
) from None
@property
def json_schema(self) -> JsonDict:
return {'type': 'string'}
return {"type": "string"}
class NVEnum(StrEnum):
novalue = 'novalue'
novalue = "novalue"
def __eq__(self, other):
return isinstance(other, NVEnum)
@@ -81,14 +83,17 @@ class NVEnum(StrEnum):
@dataclass
class NoValue(dbtClassMixin):
"""Sometimes, you want a way to say none that isn't None"""
novalue: NVEnum = NVEnum.novalue
dbtClassMixin.register_field_encoders({
Port: PortEncoder(),
timedelta: TimeDeltaFieldEncoder(),
Path: PathEncoder(),
})
dbtClassMixin.register_field_encoders(
{
Port: PortEncoder(),
timedelta: TimeDeltaFieldEncoder(),
Path: PathEncoder(),
}
)
FQNPath = Tuple[str, ...]

View File

@@ -5,8 +5,8 @@ from typing import Union, Dict, Any
class ModelHookType(StrEnum):
PreHook = 'pre-hook'
PostHook = 'post-hook'
PreHook = "pre-hook"
PostHook = "post-hook"
def get_hook_dict(source: Union[str, Dict[str, Any]]) -> Dict[str, Any]:
@@ -18,4 +18,4 @@ def get_hook_dict(source: Union[str, Dict[str, Any]]) -> Dict[str, Any]:
try:
return json.loads(source)
except ValueError:
return {'sql': source}
return {"sql": source}

View File

@@ -1,7 +1,6 @@
import os
PACKAGE_PATH = os.path.dirname(__file__)
PROJECT_NAME = 'dbt'
PROJECT_NAME = "dbt"
DOCS_INDEX_FILE_PATH = os.path.normpath(
os.path.join(PACKAGE_PATH, '..', "index.html"))
DOCS_INDEX_FILE_PATH = os.path.normpath(os.path.join(PACKAGE_PATH, "..", "index.html"))

View File

@@ -287,4 +287,3 @@
{% macro set_sql_header(config) -%}
{{ config.set('sql_header', caller()) }}
{%- endmacro %}

View File

@@ -70,7 +70,7 @@
deletes_source_data as (
select
select
*,
{{ strategy.unique_key }} as dbt_unique_key
from snapshot_query
@@ -113,7 +113,7 @@
,
deletes as (
select
'delete' as dbt_change_type,
source_data.*,
@@ -121,7 +121,7 @@
{{ snapshot_get_time() }} as dbt_updated_at,
{{ snapshot_get_time() }} as dbt_valid_to,
snapshotted_data.dbt_scd_id
from snapshotted_data
left join deletes_source_data as source_data on snapshotted_data.dbt_unique_key = source_data.dbt_unique_key
where source_data.dbt_unique_key is null

View File

@@ -23,5 +23,3 @@
values ({{ insert_cols_csv }})
;
{% endmacro %}

View File

@@ -134,7 +134,7 @@
{% set check_cols_config = config['check_cols'] %}
{% set primary_key = config['unique_key'] %}
{% set invalidate_hard_deletes = config.get('invalidate_hard_deletes', false) %}
{% set select_current_time -%}
select {{ snapshot_get_time() }} as snapshot_start
{%- endset %}

View File

@@ -1,4 +1,4 @@
ProfileConfigDocs = 'https://docs.getdbt.com/docs/configure-your-profile'
SnowflakeQuotingDocs = 'https://docs.getdbt.com/v0.10/docs/configuring-quoting'
IncrementalDocs = 'https://docs.getdbt.com/docs/configuring-incremental-models'
BigQueryNewPartitionBy = 'https://docs.getdbt.com/docs/upgrading-to-0-16-0'
ProfileConfigDocs = "https://docs.getdbt.com/docs/configure-your-profile"
SnowflakeQuotingDocs = "https://docs.getdbt.com/v0.10/docs/configuring-quoting"
IncrementalDocs = "https://docs.getdbt.com/docs/configuring-incremental-models"
BigQueryNewPartitionBy = "https://docs.getdbt.com/docs/upgrading-to-0-16-0"

View File

@@ -26,21 +26,21 @@ colorama_wrap = True
colorama.init(wrap=colorama_wrap)
if sys.platform == 'win32' and not os.getenv('TERM'):
if sys.platform == "win32" and not os.getenv("TERM"):
colorama_wrap = False
colorama_stdout = colorama.AnsiToWin32(sys.stdout).stream
elif sys.platform == 'win32':
elif sys.platform == "win32":
colorama_wrap = False
colorama.init(wrap=colorama_wrap)
STDOUT_LOG_FORMAT = '{record.message}'
STDOUT_LOG_FORMAT = "{record.message}"
DEBUG_LOG_FORMAT = (
'{record.time:%Y-%m-%d %H:%M:%S.%f%z} '
'({record.thread_name}): '
'{record.message}'
"{record.time:%Y-%m-%d %H:%M:%S.%f%z} "
"({record.thread_name}): "
"{record.message}"
)
@@ -94,6 +94,7 @@ class JsonFormatter(LogMessageFormatter):
"""Return a the record converted to LogMessage's JSON form"""
# utils imports exceptions which imports logger...
import dbt.utils
log_message = super().__call__(record, handler)
dct = log_message.to_dict(omit_none=True)
return json.dumps(dct, cls=dbt.utils.JSONEncoder)
@@ -117,9 +118,7 @@ class FormatterMixin:
self.format_string = self._text_format_string
def reset(self):
raise NotImplementedError(
'reset() not implemented in FormatterMixin subclass'
)
raise NotImplementedError("reset() not implemented in FormatterMixin subclass")
class OutputHandler(logbook.StreamHandler, FormatterMixin):
@@ -164,9 +163,9 @@ class OutputHandler(logbook.StreamHandler, FormatterMixin):
if record.level < self.level:
return False
text_mode = self.formatter_class is logbook.StringFormatter
if text_mode and record.extra.get('json_only', False):
if text_mode and record.extra.get("json_only", False):
return False
elif not text_mode and record.extra.get('text_only', False):
elif not text_mode and record.extra.get("text_only", False):
return False
else:
return True
@@ -177,7 +176,7 @@ def _redirect_std_logging():
def _root_channel(record: logbook.LogRecord) -> str:
return record.channel.split('.')[0]
return record.channel.split(".")[0]
class Relevel(logbook.Processor):
@@ -195,7 +194,7 @@ class Relevel(logbook.Processor):
def process(self, record):
if _root_channel(record) in self.allowed:
return
record.extra['old_level'] = record.level
record.extra["old_level"] = record.level
# suppress logs at/below our min level by lowering them to NOTSET
if record.level < self.min_level:
record.level = logbook.NOTSET
@@ -207,12 +206,12 @@ class Relevel(logbook.Processor):
class JsonOnly(logbook.Processor):
def process(self, record):
record.extra['json_only'] = True
record.extra["json_only"] = True
class TextOnly(logbook.Processor):
def process(self, record):
record.extra['text_only'] = True
record.extra["text_only"] = True
class TimingProcessor(logbook.Processor):
@@ -222,8 +221,7 @@ class TimingProcessor(logbook.Processor):
def process(self, record):
if self.timing_info is not None:
record.extra['timing_info'] = self.timing_info.to_dict(
omit_none=True)
record.extra["timing_info"] = self.timing_info.to_dict(omit_none=True)
class DbtProcessState(logbook.Processor):
@@ -233,11 +231,10 @@ class DbtProcessState(logbook.Processor):
def process(self, record):
overwrite = (
'run_state' not in record.extra or
record.extra['run_state'] == 'internal'
"run_state" not in record.extra or record.extra["run_state"] == "internal"
)
if overwrite:
record.extra['run_state'] = self.value
record.extra["run_state"] = self.value
class DbtModelState(logbook.Processor):
@@ -251,7 +248,7 @@ class DbtModelState(logbook.Processor):
class DbtStatusMessage(logbook.Processor):
def process(self, record):
record.extra['is_status_message'] = True
record.extra["is_status_message"] = True
class UniqueID(logbook.Processor):
@@ -260,7 +257,7 @@ class UniqueID(logbook.Processor):
super().__init__()
def process(self, record):
record.extra['unique_id'] = self.unique_id
record.extra["unique_id"] = self.unique_id
class NodeCount(logbook.Processor):
@@ -269,7 +266,7 @@ class NodeCount(logbook.Processor):
super().__init__()
def process(self, record):
record.extra['node_count'] = self.node_count
record.extra["node_count"] = self.node_count
class NodeMetadata(logbook.Processor):
@@ -289,26 +286,26 @@ class NodeMetadata(logbook.Processor):
def process(self, record):
self.process_keys(record)
record.extra['node_index'] = self.index
record.extra["node_index"] = self.index
class ModelMetadata(NodeMetadata):
def mapping_keys(self):
return [
('alias', 'node_alias'),
('schema', 'node_schema'),
('database', 'node_database'),
('original_file_path', 'node_path'),
('name', 'node_name'),
('resource_type', 'resource_type'),
('depends_on_nodes', 'depends_on'),
("alias", "node_alias"),
("schema", "node_schema"),
("database", "node_database"),
("original_file_path", "node_path"),
("name", "node_name"),
("resource_type", "resource_type"),
("depends_on_nodes", "depends_on"),
]
def process_config(self, record):
if hasattr(self.node, 'config'):
materialized = getattr(self.node.config, 'materialized', None)
if hasattr(self.node, "config"):
materialized = getattr(self.node.config, "materialized", None)
if materialized is not None:
record.extra['node_materialized'] = materialized
record.extra["node_materialized"] = materialized
def process(self, record):
super().process(record)
@@ -318,8 +315,8 @@ class ModelMetadata(NodeMetadata):
class HookMetadata(NodeMetadata):
def mapping_keys(self):
return [
('name', 'node_name'),
('resource_type', 'resource_type'),
("name", "node_name"),
("resource_type", "resource_type"),
]
@@ -333,30 +330,31 @@ class TimestampNamed(logbook.Processor):
record.extra[self.name] = datetime.utcnow().isoformat()
logger = logbook.Logger('dbt')
logger = logbook.Logger("dbt")
# provide this for the cache, disabled by default
CACHE_LOGGER = logbook.Logger('dbt.cache')
CACHE_LOGGER = logbook.Logger("dbt.cache")
CACHE_LOGGER.disable()
warnings.filterwarnings("ignore", category=ResourceWarning,
message="unclosed.*<socket.socket.*>")
warnings.filterwarnings(
"ignore", category=ResourceWarning, message="unclosed.*<socket.socket.*>"
)
initialized = False
def make_log_dir_if_missing(log_dir):
import dbt.clients.system
dbt.clients.system.make_directory(log_dir)
class DebugWarnings(logbook.compat.redirected_warnings):
"""Log warnings, except send them to 'debug' instead of 'warning' level.
"""
"""Log warnings, except send them to 'debug' instead of 'warning' level."""
def make_record(self, message, exception, filename, lineno):
rv = super().make_record(message, exception, filename, lineno)
rv.level = logbook.DEBUG
rv.extra['from_warnings'] = True
rv.extra["from_warnings"] = True
return rv
@@ -408,14 +406,14 @@ class DelayedFileHandler(logbook.RotatingFileHandler, FormatterMixin):
if self.disabled:
return
assert not self.initialized, 'set_path called after being set'
assert not self.initialized, "set_path called after being set"
if log_dir is None:
self.disabled = True
return
make_log_dir_if_missing(log_dir)
log_path = os.path.join(log_dir, 'dbt.log')
log_path = os.path.join(log_dir, "dbt.log")
self._super_init(log_path)
self._replay_buffered()
self._log_path = log_path
@@ -435,8 +433,9 @@ class DelayedFileHandler(logbook.RotatingFileHandler, FormatterMixin):
FormatterMixin.__init__(self, DEBUG_LOG_FORMAT)
def _replay_buffered(self):
assert self._msg_buffer is not None, \
'_msg_buffer should never be None in _replay_buffered'
assert (
self._msg_buffer is not None
), "_msg_buffer should never be None in _replay_buffered"
for record in self._msg_buffer:
super().emit(record)
self._msg_buffer = None
@@ -445,7 +444,7 @@ class DelayedFileHandler(logbook.RotatingFileHandler, FormatterMixin):
msg = super().format(record)
subbed = str(msg)
for escape_sequence in dbt.ui.COLORS.values():
subbed = subbed.replace(escape_sequence, '')
subbed = subbed.replace(escape_sequence, "")
return subbed
def emit(self, record: logbook.LogRecord):
@@ -457,11 +456,13 @@ class DelayedFileHandler(logbook.RotatingFileHandler, FormatterMixin):
elif self.initialized:
super().emit(record)
else:
assert self._msg_buffer is not None, \
'_msg_buffer should never be None if _log_path is set'
assert (
self._msg_buffer is not None
), "_msg_buffer should never be None if _log_path is set"
self._msg_buffer.append(record)
assert len(self._msg_buffer) < self._bufmax, \
'too many messages received before initilization!'
assert (
len(self._msg_buffer) < self._bufmax
), "too many messages received before initilization!"
class LogManager(logbook.NestedSetup):
@@ -471,19 +472,21 @@ class LogManager(logbook.NestedSetup):
self._null_handler = logbook.NullHandler()
self._output_handler = OutputHandler(self.stdout)
self._file_handler = DelayedFileHandler()
self._relevel_processor = Relevel(allowed=['dbt', 'werkzeug'])
self._state_processor = DbtProcessState('internal')
self._relevel_processor = Relevel(allowed=["dbt", "werkzeug"])
self._state_processor = DbtProcessState("internal")
# keep track of wheter we've already entered to decide if we should
# be actually pushing. This allows us to log in main() and also
# support entering dbt execution via handle_and_check.
self._stack_depth = 0
super().__init__([
self._null_handler,
self._output_handler,
self._file_handler,
self._relevel_processor,
self._state_processor,
])
super().__init__(
[
self._null_handler,
self._output_handler,
self._file_handler,
self._relevel_processor,
self._state_processor,
]
)
def push_application(self):
self._stack_depth += 1
@@ -499,8 +502,7 @@ class LogManager(logbook.NestedSetup):
self.add_handler(logbook.NullHandler())
def add_handler(self, handler):
"""add an handler to the log manager that runs before the file handler.
"""
"""add an handler to the log manager that runs before the file handler."""
self.objects.append(handler)
# this is used by `dbt ls` to allow piping stdout to jq, etc
@@ -558,8 +560,7 @@ log_manager = LogManager()
def log_cache_events(flag):
"""Set the cache logger to propagate its messages based on the given flag.
"""
"""Set the cache logger to propagate its messages based on the given flag."""
# the flag is True if we should log, and False if we shouldn't, so disabled
# is the inverse.
CACHE_LOGGER.disabled = not flag
@@ -583,7 +584,7 @@ class ListLogHandler(LogMessageHandler):
level: int = logbook.NOTSET,
filter: Callable = None,
bubble: bool = False,
lst: Optional[List[LogMessage]] = None
lst: Optional[List[LogMessage]] = None,
) -> None:
super().__init__(level, filter, bubble)
if lst is None:
@@ -592,7 +593,7 @@ class ListLogHandler(LogMessageHandler):
def should_handle(self, record):
"""Only ever emit dbt-sourced log messages to the ListHandler."""
if _root_channel(record) != 'dbt':
if _root_channel(record) != "dbt":
return False
return super().should_handle(record)
@@ -609,28 +610,27 @@ def _env_log_level(var_name: str) -> int:
return logging.ERROR
LOG_LEVEL_GOOGLE = _env_log_level('DBT_GOOGLE_DEBUG_LOGGING')
LOG_LEVEL_SNOWFLAKE = _env_log_level('DBT_SNOWFLAKE_CONNECTOR_DEBUG_LOGGING')
LOG_LEVEL_BOTOCORE = _env_log_level('DBT_BOTOCORE_DEBUG_LOGGING')
LOG_LEVEL_HTTP = _env_log_level('DBT_HTTP_DEBUG_LOGGING')
LOG_LEVEL_WERKZEUG = _env_log_level('DBT_WERKZEUG_DEBUG_LOGGING')
LOG_LEVEL_GOOGLE = _env_log_level("DBT_GOOGLE_DEBUG_LOGGING")
LOG_LEVEL_SNOWFLAKE = _env_log_level("DBT_SNOWFLAKE_CONNECTOR_DEBUG_LOGGING")
LOG_LEVEL_BOTOCORE = _env_log_level("DBT_BOTOCORE_DEBUG_LOGGING")
LOG_LEVEL_HTTP = _env_log_level("DBT_HTTP_DEBUG_LOGGING")
LOG_LEVEL_WERKZEUG = _env_log_level("DBT_WERKZEUG_DEBUG_LOGGING")
logging.getLogger('botocore').setLevel(LOG_LEVEL_BOTOCORE)
logging.getLogger('requests').setLevel(LOG_LEVEL_HTTP)
logging.getLogger('urllib3').setLevel(LOG_LEVEL_HTTP)
logging.getLogger('google').setLevel(LOG_LEVEL_GOOGLE)
logging.getLogger('snowflake.connector').setLevel(LOG_LEVEL_SNOWFLAKE)
logging.getLogger("botocore").setLevel(LOG_LEVEL_BOTOCORE)
logging.getLogger("requests").setLevel(LOG_LEVEL_HTTP)
logging.getLogger("urllib3").setLevel(LOG_LEVEL_HTTP)
logging.getLogger("google").setLevel(LOG_LEVEL_GOOGLE)
logging.getLogger("snowflake.connector").setLevel(LOG_LEVEL_SNOWFLAKE)
logging.getLogger('parsedatetime').setLevel(logging.ERROR)
logging.getLogger('werkzeug').setLevel(LOG_LEVEL_WERKZEUG)
logging.getLogger("parsedatetime").setLevel(logging.ERROR)
logging.getLogger("werkzeug").setLevel(LOG_LEVEL_WERKZEUG)
def list_handler(
lst: Optional[List[LogMessage]],
level=logbook.NOTSET,
) -> ContextManager:
"""Return a context manager that temporarly attaches a list to the logger.
"""
"""Return a context manager that temporarly attaches a list to the logger."""
return ListLogHandler(lst=lst, level=level, bubble=True)

File diff suppressed because it is too large Load Diff

View File

@@ -4,20 +4,20 @@ from dbt.dataclass_schema import StrEnum
class NodeType(StrEnum):
Model = 'model'
Analysis = 'analysis'
Test = 'test'
Snapshot = 'snapshot'
Operation = 'operation'
Seed = 'seed'
RPCCall = 'rpc'
Documentation = 'docs'
Source = 'source'
Macro = 'macro'
Exposure = 'exposure'
Model = "model"
Analysis = "analysis"
Test = "test"
Snapshot = "snapshot"
Operation = "operation"
Seed = "seed"
RPCCall = "rpc"
Documentation = "docs"
Source = "source"
Macro = "macro"
Exposure = "exposure"
@classmethod
def executable(cls) -> List['NodeType']:
def executable(cls) -> List["NodeType"]:
return [
cls.Model,
cls.Test,
@@ -30,7 +30,7 @@ class NodeType(StrEnum):
]
@classmethod
def refable(cls) -> List['NodeType']:
def refable(cls) -> List["NodeType"]:
return [
cls.Model,
cls.Seed,
@@ -38,7 +38,7 @@ class NodeType(StrEnum):
]
@classmethod
def documentable(cls) -> List['NodeType']:
def documentable(cls) -> List["NodeType"]:
return [
cls.Model,
cls.Seed,
@@ -46,16 +46,16 @@ class NodeType(StrEnum):
cls.Source,
cls.Macro,
cls.Analysis,
cls.Exposure
cls.Exposure,
]
def pluralize(self) -> str:
if self == 'analysis':
return 'analyses'
if self == "analysis":
return "analyses"
else:
return f'{self}s'
return f"{self}s"
class RunHookType(StrEnum):
Start = 'on-run-start'
End = 'on-run-end'
Start = "on-run-start"
End = "on-run-end"

View File

@@ -11,6 +11,14 @@ from .seeds import SeedParser # noqa
from .snapshots import SnapshotParser # noqa
from . import ( # noqa
analysis, base, data_test, docs, hooks, macros, models, results, schemas,
snapshots
analysis,
base,
data_test,
docs,
hooks,
macros,
models,
results,
schemas,
snapshots,
)

View File

@@ -8,9 +8,7 @@ from dbt.parser.search import FilesystemSearcher, FileBlock
class AnalysisParser(SimpleSQLParser[ParsedAnalysisNode]):
def get_paths(self):
return FilesystemSearcher(
self.project, self.project.analysis_paths, '.sql'
)
return FilesystemSearcher(self.project, self.project.analysis_paths, ".sql")
def parse_from_dict(self, dct, validate=True) -> ParsedAnalysisNode:
if validate:
@@ -23,4 +21,4 @@ class AnalysisParser(SimpleSQLParser[ParsedAnalysisNode]):
@classmethod
def get_compiled_path(cls, block: FileBlock):
return os.path.join('analysis', block.path.relative_path)
return os.path.join("analysis", block.path.relative_path)

View File

@@ -1,9 +1,7 @@
import abc
import itertools
import os
from typing import (
List, Dict, Any, Iterable, Generic, TypeVar
)
from typing import List, Dict, Any, Iterable, Generic, TypeVar
from dbt.dataclass_schema import ValidationError
@@ -17,17 +15,15 @@ from dbt.context.providers import (
from dbt.adapters.factory import get_adapter
from dbt.clients.jinja import get_rendered
from dbt.config import Project, RuntimeConfig
from dbt.context.context_config import (
ContextConfig
)
from dbt.contracts.files import (
SourceFile, FilePath, FileHash
)
from dbt.context.context_config import ContextConfig
from dbt.contracts.files import SourceFile, FilePath, FileHash
from dbt.contracts.graph.manifest import MacroManifest
from dbt.contracts.graph.parsed import HasUniqueID
from dbt.contracts.graph.unparsed import UnparsedNode
from dbt.exceptions import (
CompilationException, validator_error_message, InternalException
CompilationException,
validator_error_message,
InternalException,
)
from dbt import hooks
from dbt.node_types import NodeType
@@ -37,14 +33,14 @@ from dbt.parser.search import FileBlock
# internally, the parser may store a less-restrictive type that will be
# transformed into the final type. But it will have to be derived from
# ParsedNode to be operable.
FinalValue = TypeVar('FinalValue', bound=HasUniqueID)
IntermediateValue = TypeVar('IntermediateValue', bound=HasUniqueID)
FinalValue = TypeVar("FinalValue", bound=HasUniqueID)
IntermediateValue = TypeVar("IntermediateValue", bound=HasUniqueID)
IntermediateNode = TypeVar('IntermediateNode', bound=Any)
FinalNode = TypeVar('FinalNode', bound=ManifestNodes)
IntermediateNode = TypeVar("IntermediateNode", bound=Any)
FinalNode = TypeVar("FinalNode", bound=ManifestNodes)
ConfiguredBlockType = TypeVar('ConfiguredBlockType', bound=FileBlock)
ConfiguredBlockType = TypeVar("ConfiguredBlockType", bound=FileBlock)
class BaseParser(Generic[FinalValue]):
@@ -73,9 +69,9 @@ class BaseParser(Generic[FinalValue]):
def generate_unique_id(self, resource_name: str) -> str:
"""Returns a unique identifier for a resource"""
return "{}.{}.{}".format(self.resource_type,
self.project.project_name,
resource_name)
return "{}.{}.{}".format(
self.resource_type, self.project.project_name, resource_name
)
def load_file(
self,
@@ -89,7 +85,7 @@ class BaseParser(Generic[FinalValue]):
if set_contents:
source_file.contents = file_contents.strip()
else:
source_file.contents = ''
source_file.contents = ""
return source_file
@@ -108,8 +104,7 @@ class Parser(BaseParser[FinalValue], Generic[FinalValue]):
class RelationUpdate:
def __init__(
self, config: RuntimeConfig, macro_manifest: MacroManifest,
component: str
self, config: RuntimeConfig, macro_manifest: MacroManifest, component: str
) -> None:
macro = macro_manifest.find_generate_macro_by_name(
component=component,
@@ -117,7 +112,7 @@ class RelationUpdate:
)
if macro is None:
raise InternalException(
f'No macro with name generate_{component}_name found'
f"No macro with name generate_{component}_name found"
)
root_context = generate_generate_component_name_macro(
@@ -126,9 +121,7 @@ class RelationUpdate:
self.updater = MacroGenerator(macro, root_context)
self.component = component
def __call__(
self, parsed_node: Any, config_dict: Dict[str, Any]
) -> None:
def __call__(self, parsed_node: Any, config_dict: Dict[str, Any]) -> None:
override = config_dict.get(self.component)
new_value = self.updater(override, parsed_node)
if isinstance(new_value, str):
@@ -150,16 +143,13 @@ class ConfiguredParser(
super().__init__(results, project, root_project, macro_manifest)
self._update_node_database = RelationUpdate(
macro_manifest=macro_manifest, config=root_project,
component='database'
macro_manifest=macro_manifest, config=root_project, component="database"
)
self._update_node_schema = RelationUpdate(
macro_manifest=macro_manifest, config=root_project,
component='schema'
macro_manifest=macro_manifest, config=root_project, component="schema"
)
self._update_node_alias = RelationUpdate(
macro_manifest=macro_manifest, config=root_project,
component='alias'
macro_manifest=macro_manifest, config=root_project, component="alias"
)
@abc.abstractclassmethod
@@ -206,7 +196,11 @@ class ConfiguredParser(
config[key] = [hooks.get_hook_dict(h) for h in config[key]]
def _create_error_node(
self, name: str, path: str, original_file_path: str, raw_sql: str,
self,
name: str,
path: str,
original_file_path: str,
raw_sql: str,
) -> UnparsedNode:
"""If we hit an error before we've actually parsed a node, provide some
level of useful information by attaching this to the exception.
@@ -239,20 +233,20 @@ class ConfiguredParser(
if name is None:
name = block.name
dct = {
'alias': name,
'schema': self.default_schema,
'database': self.default_database,
'fqn': fqn,
'name': name,
'root_path': self.project.project_root,
'resource_type': self.resource_type,
'path': path,
'original_file_path': block.path.original_file_path,
'package_name': self.project.project_name,
'raw_sql': block.contents,
'unique_id': self.generate_unique_id(name),
'config': self.config_dict(config),
'checksum': block.file.checksum.to_dict(omit_none=True),
"alias": name,
"schema": self.default_schema,
"database": self.default_database,
"fqn": fqn,
"name": name,
"root_path": self.project.project_root,
"resource_type": self.resource_type,
"path": path,
"original_file_path": block.path.original_file_path,
"package_name": self.project.project_name,
"raw_sql": block.contents,
"unique_id": self.generate_unique_id(name),
"config": self.config_dict(config),
"checksum": block.file.checksum.to_dict(omit_none=True),
}
dct.update(kwargs)
try:
@@ -290,9 +284,7 @@ class ConfiguredParser(
# this goes through the process of rendering, but just throws away
# the rendered result. The "macro capture" is the point?
get_rendered(
parsed_node.raw_sql, context, parsed_node, capture_macros=True
)
get_rendered(parsed_node.raw_sql, context, parsed_node, capture_macros=True)
# This is taking the original config for the node, converting it to a dict,
# updating the config with new config passed in, then re-creating the
@@ -324,12 +316,10 @@ class ConfiguredParser(
config_dict = config.build_config_dict()
# Set tags on node provided in config blocks
model_tags = config_dict.get('tags', [])
model_tags = config_dict.get("tags", [])
parsed_node.tags.extend(model_tags)
parsed_node.unrendered_config = config.build_config_dict(
rendered=False
)
parsed_node.unrendered_config = config.build_config_dict(rendered=False)
# do this once before we parse the node database/schema/alias, so
# parsed_node.config is what it would be if they did nothing
@@ -338,8 +328,9 @@ class ConfiguredParser(
# at this point, we've collected our hooks. Use the node context to
# render each hook and collect refs/sources
hooks = list(itertools.chain(parsed_node.config.pre_hook,
parsed_node.config.post_hook))
hooks = list(
itertools.chain(parsed_node.config.pre_hook, parsed_node.config.post_hook)
)
# skip context rebuilding if there aren't any hooks
if not hooks:
return
@@ -362,20 +353,18 @@ class ConfiguredParser(
)
else:
raise InternalException(
f'Got an unexpected project version={config_version}, '
f'expected 2'
f"Got an unexpected project version={config_version}, " f"expected 2"
)
def config_dict(
self, config: ContextConfig,
self,
config: ContextConfig,
) -> Dict[str, Any]:
config_dict = config.build_config_dict(base=True)
self._mangle_hooks(config_dict)
return config_dict
def render_update(
self, node: IntermediateNode, config: ContextConfig
) -> None:
def render_update(self, node: IntermediateNode, config: ContextConfig) -> None:
try:
self.render_with_context(node, config)
self.update_parsed_node(node, config)
@@ -418,7 +407,7 @@ class ConfiguredParser(
class SimpleParser(
ConfiguredParser[ConfiguredBlockType, FinalNode, FinalNode],
Generic[ConfiguredBlockType, FinalNode]
Generic[ConfiguredBlockType, FinalNode],
):
def transform(self, node):
return node
@@ -426,14 +415,12 @@ class SimpleParser(
class SQLParser(
ConfiguredParser[FileBlock, IntermediateNode, FinalNode],
Generic[IntermediateNode, FinalNode]
Generic[IntermediateNode, FinalNode],
):
def parse_file(self, file_block: FileBlock) -> None:
self.parse_node(file_block)
class SimpleSQLParser(
SQLParser[FinalNode, FinalNode]
):
class SimpleSQLParser(SQLParser[FinalNode, FinalNode]):
def transform(self, node):
return node

View File

@@ -7,9 +7,7 @@ from dbt.utils import get_pseudo_test_path
class DataTestParser(SimpleSQLParser[ParsedDataTestNode]):
def get_paths(self):
return FilesystemSearcher(
self.project, self.project.test_paths, '.sql'
)
return FilesystemSearcher(self.project, self.project.test_paths, ".sql")
def parse_from_dict(self, dct, validate=True) -> ParsedDataTestNode:
if validate:
@@ -21,11 +19,10 @@ class DataTestParser(SimpleSQLParser[ParsedDataTestNode]):
return NodeType.Test
def transform(self, node):
if 'data' not in node.tags:
node.tags.append('data')
if "data" not in node.tags:
node.tags.append("data")
return node
@classmethod
def get_compiled_path(cls, block: FileBlock):
return get_pseudo_test_path(block.name, block.path.relative_path,
'data_test')
return get_pseudo_test_path(block.name, block.path.relative_path, "data_test")

View File

@@ -7,11 +7,14 @@ from dbt.contracts.graph.parsed import ParsedDocumentation
from dbt.node_types import NodeType
from dbt.parser.base import Parser
from dbt.parser.search import (
BlockContents, FileBlock, FilesystemSearcher, BlockSearcher
BlockContents,
FileBlock,
FilesystemSearcher,
BlockSearcher,
)
SHOULD_PARSE_RE = re.compile(r'{[{%]')
SHOULD_PARSE_RE = re.compile(r"{[{%]")
class DocumentationParser(Parser[ParsedDocumentation]):
@@ -19,7 +22,7 @@ class DocumentationParser(Parser[ParsedDocumentation]):
return FilesystemSearcher(
project=self.project,
relative_dirs=self.project.docs_paths,
extension='.md',
extension=".md",
)
@property
@@ -33,11 +36,9 @@ class DocumentationParser(Parser[ParsedDocumentation]):
def generate_unique_id(self, resource_name: str) -> str:
# because docs are in their own graph namespace, node type doesn't
# need to be part of the unique ID.
return '{}.{}'.format(self.project.project_name, resource_name)
return "{}.{}".format(self.project.project_name, resource_name)
def parse_block(
self, block: BlockContents
) -> Iterable[ParsedDocumentation]:
def parse_block(self, block: BlockContents) -> Iterable[ParsedDocumentation]:
unique_id = self.generate_unique_id(block.name)
contents = get_rendered(block.contents, {}).strip()
@@ -55,7 +56,7 @@ class DocumentationParser(Parser[ParsedDocumentation]):
def parse_file(self, file_block: FileBlock):
searcher: Iterable[BlockContents] = BlockSearcher(
source=[file_block],
allowed_blocks={'docs'},
allowed_blocks={"docs"},
source_tag_factory=BlockContents,
)
for block in searcher:

View File

@@ -24,7 +24,7 @@ class HookBlock(FileBlock):
@property
def name(self):
return '{}-{!s}-{!s}'.format(self.project, self.hook_type, self.index)
return "{}-{!s}-{!s}".format(self.project, self.hook_type, self.index)
class HookSearcher(Iterable[HookBlock]):
@@ -33,9 +33,7 @@ class HookSearcher(Iterable[HookBlock]):
self.source_file = source_file
self.hook_type = hook_type
def _hook_list(
self, hooks: Union[str, List[str], Tuple[str, ...]]
) -> List[str]:
def _hook_list(self, hooks: Union[str, List[str], Tuple[str, ...]]) -> List[str]:
if isinstance(hooks, tuple):
hooks = list(hooks)
elif not isinstance(hooks, list):
@@ -49,8 +47,9 @@ class HookSearcher(Iterable[HookBlock]):
hooks = self.project.on_run_end
else:
raise InternalException(
'hook_type must be one of "{}" or "{}" (got {})'
.format(RunHookType.Start, RunHookType.End, self.hook_type)
'hook_type must be one of "{}" or "{}" (got {})'.format(
RunHookType.Start, RunHookType.End, self.hook_type
)
)
return self._hook_list(hooks)
@@ -73,8 +72,8 @@ class HookParser(SimpleParser[HookBlock, ParsedHookNode]):
def get_paths(self) -> List[FilePath]:
path = FilePath(
project_root=self.project.project_root,
searched_path='.',
relative_path='dbt_project.yml',
searched_path=".",
relative_path="dbt_project.yml",
)
return [path]
@@ -98,9 +97,13 @@ class HookParser(SimpleParser[HookBlock, ParsedHookNode]):
) -> ParsedHookNode:
return super()._create_parsetime_node(
block=block, path=path, config=config, fqn=fqn,
index=block.index, name=name,
tags=[str(block.hook_type)]
block=block,
path=path,
config=config,
fqn=fqn,
index=block.index,
name=name,
tags=[str(block.hook_type)],
)
@property

View File

@@ -18,7 +18,7 @@ class MacroParser(BaseParser[ParsedMacro]):
return FilesystemSearcher(
project=self.project,
relative_dirs=self.project.macro_paths,
extension='.sql',
extension=".sql",
)
@property
@@ -45,15 +45,13 @@ class MacroParser(BaseParser[ParsedMacro]):
unique_id=unique_id,
)
def parse_unparsed_macros(
self, base_node: UnparsedMacro
) -> Iterable[ParsedMacro]:
def parse_unparsed_macros(self, base_node: UnparsedMacro) -> Iterable[ParsedMacro]:
try:
blocks: List[jinja.BlockTag] = [
t for t in
jinja.extract_toplevel_blocks(
t
for t in jinja.extract_toplevel_blocks(
base_node.raw_sql,
allowed_blocks={'macro', 'materialization'},
allowed_blocks={"macro", "materialization"},
collect_raw_data=False,
)
if isinstance(t, jinja.BlockTag)
@@ -75,8 +73,8 @@ class MacroParser(BaseParser[ParsedMacro]):
# things have gone disastrously wrong, we thought we only
# parsed one block!
raise CompilationException(
f'Found multiple macros in {block.full_block}, expected 1',
node=base_node
f"Found multiple macros in {block.full_block}, expected 1",
node=base_node,
)
macro_name = macro_nodes[0].name
@@ -84,7 +82,7 @@ class MacroParser(BaseParser[ParsedMacro]):
if not macro_name.startswith(MACRO_PREFIX):
continue
name: str = macro_name.replace(MACRO_PREFIX, '')
name: str = macro_name.replace(MACRO_PREFIX, "")
node = self.parse_macro(block, base_node, name)
yield node

View File

@@ -3,7 +3,15 @@ from dataclasses import field
import os
import pickle
from typing import (
Dict, Optional, Mapping, Callable, Any, List, Type, Union, MutableMapping
Dict,
Optional,
Mapping,
Callable,
Any,
List,
Type,
Union,
MutableMapping,
)
import time
@@ -23,11 +31,13 @@ from dbt.config import Project, RuntimeConfig
from dbt.context.docs import generate_runtime_docs
from dbt.contracts.files import FilePath, FileHash
from dbt.contracts.graph.compiled import ManifestNode
from dbt.contracts.graph.manifest import (
Manifest, MacroManifest, AnyManifest, Disabled
)
from dbt.contracts.graph.manifest import Manifest, MacroManifest, AnyManifest, Disabled
from dbt.contracts.graph.parsed import (
ParsedSourceDefinition, ParsedNode, ParsedMacro, ColumnInfo, ParsedExposure
ParsedSourceDefinition,
ParsedNode,
ParsedMacro,
ColumnInfo,
ParsedExposure,
)
from dbt.contracts.util import Writable
from dbt.exceptions import (
@@ -55,8 +65,8 @@ from dbt.version import __version__
from dbt.dataclass_schema import dbtClassMixin
PARTIAL_PARSE_FILE_NAME = 'partial_parse.pickle'
PARSING_STATE = DbtProcessState('parsing')
PARTIAL_PARSE_FILE_NAME = "partial_parse.pickle"
PARSING_STATE = DbtProcessState("parsing")
DEFAULT_PARTIAL_PARSE = False
@@ -110,20 +120,22 @@ def make_parse_result(
"""Make a ParseResult from the project configuration and the profile."""
# if any of these change, we need to reject the parser
vars_hash = FileHash.from_contents(
'\x00'.join([
getattr(config.args, 'vars', '{}') or '{}',
getattr(config.args, 'profile', '') or '',
getattr(config.args, 'target', '') or '',
__version__
])
"\x00".join(
[
getattr(config.args, "vars", "{}") or "{}",
getattr(config.args, "profile", "") or "",
getattr(config.args, "target", "") or "",
__version__,
]
)
)
profile_path = os.path.join(config.args.profiles_dir, 'profiles.yml')
profile_path = os.path.join(config.args.profiles_dir, "profiles.yml")
with open(profile_path) as fp:
profile_hash = FileHash.from_contents(fp.read())
project_hashes = {}
for name, project in all_projects.items():
path = os.path.join(project.project_root, 'dbt_project.yml')
path = os.path.join(project.project_root, "dbt_project.yml")
with open(path) as fp:
project_hashes[name] = FileHash.from_contents(fp.read())
@@ -153,7 +165,8 @@ class ManifestLoader:
# in dictionaries: nodes, sources, docs, macros, exposures,
# macro_patches, patches, source_patches, files, etc
self.results: ParseResult = make_parse_result(
root_project, all_projects,
root_project,
all_projects,
)
self._loaded_file_cache: Dict[str, FileBlock] = {}
self._perf_info = ManifestLoaderInfo(
@@ -162,20 +175,18 @@ class ManifestLoader:
def track_project_load(self):
invocation_id = dbt.tracking.active_user.invocation_id
dbt.tracking.track_project_load({
"invocation_id": invocation_id,
"project_id": self.root_project.hashed_name(),
"path_count": self._perf_info.path_count,
"parse_project_elapsed": self._perf_info.parse_project_elapsed,
"patch_sources_elapsed": self._perf_info.patch_sources_elapsed,
"process_manifest_elapsed": (
self._perf_info.process_manifest_elapsed
),
"load_all_elapsed": self._perf_info.load_all_elapsed,
"is_partial_parse_enabled": (
self._perf_info.is_partial_parse_enabled
),
})
dbt.tracking.track_project_load(
{
"invocation_id": invocation_id,
"project_id": self.root_project.hashed_name(),
"path_count": self._perf_info.path_count,
"parse_project_elapsed": self._perf_info.parse_project_elapsed,
"patch_sources_elapsed": self._perf_info.patch_sources_elapsed,
"process_manifest_elapsed": (self._perf_info.process_manifest_elapsed),
"load_all_elapsed": self._perf_info.load_all_elapsed,
"is_partial_parse_enabled": (self._perf_info.is_partial_parse_enabled),
}
)
def parse_with_cache(
self,
@@ -220,8 +231,7 @@ class ManifestLoader:
) -> None:
parsers: List[Parser] = []
for cls in _parser_types:
parser = cls(self.results, project, self.root_project,
macro_manifest)
parser = cls(self.results, project, self.root_project, macro_manifest)
parsers.append(parser)
# per-project cache.
@@ -238,11 +248,13 @@ class ManifestLoader:
parser_path_count = parser_path_count + 1
if parser_path_count > 0:
project_parser_info.append(ParserInfo(
parser=parser.resource_type,
path_count=parser_path_count,
elapsed=time.perf_counter() - parser_start_timer
))
project_parser_info.append(
ParserInfo(
parser=parser.resource_type,
path_count=parser_path_count,
elapsed=time.perf_counter() - parser_start_timer,
)
)
total_path_count = total_path_count + parser_path_count
elapsed = time.perf_counter() - start_timer
@@ -250,12 +262,10 @@ class ManifestLoader:
project_name=project.project_name,
path_count=total_path_count,
elapsed=elapsed,
parsers=project_parser_info
parsers=project_parser_info,
)
self._perf_info.projects.append(project_info)
self._perf_info.path_count = (
self._perf_info.path_count + total_path_count
)
self._perf_info.path_count = self._perf_info.path_count + total_path_count
def load_only_macros(self) -> MacroManifest:
old_results = self.read_parse_results()
@@ -267,8 +277,7 @@ class ManifestLoader:
# make a manifest with just the macros to get the context
macro_manifest = MacroManifest(
macros=self.results.macros,
files=self.results.files
macros=self.results.macros, files=self.results.files
)
self.macro_hook(macro_manifest)
return macro_manifest
@@ -278,7 +287,7 @@ class ManifestLoader:
# if partial parse is enabled, load old results
old_results = self.read_parse_results()
if old_results is not None:
logger.debug('Got an acceptable cached parse result')
logger.debug("Got an acceptable cached parse result")
# store the macros & files from the adapter macro manifest
self.results.macros.update(macro_manifest.macros)
self.results.files.update(macro_manifest.files)
@@ -289,15 +298,12 @@ class ManifestLoader:
# parse a single project
self.parse_project(project, macro_manifest, old_results)
self._perf_info.parse_project_elapsed = (
time.perf_counter() - start_timer
)
self._perf_info.parse_project_elapsed = time.perf_counter() - start_timer
def write_parse_results(self):
path = os.path.join(self.root_project.target_path,
PARTIAL_PARSE_FILE_NAME)
path = os.path.join(self.root_project.target_path, PARTIAL_PARSE_FILE_NAME)
make_directory(self.root_project.target_path)
with open(path, 'wb') as fp:
with open(path, "wb") as fp:
pickle.dump(self.results, fp)
def matching_parse_results(self, result: ParseResult) -> bool:
@@ -307,31 +313,32 @@ class ManifestLoader:
try:
if result.dbt_version != __version__:
logger.debug(
'dbt version mismatch: {} != {}, cache invalidated'
.format(result.dbt_version, __version__)
"dbt version mismatch: {} != {}, cache invalidated".format(
result.dbt_version, __version__
)
)
return False
except AttributeError:
logger.debug('malformed result file, cache invalidated')
logger.debug("malformed result file, cache invalidated")
return False
valid = True
if self.results.vars_hash != result.vars_hash:
logger.debug('vars hash mismatch, cache invalidated')
logger.debug("vars hash mismatch, cache invalidated")
valid = False
if self.results.profile_hash != result.profile_hash:
logger.debug('profile hash mismatch, cache invalidated')
logger.debug("profile hash mismatch, cache invalidated")
valid = False
missing_keys = {
k for k in self.results.project_hashes
if k not in result.project_hashes
k for k in self.results.project_hashes if k not in result.project_hashes
}
if missing_keys:
logger.debug(
'project hash mismatch: values missing, cache invalidated: {}'
.format(missing_keys)
"project hash mismatch: values missing, cache invalidated: {}".format(
missing_keys
)
)
valid = False
@@ -340,9 +347,8 @@ class ManifestLoader:
old_value = result.project_hashes[key]
if new_value != old_value:
logger.debug(
'For key {}, hash mismatch ({} -> {}), cache '
'invalidated'
.format(key, old_value, new_value)
"For key {}, hash mismatch ({} -> {}), cache "
"invalidated".format(key, old_value, new_value)
)
valid = False
return valid
@@ -359,14 +365,13 @@ class ManifestLoader:
def read_parse_results(self) -> Optional[ParseResult]:
if not self._partial_parse_enabled():
logger.debug('Partial parsing not enabled')
logger.debug("Partial parsing not enabled")
return None
path = os.path.join(self.root_project.target_path,
PARTIAL_PARSE_FILE_NAME)
path = os.path.join(self.root_project.target_path, PARTIAL_PARSE_FILE_NAME)
if os.path.exists(path):
try:
with open(path, 'rb') as fp:
with open(path, "rb") as fp:
result: ParseResult = pickle.load(fp)
# keep this check inside the try/except in case something about
# the file has changed in weird ways, perhaps due to being a
@@ -375,9 +380,8 @@ class ManifestLoader:
return result
except Exception as exc:
logger.debug(
'Failed to load parsed file from disk at {}: {}'
.format(path, exc),
exc_info=True
"Failed to load parsed file from disk at {}: {}".format(path, exc),
exc_info=True,
)
return None
@@ -394,9 +398,7 @@ class ManifestLoader:
# list is created
start_patch = time.perf_counter()
sources = patch_sources(self.results, self.root_project)
self._perf_info.patch_sources_elapsed = (
time.perf_counter() - start_patch
)
self._perf_info.patch_sources_elapsed = time.perf_counter() - start_patch
disabled = []
for value in self.results.disabled.values():
disabled.extend(value)
@@ -421,9 +423,7 @@ class ManifestLoader:
start_process = time.perf_counter()
self.process_manifest(manifest)
self._perf_info.process_manifest_elapsed = (
time.perf_counter() - start_process
)
self._perf_info.process_manifest_elapsed = time.perf_counter() - start_process
return manifest
@@ -445,9 +445,7 @@ class ManifestLoader:
_check_manifest(manifest, root_config)
manifest.build_flat_graph()
loader._perf_info.load_all_elapsed = (
time.perf_counter() - start_load_all
)
loader._perf_info.load_all_elapsed = time.perf_counter() - start_load_all
loader.track_project_load()
@@ -465,8 +463,9 @@ class ManifestLoader:
return loader.load_only_macros()
def invalid_ref_fail_unless_test(node, target_model_name,
target_model_package, disabled):
def invalid_ref_fail_unless_test(
node, target_model_name, target_model_package, disabled
):
if node.resource_type == NodeType.Test:
msg = get_target_not_found_or_disabled_msg(
@@ -475,10 +474,7 @@ def invalid_ref_fail_unless_test(node, target_model_name,
if disabled:
logger.debug(warning_tag(msg))
else:
warn_or_error(
msg,
log_fmt=warning_tag('{}')
)
warn_or_error(msg, log_fmt=warning_tag("{}"))
else:
ref_target_not_found(
node,
@@ -488,9 +484,7 @@ def invalid_ref_fail_unless_test(node, target_model_name,
)
def invalid_source_fail_unless_test(
node, target_name, target_table_name, disabled
):
def invalid_source_fail_unless_test(node, target_name, target_table_name, disabled):
if node.resource_type == NodeType.Test:
msg = get_source_not_found_or_disabled_msg(
node, target_name, target_table_name, disabled
@@ -498,17 +492,9 @@ def invalid_source_fail_unless_test(
if disabled:
logger.debug(warning_tag(msg))
else:
warn_or_error(
msg,
log_fmt=warning_tag('{}')
)
warn_or_error(msg, log_fmt=warning_tag("{}"))
else:
source_target_not_found(
node,
target_name,
target_table_name,
disabled=disabled
)
source_target_not_found(node, target_name, target_table_name, disabled=disabled)
def _check_resource_uniqueness(
@@ -532,15 +518,11 @@ def _check_resource_uniqueness(
existing_node = names_resources.get(name)
if existing_node is not None:
dbt.exceptions.raise_duplicate_resource_name(
existing_node, node
)
dbt.exceptions.raise_duplicate_resource_name(existing_node, node)
existing_alias = alias_resources.get(full_node_name)
if existing_alias is not None:
dbt.exceptions.raise_ambiguous_alias(
existing_alias, node, full_node_name
)
dbt.exceptions.raise_ambiguous_alias(existing_alias, node, full_node_name)
names_resources[name] = node
alias_resources[full_node_name] = node
@@ -565,8 +547,7 @@ def _load_projects(config, paths):
project = config.new_project(path)
except dbt.exceptions.DbtProjectError as e:
raise dbt.exceptions.DbtProjectError(
'Failed to read package at {}: {}'
.format(path, e)
"Failed to read package at {}: {}".format(path, e)
)
else:
yield project.project_name, project
@@ -587,8 +568,7 @@ def _get_node_column(node, column_name):
DocsContextCallback = Callable[
[Union[ParsedNode, ParsedSourceDefinition]],
Dict[str, Any]
[Union[ParsedNode, ParsedSourceDefinition]], Dict[str, Any]
]
@@ -618,9 +598,7 @@ def _process_docs_for_source(
column.description = column_desc
def _process_docs_for_macro(
context: Dict[str, Any], macro: ParsedMacro
) -> None:
def _process_docs_for_macro(context: Dict[str, Any], macro: ParsedMacro) -> None:
macro.description = get_rendered(macro.description, context)
for arg in macro.arguments:
arg.description = get_rendered(arg.description, context)
@@ -682,7 +660,7 @@ def _process_refs_for_exposure(
target_model_package, target_model_name = ref
else:
raise dbt.exceptions.InternalException(
f'Refs should always be 1 or 2 arguments - got {len(ref)}'
f"Refs should always be 1 or 2 arguments - got {len(ref)}"
)
target_model = manifest.resolve_ref(
@@ -696,8 +674,10 @@ def _process_refs_for_exposure(
# This may raise. Even if it doesn't, we don't want to add
# this exposure to the graph b/c there is no destination exposure
invalid_ref_fail_unless_test(
exposure, target_model_name, target_model_package,
disabled=(isinstance(target_model, Disabled))
exposure,
target_model_name,
target_model_package,
disabled=(isinstance(target_model, Disabled)),
)
continue
@@ -723,7 +703,7 @@ def _process_refs_for_node(
target_model_package, target_model_name = ref
else:
raise dbt.exceptions.InternalException(
f'Refs should always be 1 or 2 arguments - got {len(ref)}'
f"Refs should always be 1 or 2 arguments - got {len(ref)}"
)
target_model = manifest.resolve_ref(
@@ -738,8 +718,10 @@ def _process_refs_for_node(
# this node to the graph b/c there is no destination node
node.config.enabled = False
invalid_ref_fail_unless_test(
node, target_model_name, target_model_package,
disabled=(isinstance(target_model, Disabled))
node,
target_model_name,
target_model_package,
disabled=(isinstance(target_model, Disabled)),
)
continue
@@ -777,7 +759,7 @@ def _process_sources_for_exposure(
exposure,
source_name,
table_name,
disabled=(isinstance(target_source, Disabled))
disabled=(isinstance(target_source, Disabled)),
)
continue
target_source_id = target_source.unique_id
@@ -804,7 +786,7 @@ def _process_sources_for_node(
node,
source_name,
table_name,
disabled=(isinstance(target_source, Disabled))
disabled=(isinstance(target_source, Disabled)),
)
continue
target_source_id = target_source.unique_id
@@ -835,13 +817,9 @@ def process_macro(
_process_docs_for_macro(ctx, macro)
def process_node(
config: RuntimeConfig, manifest: Manifest, node: ManifestNode
):
def process_node(config: RuntimeConfig, manifest: Manifest, node: ManifestNode):
_process_sources_for_node(
manifest, config.project_name, node
)
_process_sources_for_node(manifest, config.project_name, node)
_process_refs_for_node(manifest, config.project_name, node)
ctx = generate_runtime_docs(config, node, manifest, config.project_name)
_process_docs_for_node(ctx, node)

View File

@@ -6,9 +6,7 @@ from dbt.parser.search import FilesystemSearcher, FileBlock
class ModelParser(SimpleSQLParser[ParsedModelNode]):
def get_paths(self):
return FilesystemSearcher(
self.project, self.project.source_paths, '.sql'
)
return FilesystemSearcher(self.project, self.project.source_paths, ".sql")
def parse_from_dict(self, dct, validate=True) -> ParsedModelNode:
if validate:

View File

@@ -25,9 +25,13 @@ from dbt.contracts.graph.parsed import (
from dbt.contracts.graph.unparsed import SourcePatch
from dbt.contracts.util import Writable, Replaceable, MacroKey, SourceKey
from dbt.exceptions import (
raise_duplicate_resource_name, raise_duplicate_patch_name,
raise_duplicate_macro_patch_name, CompilationException, InternalException,
raise_compiler_error, raise_duplicate_source_patch_name
raise_duplicate_resource_name,
raise_duplicate_patch_name,
raise_duplicate_macro_patch_name,
CompilationException,
InternalException,
raise_compiler_error,
raise_duplicate_source_patch_name,
)
from dbt.node_types import NodeType
from dbt.ui import line_wrap_message
@@ -35,12 +39,10 @@ from dbt.version import __version__
# Parsers can return anything as long as it's a unique ID
ParsedValueType = TypeVar('ParsedValueType', bound=HasUniqueID)
ParsedValueType = TypeVar("ParsedValueType", bound=HasUniqueID)
def _check_duplicates(
value: HasUniqueID, src: Mapping[str, HasUniqueID]
):
def _check_duplicates(value: HasUniqueID, src: Mapping[str, HasUniqueID]):
if value.unique_id in src:
raise_duplicate_resource_name(value, src[value.unique_id])
@@ -86,9 +88,7 @@ class ParseResult(dbtClassMixin, Writable, Replaceable):
self.files[key] = source_file
return self.files[key]
def add_source(
self, source_file: SourceFile, source: UnpatchedSourceDefinition
):
def add_source(self, source_file: SourceFile, source: UnpatchedSourceDefinition):
# sources can't be overwritten!
_check_duplicates(source, self.sources)
self.sources[source.unique_id] = source
@@ -126,7 +126,7 @@ class ParseResult(dbtClassMixin, Writable, Replaceable):
# note that the line wrap eats newlines, so if you want newlines,
# this is the result :(
msg = line_wrap_message(
f'''\
f"""\
dbt found two macros named "{macro.name}" in the project
"{macro.package_name}".
@@ -137,8 +137,8 @@ class ParseResult(dbtClassMixin, Writable, Replaceable):
- {macro.original_file_path}
- {other_path}
''',
subtract=2
""",
subtract=2,
)
raise_compiler_error(msg)
@@ -150,18 +150,14 @@ class ParseResult(dbtClassMixin, Writable, Replaceable):
self.docs[doc.unique_id] = doc
self.get_file(source_file).docs.append(doc.unique_id)
def add_patch(
self, source_file: SourceFile, patch: ParsedNodePatch
) -> None:
def add_patch(self, source_file: SourceFile, patch: ParsedNodePatch) -> None:
# patches can't be overwritten
if patch.name in self.patches:
raise_duplicate_patch_name(patch, self.patches[patch.name])
self.patches[patch.name] = patch
self.get_file(source_file).patches.append(patch.name)
def add_macro_patch(
self, source_file: SourceFile, patch: ParsedMacroPatch
) -> None:
def add_macro_patch(self, source_file: SourceFile, patch: ParsedMacroPatch) -> None:
# macros are fully namespaced
key = (patch.package_name, patch.name)
if key in self.macro_patches:
@@ -169,9 +165,7 @@ class ParseResult(dbtClassMixin, Writable, Replaceable):
self.macro_patches[key] = patch
self.get_file(source_file).macro_patches.append(key)
def add_source_patch(
self, source_file: SourceFile, patch: SourcePatch
) -> None:
def add_source_patch(self, source_file: SourceFile, patch: SourcePatch) -> None:
# source patches must be unique
key = (patch.overrides, patch.name)
if key in self.source_patches:
@@ -186,11 +180,13 @@ class ParseResult(dbtClassMixin, Writable, Replaceable):
) -> List[CompileResultNode]:
if unique_id not in self.disabled:
raise InternalException(
'called _get_disabled with id={}, but it does not exist'
.format(unique_id)
"called _get_disabled with id={}, but it does not exist".format(
unique_id
)
)
return [
n for n in self.disabled[unique_id]
n
for n in self.disabled[unique_id]
if n.original_file_path == match_file.path.original_file_path
]
@@ -199,7 +195,7 @@ class ParseResult(dbtClassMixin, Writable, Replaceable):
node_id: str,
source_file: SourceFile,
old_file: SourceFile,
old_result: 'ParseResult',
old_result: "ParseResult",
) -> None:
"""Nodes are a special kind of complicated - there can be multiple
with the same name, as long as all but one are disabled.
@@ -224,14 +220,15 @@ class ParseResult(dbtClassMixin, Writable, Replaceable):
if not found:
raise CompilationException(
'Expected to find "{}" in cached "manifest.nodes" or '
'"manifest.disabled" based on cached file information: {}!'
.format(node_id, old_file)
'"manifest.disabled" based on cached file information: {}!'.format(
node_id, old_file
)
)
def sanitized_update(
self,
source_file: SourceFile,
old_result: 'ParseResult',
old_result: "ParseResult",
resource_type: NodeType,
) -> bool:
"""Perform a santized update. If the file can't be updated, invalidate
@@ -246,15 +243,11 @@ class ParseResult(dbtClassMixin, Writable, Replaceable):
self.add_doc(source_file, doc)
for macro_id in old_file.macros:
macro = _expect_value(
macro_id, old_result.macros, old_file, "macros"
)
macro = _expect_value(macro_id, old_result.macros, old_file, "macros")
self.add_macro(source_file, macro)
for source_id in old_file.sources:
source = _expect_value(
source_id, old_result.sources, old_file, "sources"
)
source = _expect_value(source_id, old_result.sources, old_file, "sources")
self.add_source(source_file, source)
# because we know this is how we _parsed_ the node, we can safely
@@ -265,7 +258,7 @@ class ParseResult(dbtClassMixin, Writable, Replaceable):
for node_id in old_file.nodes:
# cheat: look at the first part of the node ID and compare it to
# the parser resource type. On a mismatch, bail out.
if resource_type != node_id.split('.')[0]:
if resource_type != node_id.split(".")[0]:
continue
self._process_node(node_id, source_file, old_file, old_result)
@@ -277,9 +270,7 @@ class ParseResult(dbtClassMixin, Writable, Replaceable):
patched = False
for name in old_file.patches:
patch = _expect_value(
name, old_result.patches, old_file, "patches"
)
patch = _expect_value(name, old_result.patches, old_file, "patches")
self.add_patch(source_file, patch)
patched = True
if patched:
@@ -312,8 +303,8 @@ class ParseResult(dbtClassMixin, Writable, Replaceable):
return cls(FileHash.empty(), FileHash.empty(), {})
K_T = TypeVar('K_T')
V_T = TypeVar('V_T')
K_T = TypeVar("K_T")
V_T = TypeVar("V_T")
def _expect_value(
@@ -322,7 +313,6 @@ def _expect_value(
if key not in src:
raise CompilationException(
'Expected to find "{}" in cached "result.{}" based '
'on cached file information: {}!'
.format(key, name, old_file)
"on cached file information: {}!".format(key, name, old_file)
)
return src[key]

View File

@@ -38,11 +38,11 @@ class RPCCallParser(SimpleSQLParser[ParsedRPCNode]):
# we do it this way to make mypy happy
if not isinstance(block, RPCBlock):
raise InternalException(
'While parsing RPC calls, got an actual file block instead of '
'an RPC block: {}'.format(block)
"While parsing RPC calls, got an actual file block instead of "
"an RPC block: {}".format(block)
)
return os.path.join('rpc', block.name)
return os.path.join("rpc", block.name)
def parse_remote(self, sql: str, name: str) -> ParsedRPCNode:
source_file = SourceFile.remote(contents=sql)
@@ -53,8 +53,8 @@ class RPCCallParser(SimpleSQLParser[ParsedRPCNode]):
class RPCMacroParser(MacroParser):
def parse_remote(self, contents) -> Iterable[ParsedMacro]:
base = UnparsedMacro(
path='from remote system',
original_file_path='from remote system',
path="from remote system",
original_file_path="from remote system",
package_name=self.project.project_name,
raw_sql=contents,
root_path=self.project.project_root,

View File

@@ -3,7 +3,13 @@ import re
from copy import deepcopy
from dataclasses import dataclass
from typing import (
Generic, TypeVar, Dict, Any, Tuple, Optional, List,
Generic,
TypeVar,
Dict,
Any,
Tuple,
Optional,
List,
)
from dbt.clients.jinja import get_rendered, SCHEMA_TEST_KWARGS_NAME
@@ -25,7 +31,7 @@ def get_nice_schema_test_name(
flat_args = []
for arg_name in sorted(args):
# the model is already embedded in the name, so skip it
if arg_name == 'model':
if arg_name == "model":
continue
arg_val = args[arg_name]
@@ -38,17 +44,17 @@ def get_nice_schema_test_name(
flat_args.extend([str(part) for part in parts])
clean_flat_args = [re.sub('[^0-9a-zA-Z_]+', '_', arg) for arg in flat_args]
clean_flat_args = [re.sub("[^0-9a-zA-Z_]+", "_", arg) for arg in flat_args]
unique = "__".join(clean_flat_args)
cutoff = 32
if len(unique) <= cutoff:
label = unique
else:
label = hashlib.md5(unique.encode('utf-8')).hexdigest()
label = hashlib.md5(unique.encode("utf-8")).hexdigest()
filename = '{}_{}_{}'.format(test_type, test_name, label)
name = '{}_{}_{}'.format(test_type, test_name, unique)
filename = "{}_{}_{}".format(test_type, test_name, label)
name = "{}_{}_{}".format(test_type, test_name, unique)
return filename, name
@@ -65,19 +71,17 @@ class YamlBlock(FileBlock):
)
Testable = TypeVar(
'Testable', UnparsedNodeUpdate, UnpatchedSourceDefinition
)
Testable = TypeVar("Testable", UnparsedNodeUpdate, UnpatchedSourceDefinition)
ColumnTarget = TypeVar(
'ColumnTarget',
"ColumnTarget",
UnparsedNodeUpdate,
UnparsedAnalysisUpdate,
UnpatchedSourceDefinition,
)
Target = TypeVar(
'Target',
"Target",
UnparsedNodeUpdate,
UnparsedMacroUpdate,
UnparsedAnalysisUpdate,
@@ -103,9 +107,7 @@ class TargetBlock(YamlBlock, Generic[Target]):
return []
@classmethod
def from_yaml_block(
cls, src: YamlBlock, target: Target
) -> 'TargetBlock[Target]':
def from_yaml_block(cls, src: YamlBlock, target: Target) -> "TargetBlock[Target]":
return cls(
file=src.file,
data=src.data,
@@ -137,9 +139,7 @@ class TestBlock(TargetColumnsBlock[Testable], Generic[Testable]):
return self.target.quote_columns
@classmethod
def from_yaml_block(
cls, src: YamlBlock, target: Testable
) -> 'TestBlock[Testable]':
def from_yaml_block(cls, src: YamlBlock, target: Testable) -> "TestBlock[Testable]":
return cls(
file=src.file,
data=src.data,
@@ -160,7 +160,7 @@ class SchemaTestBlock(TestBlock[Testable], Generic[Testable]):
test: Dict[str, Any],
column_name: Optional[str],
tags: List[str],
) -> 'SchemaTestBlock':
) -> "SchemaTestBlock":
return cls(
file=src.file,
data=src.data,
@@ -179,13 +179,14 @@ class TestBuilder(Generic[Testable]):
- or it may not be namespaced (test)
"""
# The 'test_name' is used to find the 'macro' that implements the test
TEST_NAME_PATTERN = re.compile(
r'((?P<test_namespace>([a-zA-Z_][0-9a-zA-Z_]*))\.)?'
r'(?P<test_name>([a-zA-Z_][0-9a-zA-Z_]*))'
r"((?P<test_namespace>([a-zA-Z_][0-9a-zA-Z_]*))\.)?"
r"(?P<test_name>([a-zA-Z_][0-9a-zA-Z_]*))"
)
# map magic keys to default values
MODIFIER_ARGS = {'severity': 'ERROR', 'tags': []}
MODIFIER_ARGS = {"severity": "ERROR", "tags": []}
def __init__(
self,
@@ -197,25 +198,24 @@ class TestBuilder(Generic[Testable]):
) -> None:
test_name, test_args = self.extract_test_args(test, column_name)
self.args: Dict[str, Any] = test_args
if 'model' in self.args:
if "model" in self.args:
raise_compiler_error(
'Test arguments include "model", which is a reserved argument',
)
self.package_name: str = package_name
self.target: Testable = target
self.args['model'] = self.build_model_str()
self.args["model"] = self.build_model_str()
match = self.TEST_NAME_PATTERN.match(test_name)
if match is None:
raise_compiler_error(
'Test name string did not match expected pattern: {}'
.format(test_name)
"Test name string did not match expected pattern: {}".format(test_name)
)
groups = match.groupdict()
self.name: str = groups['test_name']
self.namespace: str = groups['test_namespace']
self.name: str = groups["test_name"]
self.namespace: str = groups["test_namespace"]
self.modifiers: Dict[str, Any] = {}
for key, default in self.MODIFIER_ARGS.items():
value = self.args.pop(key, default)
@@ -237,57 +237,52 @@ class TestBuilder(Generic[Testable]):
def extract_test_args(test, name=None) -> Tuple[str, Dict[str, Any]]:
if not isinstance(test, dict):
raise_compiler_error(
'test must be dict or str, got {} (value {})'.format(
type(test), test
)
"test must be dict or str, got {} (value {})".format(type(test), test)
)
test = list(test.items())
if len(test) != 1:
raise_compiler_error(
'test definition dictionary must have exactly one key, got'
' {} instead ({} keys)'.format(test, len(test))
"test definition dictionary must have exactly one key, got"
" {} instead ({} keys)".format(test, len(test))
)
test_name, test_args = test[0]
if not isinstance(test_args, dict):
raise_compiler_error(
'test arguments must be dict, got {} (value {})'.format(
"test arguments must be dict, got {} (value {})".format(
type(test_args), test_args
)
)
if not isinstance(test_name, str):
raise_compiler_error(
'test name must be a str, got {} (value {})'.format(
"test name must be a str, got {} (value {})".format(
type(test_name), test_name
)
)
test_args = deepcopy(test_args)
if name is not None:
test_args['column_name'] = name
test_args["column_name"] = name
return test_name, test_args
def severity(self) -> str:
return self.modifiers.get('severity', 'ERROR').upper()
return self.modifiers.get("severity", "ERROR").upper()
def tags(self) -> List[str]:
tags = self.modifiers.get('tags', [])
tags = self.modifiers.get("tags", [])
if isinstance(tags, str):
tags = [tags]
if not isinstance(tags, list):
raise_compiler_error(
f'got {tags} ({type(tags)}) for tags, expected a list of '
f'strings'
f"got {tags} ({type(tags)}) for tags, expected a list of " f"strings"
)
for tag in tags:
if not isinstance(tag, str):
raise_compiler_error(
f'got {tag} ({type(tag)}) for tag, expected a str'
)
raise_compiler_error(f"got {tag} ({type(tag)}) for tag, expected a str")
return tags[:]
def macro_name(self) -> str:
macro_name = 'test_{}'.format(self.name)
macro_name = "test_{}".format(self.name)
if self.namespace is not None:
macro_name = "{}.{}".format(self.namespace, macro_name)
return macro_name
@@ -296,11 +291,11 @@ class TestBuilder(Generic[Testable]):
if isinstance(self.target, UnparsedNodeUpdate):
name = self.name
elif isinstance(self.target, UnpatchedSourceDefinition):
name = 'source_' + self.name
name = "source_" + self.name
else:
raise self._bad_type()
if self.namespace is not None:
name = '{}_{}'.format(self.namespace, name)
name = "{}_{}".format(self.namespace, name)
return get_nice_schema_test_name(name, self.target.name, self.args)
# this is the 'raw_sql' that's used in 'render_update' and execution

View File

@@ -2,9 +2,7 @@ import itertools
import os
from abc import ABCMeta, abstractmethod
from typing import (
Iterable, Dict, Any, Union, List, Optional, Generic, TypeVar, Type
)
from typing import Iterable, Dict, Any, Union, List, Optional, Generic, TypeVar, Type
from dbt.dataclass_schema import ValidationError, dbtClassMixin
@@ -20,9 +18,7 @@ from dbt.context.context_config import (
)
from dbt.context.configured import generate_schema_yml
from dbt.context.target import generate_target_context
from dbt.context.providers import (
generate_parse_exposure, generate_test_context
)
from dbt.context.providers import generate_parse_exposure, generate_test_context
from dbt.context.macro_resolver import MacroResolver
from dbt.contracts.files import FileHash
from dbt.contracts.graph.manifest import SourceFile
@@ -50,20 +46,26 @@ from dbt.contracts.graph.unparsed import (
UnparsedSourceDefinition,
)
from dbt.exceptions import (
validator_error_message, JSONValidationException,
raise_invalid_schema_yml_version, ValidationException,
CompilationException, warn_or_error, InternalException
validator_error_message,
JSONValidationException,
raise_invalid_schema_yml_version,
ValidationException,
CompilationException,
warn_or_error,
InternalException,
)
from dbt.node_types import NodeType
from dbt.parser.base import SimpleParser
from dbt.parser.search import FileBlock, FilesystemSearcher
from dbt.parser.schema_test_builders import (
TestBuilder, SchemaTestBlock, TargetBlock, YamlBlock,
TestBlock, Testable
)
from dbt.utils import (
get_pseudo_test_path, coerce_dict_str
TestBuilder,
SchemaTestBlock,
TargetBlock,
YamlBlock,
TestBlock,
Testable,
)
from dbt.utils import get_pseudo_test_path, coerce_dict_str
UnparsedSchemaYaml = Union[
@@ -80,19 +82,17 @@ def error_context(
path: str,
key: str,
data: Any,
cause: Union[str, ValidationException, JSONValidationException]
cause: Union[str, ValidationException, JSONValidationException],
) -> str:
"""Provide contextual information about an error while parsing
"""
"""Provide contextual information about an error while parsing"""
if isinstance(cause, str):
reason = cause
elif isinstance(cause, ValidationError):
reason = validator_error_message(cause)
else:
reason = cause.msg
return (
'Invalid {key} config given in {path} @ {key}: {data} - {reason}'
.format(key=key, path=path, data=data, reason=reason)
return "Invalid {key} config given in {path} @ {key}: {data} - {reason}".format(
key=key, path=path, data=data, reason=reason
)
@@ -110,7 +110,7 @@ class ParserRef:
meta: Dict[str, Any],
):
tags: List[str] = []
tags.extend(getattr(column, 'tags', ()))
tags.extend(getattr(column, "tags", ()))
quote: Optional[bool]
if isinstance(column, UnparsedColumn):
quote = column.quote
@@ -123,13 +123,11 @@ class ParserRef:
meta=meta,
tags=tags,
quote=quote,
_extra=column.extra
_extra=column.extra,
)
@classmethod
def from_target(
cls, target: Union[HasColumnDocs, HasColumnTests]
) -> 'ParserRef':
def from_target(cls, target: Union[HasColumnDocs, HasColumnTests]) -> "ParserRef":
refs = cls()
for column in target.columns:
description = column.description
@@ -142,7 +140,7 @@ class ParserRef:
def _trimmed(inp: str) -> str:
if len(inp) < 50:
return inp
return inp[:44] + '...' + inp[-3:]
return inp[:44] + "..." + inp[-3:]
def merge_freshness(
@@ -158,21 +156,20 @@ def merge_freshness(
class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
def __init__(
self, results, project, root_project, macro_manifest,
self,
results,
project,
root_project,
macro_manifest,
) -> None:
super().__init__(results, project, root_project, macro_manifest)
all_v_2 = (
self.root_project.config_version == 2 and
self.project.config_version == 2
self.root_project.config_version == 2 and self.project.config_version == 2
)
if all_v_2:
ctx = generate_schema_yml(
self.root_project, self.project.project_name
)
ctx = generate_schema_yml(self.root_project, self.project.project_name)
else:
ctx = generate_target_context(
self.root_project, self.root_project.cli_vars
)
ctx = generate_target_context(self.root_project, self.root_project.cli_vars)
self.raw_renderer = SchemaYamlRenderer(ctx)
@@ -182,7 +179,7 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
self.macro_resolver = MacroResolver(
self.macro_manifest.macros,
self.root_project.project_name,
internal_package_names
internal_package_names,
)
@classmethod
@@ -197,65 +194,55 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
def get_paths(self):
# TODO: In order to support this, make FilesystemSearcher accept a list
# of file patterns. eg: ['.yml', '.yaml']
yaml_files = list(FilesystemSearcher(
self.project, self.project.all_source_paths, '.yaml'
))
yaml_files = list(
FilesystemSearcher(self.project, self.project.all_source_paths, ".yaml")
)
if yaml_files:
warn_or_error(
'A future version of dbt will parse files with both'
' .yml and .yaml file extensions. dbt found'
f' {len(yaml_files)} files with .yaml extensions in'
' your dbt project. To avoid errors when upgrading'
' to a future release, either remove these files from'
' your dbt project, or change their extensions.'
"A future version of dbt will parse files with both"
" .yml and .yaml file extensions. dbt found"
f" {len(yaml_files)} files with .yaml extensions in"
" your dbt project. To avoid errors when upgrading"
" to a future release, either remove these files from"
" your dbt project, or change their extensions."
)
return FilesystemSearcher(
self.project, self.project.all_source_paths, '.yml'
)
return FilesystemSearcher(self.project, self.project.all_source_paths, ".yml")
def parse_from_dict(self, dct, validate=True) -> ParsedSchemaTestNode:
if validate:
ParsedSchemaTestNode.validate(dct)
return ParsedSchemaTestNode.from_dict(dct)
def _check_format_version(
self, yaml: YamlBlock
) -> None:
def _check_format_version(self, yaml: YamlBlock) -> None:
path = yaml.path.relative_path
if 'version' not in yaml.data:
raise_invalid_schema_yml_version(path, 'no version is specified')
if "version" not in yaml.data:
raise_invalid_schema_yml_version(path, "no version is specified")
version = yaml.data['version']
version = yaml.data["version"]
# if it's not an integer, the version is malformed, or not
# set. Either way, only 'version: 2' is supported.
if not isinstance(version, int):
raise_invalid_schema_yml_version(
path, 'the version is not an integer'
)
raise_invalid_schema_yml_version(path, "the version is not an integer")
if version != 2:
raise_invalid_schema_yml_version(
path, 'version {} is not supported'.format(version)
path, "version {} is not supported".format(version)
)
def _yaml_from_file(
self, source_file: SourceFile
) -> Optional[Dict[str, Any]]:
"""If loading the yaml fails, raise an exception.
"""
def _yaml_from_file(self, source_file: SourceFile) -> Optional[Dict[str, Any]]:
"""If loading the yaml fails, raise an exception."""
path: str = source_file.path.relative_path
try:
return load_yaml_text(source_file.contents)
except ValidationException as e:
reason = validator_error_message(e)
raise CompilationException(
'Error reading {}: {} - {}'
.format(self.project.project_name, path, reason)
"Error reading {}: {} - {}".format(
self.project.project_name, path, reason
)
)
return None
def parse_column_tests(
self, block: TestBlock, column: UnparsedColumn
) -> None:
def parse_column_tests(self, block: TestBlock, column: UnparsedColumn) -> None:
if not column.tests:
return
@@ -267,9 +254,7 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
if rendered:
generator = ContextConfigGenerator(self.root_project)
else:
generator = UnrenderedConfigGenerator(
self.root_project
)
generator = UnrenderedConfigGenerator(self.root_project)
return generator.calculate_node_config(
config_calls=[],
@@ -284,16 +269,14 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
relation_cls = adapter.Relation
return str(relation_cls.create_from(self.root_project, node))
def parse_source(
self, target: UnpatchedSourceDefinition
) -> ParsedSourceDefinition:
def parse_source(self, target: UnpatchedSourceDefinition) -> ParsedSourceDefinition:
source = target.source
table = target.table
refs = ParserRef.from_target(table)
unique_id = target.unique_id
description = table.description or ''
description = table.description or ""
meta = table.meta or {}
source_description = source.description or ''
source_description = source.description or ""
loaded_at_field = table.loaded_at_field or source.loaded_at_field
freshness = merge_freshness(source.freshness, table.freshness)
@@ -316,8 +299,8 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
if not isinstance(config, SourceConfig):
raise InternalException(
f'Calculated a {type(config)} for a source, but expected '
f'a SourceConfig'
f"Calculated a {type(config)} for a source, but expected "
f"a SourceConfig"
)
default_database = self.root_project.credentials.database
@@ -369,23 +352,23 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
) -> ParsedSchemaTestNode:
dct = {
'alias': name,
'schema': self.default_schema,
'database': self.default_database,
'fqn': fqn,
'name': name,
'root_path': self.project.project_root,
'resource_type': self.resource_type,
'tags': tags,
'path': path,
'original_file_path': target.original_file_path,
'package_name': self.project.project_name,
'raw_sql': raw_sql,
'unique_id': self.generate_unique_id(name),
'config': self.config_dict(config),
'test_metadata': test_metadata,
'column_name': column_name,
'checksum': FileHash.empty().to_dict(omit_none=True),
"alias": name,
"schema": self.default_schema,
"database": self.default_database,
"fqn": fqn,
"name": name,
"root_path": self.project.project_root,
"resource_type": self.resource_type,
"tags": tags,
"path": path,
"original_file_path": target.original_file_path,
"package_name": self.project.project_name,
"raw_sql": raw_sql,
"unique_id": self.generate_unique_id(name),
"config": self.config_dict(config),
"test_metadata": test_metadata,
"column_name": column_name,
"checksum": FileHash.empty().to_dict(omit_none=True),
}
try:
ParsedSchemaTestNode.validate(dct)
@@ -424,18 +407,20 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
)
except CompilationException as exc:
context = _trimmed(str(target))
msg = (
'Invalid test config given in {}:'
'\n\t{}\n\t@: {}'
.format(target.original_file_path, exc.msg, context)
msg = "Invalid test config given in {}:" "\n\t{}\n\t@: {}".format(
target.original_file_path, exc.msg, context
)
raise CompilationException(msg) from exc
original_name = os.path.basename(target.original_file_path)
compiled_path = get_pseudo_test_path(
builder.compiled_name, original_name, 'schema_test',
builder.compiled_name,
original_name,
"schema_test",
)
fqn_path = get_pseudo_test_path(
builder.fqn_name, original_name, 'schema_test',
builder.fqn_name,
original_name,
"schema_test",
)
# the fqn for tests actually happens in the test target's name, which
# is not necessarily this package's name
@@ -445,13 +430,13 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
config = self.initial_config(fqn)
metadata = {
'namespace': builder.namespace,
'name': builder.name,
'kwargs': builder.args,
"namespace": builder.namespace,
"name": builder.name,
"kwargs": builder.args,
}
tags = sorted(set(itertools.chain(tags, builder.tags())))
if 'schema' not in tags:
tags.append('schema')
if "schema" not in tags:
tags.append("schema")
node = self.create_test_node(
target=target,
@@ -477,15 +462,15 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
# parsing to avoid jinja overhead.
def render_test_update(self, node, config, builder):
macro_unique_id = self.macro_resolver.get_macro_id(
node.package_name, 'test_' + builder.name)
node.package_name, "test_" + builder.name
)
# Add the depends_on here so we can limit the macros added
# to the context in rendering processing
node.depends_on.add_macro(macro_unique_id)
if (macro_unique_id in
['macro.dbt.test_not_null', 'macro.dbt.test_unique']):
if macro_unique_id in ["macro.dbt.test_not_null", "macro.dbt.test_unique"]:
self.update_parsed_node(node, config)
node.unrendered_config['severity'] = builder.severity()
node.config['severity'] = builder.severity()
node.unrendered_config["severity"] = builder.severity()
node.config["severity"] = builder.severity()
# source node tests are processed at patch_source time
if isinstance(builder.target, UnpatchedSourceDefinition):
sources = [builder.target.fqn[-2], builder.target.fqn[-1]]
@@ -496,15 +481,16 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
try:
# make a base context that doesn't have the magic kwargs field
context = generate_test_context(
node, self.root_project, self.macro_manifest, config,
node,
self.root_project,
self.macro_manifest,
config,
self.macro_resolver,
)
# update with rendered test kwargs (which collects any refs)
add_rendered_test_kwargs(context, node, capture_macros=True)
# the parsed node is not rendered in the native context.
get_rendered(
node.raw_sql, context, node, capture_macros=True
)
get_rendered(node.raw_sql, context, node, capture_macros=True)
self.update_parsed_node(node, config)
except ValidationError as exc:
# we got a ValidationError - probably bad types in config()
@@ -522,9 +508,8 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
column_name = None
else:
column_name = column.name
should_quote = (
column.quote or
(column.quote is None and target.quote_columns)
should_quote = column.quote or (
column.quote is None and target.quote_columns
)
if should_quote:
column_name = get_adapter(self.root_project).quote(column_name)
@@ -535,10 +520,7 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
tags = list(itertools.chain.from_iterable(tags_sources))
node = self._parse_generic_test(
target=target,
test=test,
tags=tags,
column_name=column_name
target=target, test=test, tags=tags, column_name=column_name
)
# we can't go through result.add_node - no file... instead!
if node.config.enabled:
@@ -562,7 +544,9 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
return node
def render_with_context(
self, node: ParsedSchemaTestNode, config: ContextConfig,
self,
node: ParsedSchemaTestNode,
config: ContextConfig,
) -> None:
"""Given the parsed node and a ContextConfig to use during
parsing, collect all the refs that might be squirreled away in the test
@@ -574,9 +558,7 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
add_rendered_test_kwargs(context, node, capture_macros=True)
# the parsed node is not rendered in the native context.
get_rendered(
node.raw_sql, context, node, capture_macros=True
)
get_rendered(node.raw_sql, context, node, capture_macros=True)
def parse_test(
self,
@@ -592,9 +574,8 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
column_tags: List[str] = []
else:
column_name = column.name
should_quote = (
column.quote or
(column.quote is None and target_block.quote_columns)
should_quote = column.quote or (
column.quote is None and target_block.quote_columns
)
if should_quote:
column_name = get_adapter(self.root_project).quote(column_name)
@@ -632,8 +613,8 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
dct = self.raw_renderer.render_data(dct)
except CompilationException as exc:
raise CompilationException(
f'Failed to render {block.path.original_file_path} from '
f'project {self.project.project_name}: {exc}'
f"Failed to render {block.path.original_file_path} from "
f"project {self.project.project_name}: {exc}"
) from exc
# contains the FileBlock and the data (dictionary)
@@ -649,66 +630,57 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
# NonSourceParser.parse(), TestablePatchParser is a variety of
# NodePatchParser
if 'models' in dct:
parser = TestablePatchParser(self, yaml_block, 'models')
if "models" in dct:
parser = TestablePatchParser(self, yaml_block, "models")
for test_block in parser.parse():
self.parse_tests(test_block)
# NonSourceParser.parse()
if 'seeds' in dct:
parser = TestablePatchParser(self, yaml_block, 'seeds')
if "seeds" in dct:
parser = TestablePatchParser(self, yaml_block, "seeds")
for test_block in parser.parse():
self.parse_tests(test_block)
# NonSourceParser.parse()
if 'snapshots' in dct:
parser = TestablePatchParser(self, yaml_block, 'snapshots')
if "snapshots" in dct:
parser = TestablePatchParser(self, yaml_block, "snapshots")
for test_block in parser.parse():
self.parse_tests(test_block)
# This parser uses SourceParser.parse() which doesn't return
# any test blocks. Source tests are handled at a later point
# in the process.
if 'sources' in dct:
parser = SourceParser(self, yaml_block, 'sources')
if "sources" in dct:
parser = SourceParser(self, yaml_block, "sources")
parser.parse()
# NonSourceParser.parse()
if 'macros' in dct:
parser = MacroPatchParser(self, yaml_block, 'macros')
if "macros" in dct:
parser = MacroPatchParser(self, yaml_block, "macros")
for test_block in parser.parse():
self.parse_tests(test_block)
# NonSourceParser.parse()
if 'analyses' in dct:
parser = AnalysisPatchParser(self, yaml_block, 'analyses')
if "analyses" in dct:
parser = AnalysisPatchParser(self, yaml_block, "analyses")
for test_block in parser.parse():
self.parse_tests(test_block)
# parse exposures
if 'exposures' in dct:
if "exposures" in dct:
self.parse_exposures(yaml_block)
Parsed = TypeVar(
'Parsed',
UnpatchedSourceDefinition, ParsedNodePatch, ParsedMacroPatch
)
NodeTarget = TypeVar(
'NodeTarget',
UnparsedNodeUpdate, UnparsedAnalysisUpdate
)
Parsed = TypeVar("Parsed", UnpatchedSourceDefinition, ParsedNodePatch, ParsedMacroPatch)
NodeTarget = TypeVar("NodeTarget", UnparsedNodeUpdate, UnparsedAnalysisUpdate)
NonSourceTarget = TypeVar(
'NonSourceTarget',
UnparsedNodeUpdate, UnparsedAnalysisUpdate, UnparsedMacroUpdate
"NonSourceTarget", UnparsedNodeUpdate, UnparsedAnalysisUpdate, UnparsedMacroUpdate
)
# abstract base class (ABCMeta)
class YamlReader(metaclass=ABCMeta):
def __init__(
self, schema_parser: SchemaParser, yaml: YamlBlock, key: str
) -> None:
def __init__(self, schema_parser: SchemaParser, yaml: YamlBlock, key: str) -> None:
self.schema_parser = schema_parser
# key: models, seeds, snapshots, sources, macros,
# analyses, exposures
@@ -738,8 +710,9 @@ class YamlReader(metaclass=ABCMeta):
data = self.yaml.data.get(self.key, [])
if not isinstance(data, list):
raise CompilationException(
'{} must be a list, got {} instead: ({})'
.format(self.key, type(data), _trimmed(str(data)))
"{} must be a list, got {} instead: ({})".format(
self.key, type(data), _trimmed(str(data))
)
)
path = self.yaml.path.original_file_path
@@ -751,7 +724,7 @@ class YamlReader(metaclass=ABCMeta):
yield entry
else:
msg = error_context(
path, self.key, data, 'expected a dict with string keys'
path, self.key, data, "expected a dict with string keys"
)
raise CompilationException(msg)
@@ -759,10 +732,10 @@ class YamlReader(metaclass=ABCMeta):
class YamlDocsReader(YamlReader):
@abstractmethod
def parse(self) -> List[TestBlock]:
raise NotImplementedError('parse is abstract')
raise NotImplementedError("parse is abstract")
T = TypeVar('T', bound=dbtClassMixin)
T = TypeVar("T", bound=dbtClassMixin)
class SourceParser(YamlDocsReader):
@@ -779,13 +752,11 @@ class SourceParser(YamlDocsReader):
def parse(self) -> List[TestBlock]:
# get a verified list of dicts for the key handled by this parser
for data in self.get_key_dicts():
data = self.project.credentials.translate_aliases(
data, recurse=True
)
data = self.project.credentials.translate_aliases(data, recurse=True)
is_override = 'overrides' in data
is_override = "overrides" in data
if is_override:
data['path'] = self.yaml.path.original_file_path
data["path"] = self.yaml.path.original_file_path
patch = self._target_from_dict(SourcePatch, data)
self.results.add_source_patch(self.yaml.file, patch)
else:
@@ -797,10 +768,9 @@ class SourceParser(YamlDocsReader):
original_file_path = self.yaml.path.original_file_path
fqn_path = self.yaml.path.relative_path
for table in source.tables:
unique_id = '.'.join([
NodeType.Source, self.project.project_name,
source.name, table.name
])
unique_id = ".".join(
[NodeType.Source, self.project.project_name, source.name, table.name]
)
# the FQN is project name / path elements /source_name /table_name
fqn = self.schema_parser.get_fqn_prefix(fqn_path)
@@ -825,17 +795,15 @@ class SourceParser(YamlDocsReader):
class NonSourceParser(YamlDocsReader, Generic[NonSourceTarget, Parsed]):
@abstractmethod
def _target_type(self) -> Type[NonSourceTarget]:
raise NotImplementedError('_target_type not implemented')
raise NotImplementedError("_target_type not implemented")
@abstractmethod
def get_block(self, node: NonSourceTarget) -> TargetBlock:
raise NotImplementedError('get_block is abstract')
raise NotImplementedError("get_block is abstract")
@abstractmethod
def parse_patch(
self, block: TargetBlock[NonSourceTarget], refs: ParserRef
) -> None:
raise NotImplementedError('parse_patch is abstract')
def parse_patch(self, block: TargetBlock[NonSourceTarget], refs: ParserRef) -> None:
raise NotImplementedError("parse_patch is abstract")
def parse(self) -> List[TestBlock]:
node: NonSourceTarget
@@ -874,11 +842,13 @@ class NonSourceParser(YamlDocsReader, Generic[NonSourceTarget, Parsed]):
for data in key_dicts:
# add extra data to each dict. This updates the dicts
# in the parser yaml
data.update({
'original_file_path': path,
'yaml_key': self.key,
'package_name': self.project.project_name,
})
data.update(
{
"original_file_path": path,
"yaml_key": self.key,
"package_name": self.project.project_name,
}
)
try:
# target_type: UnparsedNodeUpdate, UnparsedAnalysisUpdate,
# or UnparsedMacroUpdate
@@ -892,12 +862,9 @@ class NonSourceParser(YamlDocsReader, Generic[NonSourceTarget, Parsed]):
class NodePatchParser(
NonSourceParser[NodeTarget, ParsedNodePatch],
Generic[NodeTarget]
NonSourceParser[NodeTarget, ParsedNodePatch], Generic[NodeTarget]
):
def parse_patch(
self, block: TargetBlock[NodeTarget], refs: ParserRef
) -> None:
def parse_patch(self, block: TargetBlock[NodeTarget], refs: ParserRef) -> None:
result = ParsedNodePatch(
name=block.target.name,
original_file_path=block.target.original_file_path,
@@ -958,7 +925,7 @@ class ExposureParser(YamlReader):
def parse_exposure(self, unparsed: UnparsedExposure) -> ParsedExposure:
package_name = self.project.project_name
unique_id = f'{NodeType.Exposure}.{package_name}.{unparsed.name}'
unique_id = f"{NodeType.Exposure}.{package_name}.{unparsed.name}"
path = self.yaml.path.relative_path
fqn = self.schema_parser.get_fqn_prefix(path)
@@ -984,12 +951,10 @@ class ExposureParser(YamlReader):
self.schema_parser.macro_manifest,
package_name,
)
depends_on_jinja = '\n'.join(
'{{ ' + line + '}}' for line in unparsed.depends_on
)
get_rendered(
depends_on_jinja, ctx, parsed, capture_macros=True
depends_on_jinja = "\n".join(
"{{ " + line + "}}" for line in unparsed.depends_on
)
get_rendered(depends_on_jinja, ctx, parsed, capture_macros=True)
# parsed now has a populated refs/sources
return parsed

View File

@@ -1,8 +1,6 @@
import os
from dataclasses import dataclass
from typing import (
List, Callable, Iterable, Set, Union, Iterator, TypeVar, Generic
)
from typing import List, Callable, Iterable, Set, Union, Iterator, TypeVar, Generic
from dbt.clients.jinja import extract_toplevel_blocks, BlockTag
from dbt.clients.system import find_matching
@@ -72,13 +70,13 @@ class FilesystemSearcher(Iterable[FilePath]):
root = self.project.project_root
for result in find_matching(root, self.relative_dirs, ext):
if 'searched_path' not in result or 'relative_path' not in result:
if "searched_path" not in result or "relative_path" not in result:
raise InternalException(
'Invalid result from find_matching: {}'.format(result)
"Invalid result from find_matching: {}".format(result)
)
file_match = FilePath(
searched_path=result['searched_path'],
relative_path=result['relative_path'],
searched_path=result["searched_path"],
relative_path=result["relative_path"],
project_root=root,
)
yield file_match
@@ -86,7 +84,7 @@ class FilesystemSearcher(Iterable[FilePath]):
Block = Union[BlockContents, FullBlock]
BlockSearchResult = TypeVar('BlockSearchResult', BlockContents, FullBlock)
BlockSearchResult = TypeVar("BlockSearchResult", BlockContents, FullBlock)
BlockSearchResultFactory = Callable[[SourceFile, BlockTag], BlockSearchResult]
@@ -96,7 +94,7 @@ class BlockSearcher(Generic[BlockSearchResult], Iterable[BlockSearchResult]):
self,
source: List[FileBlock],
allowed_blocks: Set[str],
source_tag_factory: BlockSearchResultFactory
source_tag_factory: BlockSearchResultFactory,
) -> None:
self.source = source
self.allowed_blocks = allowed_blocks
@@ -107,7 +105,7 @@ class BlockSearcher(Generic[BlockSearchResult], Iterable[BlockSearchResult]):
blocks = extract_toplevel_blocks(
source_file.contents,
allowed_blocks=self.allowed_blocks,
collect_raw_data=False
collect_raw_data=False,
)
# this makes mypy happy, and this is an invariant we really need
for block in blocks:

View File

@@ -8,9 +8,7 @@ from dbt.parser.search import FileBlock, FilesystemSearcher
class SeedParser(SimpleSQLParser[ParsedSeedNode]):
def get_paths(self):
return FilesystemSearcher(
self.project, self.project.data_paths, '.csv'
)
return FilesystemSearcher(self.project, self.project.data_paths, ".csv")
def parse_from_dict(self, dct, validate=True) -> ParsedSeedNode:
if validate:
@@ -30,9 +28,7 @@ class SeedParser(SimpleSQLParser[ParsedSeedNode]):
) -> None:
"""Seeds don't need to do any rendering."""
def load_file(
self, match: FilePath, *, set_contents: bool = False
) -> SourceFile:
def load_file(self, match: FilePath, *, set_contents: bool = False) -> SourceFile:
if match.seed_too_large():
# We don't want to calculate a hash of this file. Use the path.
return SourceFile.big_seed(match)

View File

@@ -3,27 +3,22 @@ from typing import List
from dbt.dataclass_schema import ValidationError
from dbt.contracts.graph.parsed import (
IntermediateSnapshotNode, ParsedSnapshotNode
)
from dbt.exceptions import (
CompilationException, validator_error_message
)
from dbt.contracts.graph.parsed import IntermediateSnapshotNode, ParsedSnapshotNode
from dbt.exceptions import CompilationException, validator_error_message
from dbt.node_types import NodeType
from dbt.parser.base import SQLParser
from dbt.parser.search import (
FilesystemSearcher, BlockContents, BlockSearcher, FileBlock
FilesystemSearcher,
BlockContents,
BlockSearcher,
FileBlock,
)
from dbt.utils import split_path
class SnapshotParser(
SQLParser[IntermediateSnapshotNode, ParsedSnapshotNode]
):
class SnapshotParser(SQLParser[IntermediateSnapshotNode, ParsedSnapshotNode]):
def get_paths(self):
return FilesystemSearcher(
self.project, self.project.snapshot_paths, '.sql'
)
return FilesystemSearcher(self.project, self.project.snapshot_paths, ".sql")
def parse_from_dict(self, dct, validate=True) -> IntermediateSnapshotNode:
if validate:
@@ -78,7 +73,7 @@ class SnapshotParser(
def parse_file(self, file_block: FileBlock) -> None:
blocks = BlockSearcher(
source=[file_block],
allowed_blocks={'snapshot'},
allowed_blocks={"snapshot"},
source_tag_factory=BlockContents,
)
for block in blocks:

View File

@@ -34,8 +34,7 @@ class SourcePatcher:
self.results = results
self.root_project = root_project
self.macro_manifest = MacroManifest(
macros=self.results.macros,
files=self.results.files
macros=self.results.macros, files=self.results.files
)
self.schema_parsers: Dict[str, SchemaParser] = {}
self.patches_used: Dict[SourceKey, Set[str]] = {}
@@ -65,9 +64,7 @@ class SourcePatcher:
source = UnparsedSourceDefinition.from_dict(source_dct)
table = UnparsedSourceTableDefinition.from_dict(table_dct)
return unpatched.replace(
source=source, table=table, patch_path=patch_path
)
return unpatched.replace(source=source, table=table, patch_path=patch_path)
def parse_source_docs(self, block: UnpatchedSourceDefinition) -> ParserRef:
refs = ParserRef()
@@ -78,7 +75,7 @@ class SourcePatcher:
refs.add(column, description, data_type, meta)
return refs
def get_schema_parser_for(self, package_name: str) -> 'SchemaParser':
def get_schema_parser_for(self, package_name: str) -> "SchemaParser":
if package_name in self.schema_parsers:
schema_parser = self.schema_parsers[package_name]
else:
@@ -157,31 +154,28 @@ class SourcePatcher:
if unused_tables:
msg = self.get_unused_msg(unused_tables)
warn_or_error(msg, log_fmt=ui.warning_tag('{}'))
warn_or_error(msg, log_fmt=ui.warning_tag("{}"))
def get_unused_msg(
self,
unused_tables: Dict[SourceKey, Optional[Set[str]]],
) -> str:
msg = [
'During parsing, dbt encountered source overrides that had no '
'target:',
"During parsing, dbt encountered source overrides that had no " "target:",
]
for key, table_names in unused_tables.items():
patch = self.results.source_patches[key]
patch_name = f'{patch.overrides}.{patch.name}'
patch_name = f"{patch.overrides}.{patch.name}"
if table_names is None:
msg.append(
f' - Source {patch_name} (in {patch.path})'
)
msg.append(f" - Source {patch_name} (in {patch.path})")
else:
for table_name in sorted(table_names):
msg.append(
f' - Source table {patch_name}.{table_name} '
f'(in {patch.path})'
f" - Source table {patch_name}.{table_name} "
f"(in {patch.path})"
)
msg.append('')
return '\n'.join(msg)
msg.append("")
return "\n".join(msg)
def patch_sources(

Some files were not shown because too many files have changed in this diff Show More