mirror of
https://github.com/dbt-labs/dbt-core
synced 2025-12-21 11:21:28 +00:00
Compare commits
63 Commits
experiment
...
testing-pr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
42058de028 | ||
|
|
6a5ed4f418 | ||
|
|
ef25698d3d | ||
|
|
ab3f994626 | ||
|
|
db325d0fde | ||
|
|
24e4b75c35 | ||
|
|
34174abf26 | ||
|
|
af778312cb | ||
|
|
280f5614ef | ||
|
|
034a44e625 | ||
|
|
84155fdff7 | ||
|
|
b70fb543f5 | ||
|
|
31c88f9f5a | ||
|
|
344a14416d | ||
|
|
be47a0c5db | ||
|
|
c7c057483d | ||
|
|
7f5170ae4d | ||
|
|
49b8693b11 | ||
|
|
d7b0a14eb5 | ||
|
|
8996cb1e18 | ||
|
|
38f278cce0 | ||
|
|
bb4e475044 | ||
|
|
4fbe36a8e9 | ||
|
|
a1a40b562a | ||
|
|
3a4a1bb005 | ||
|
|
4833348769 | ||
|
|
ad07d59a78 | ||
|
|
e8aaabd1d3 | ||
|
|
d7d7396eeb | ||
|
|
41538860cd | ||
|
|
5c9f8a0cf0 | ||
|
|
11c997c3e9 | ||
|
|
1b1184a5e1 | ||
|
|
4ffcc43ed9 | ||
|
|
4ccaac46a6 | ||
|
|
ba88b84055 | ||
|
|
e88f1f1edb | ||
|
|
13c7486f0e | ||
|
|
8e811ba141 | ||
|
|
c5d86afed6 | ||
|
|
43a0cfbee1 | ||
|
|
8567d5f302 | ||
|
|
36d1bddc5b | ||
|
|
bf992680af | ||
|
|
e064298dfc | ||
|
|
e01a10ced5 | ||
|
|
2aa10fb1ed | ||
|
|
66f442ad76 | ||
|
|
11f1ecebcf | ||
|
|
e339cb27f6 | ||
|
|
bce3232b39 | ||
|
|
b08970ce39 | ||
|
|
533f88ceaf | ||
|
|
c8f0469a44 | ||
|
|
a1fc24e532 | ||
|
|
d80daa48df | ||
|
|
92aae2803f | ||
|
|
77cbbbfaf2 | ||
|
|
6c6649f912 | ||
|
|
56c2518936 | ||
|
|
e52a599be6 | ||
|
|
99744bd318 | ||
|
|
1b666d01cf |
@@ -4,7 +4,7 @@ parse = (?P<major>\d+)
|
|||||||
\.(?P<minor>\d+)
|
\.(?P<minor>\d+)
|
||||||
\.(?P<patch>\d+)
|
\.(?P<patch>\d+)
|
||||||
((?P<prerelease>[a-z]+)(?P<num>\d+))?
|
((?P<prerelease>[a-z]+)(?P<num>\d+))?
|
||||||
serialize =
|
serialize =
|
||||||
{major}.{minor}.{patch}{prerelease}{num}
|
{major}.{minor}.{patch}{prerelease}{num}
|
||||||
{major}.{minor}.{patch}
|
{major}.{minor}.{patch}
|
||||||
commit = False
|
commit = False
|
||||||
@@ -12,7 +12,7 @@ tag = False
|
|||||||
|
|
||||||
[bumpversion:part:prerelease]
|
[bumpversion:part:prerelease]
|
||||||
first_value = a
|
first_value = a
|
||||||
values =
|
values =
|
||||||
a
|
a
|
||||||
b
|
b
|
||||||
rc
|
rc
|
||||||
@@ -41,4 +41,3 @@ first_value = 1
|
|||||||
[bumpversion:file:plugins/snowflake/dbt/adapters/snowflake/__version__.py]
|
[bumpversion:file:plugins/snowflake/dbt/adapters/snowflake/__version__.py]
|
||||||
|
|
||||||
[bumpversion:file:plugins/bigquery/dbt/adapters/bigquery/__version__.py]
|
[bumpversion:file:plugins/bigquery/dbt/adapters/bigquery/__version__.py]
|
||||||
|
|
||||||
|
|||||||
45
.github/dependabot.yml
vendored
Normal file
45
.github/dependabot.yml
vendored
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
version: 2
|
||||||
|
updates:
|
||||||
|
# python dependencies
|
||||||
|
- package-ecosystem: "pip"
|
||||||
|
directory: "/"
|
||||||
|
schedule:
|
||||||
|
interval: "daily"
|
||||||
|
rebase-strategy: "disabled"
|
||||||
|
- package-ecosystem: "pip"
|
||||||
|
directory: "/core"
|
||||||
|
schedule:
|
||||||
|
interval: "daily"
|
||||||
|
rebase-strategy: "disabled"
|
||||||
|
- package-ecosystem: "pip"
|
||||||
|
directory: "/plugins/bigquery"
|
||||||
|
schedule:
|
||||||
|
interval: "daily"
|
||||||
|
rebase-strategy: "disabled"
|
||||||
|
- package-ecosystem: "pip"
|
||||||
|
directory: "/plugins/postgres"
|
||||||
|
schedule:
|
||||||
|
interval: "daily"
|
||||||
|
rebase-strategy: "disabled"
|
||||||
|
- package-ecosystem: "pip"
|
||||||
|
directory: "/plugins/redshift"
|
||||||
|
schedule:
|
||||||
|
interval: "daily"
|
||||||
|
rebase-strategy: "disabled"
|
||||||
|
- package-ecosystem: "pip"
|
||||||
|
directory: "/plugins/snowflake"
|
||||||
|
schedule:
|
||||||
|
interval: "daily"
|
||||||
|
rebase-strategy: "disabled"
|
||||||
|
|
||||||
|
# docker dependencies
|
||||||
|
- package-ecosystem: "docker"
|
||||||
|
directory: "/"
|
||||||
|
schedule:
|
||||||
|
interval: "weekly"
|
||||||
|
rebase-strategy: "disabled"
|
||||||
|
- package-ecosystem: "docker"
|
||||||
|
directory: "/docker"
|
||||||
|
schedule:
|
||||||
|
interval: "weekly"
|
||||||
|
rebase-strategy: "disabled"
|
||||||
20
.pre-commit-config.yaml
Normal file
20
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v2.3.0
|
||||||
|
hooks:
|
||||||
|
- id: check-yaml
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
- id: trailing-whitespace
|
||||||
|
- repo: https://github.com/psf/black
|
||||||
|
rev: 20.8b1
|
||||||
|
hooks:
|
||||||
|
- id: black
|
||||||
|
- repo: https://gitlab.com/PyCQA/flake8
|
||||||
|
rev: 3.9.0
|
||||||
|
hooks:
|
||||||
|
- id: flake8
|
||||||
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
|
rev: v0.812
|
||||||
|
hooks:
|
||||||
|
- id: mypy
|
||||||
|
files: ^core/dbt/
|
||||||
49
ARCHITECTURE.md
Normal file
49
ARCHITECTURE.md
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
The core function of dbt is SQL compilation and execution. Users create projects of dbt resources (models, tests, seeds, snapshots, ...), defined in SQL and YAML files, and they invoke dbt to create, update, or query associated views and tables. Today, dbt makes heavy use of Jinja2 to enable the templating of SQL, and to construct a DAG (Directed Acyclic Graph) from all of the resources in a project. Users can also extend their projects by installing resources (including Jinja macros) from other projects, called "packages."
|
||||||
|
|
||||||
|
## dbt-core
|
||||||
|
|
||||||
|
Most of the python code in the repository is within the `core/dbt` directory. Currently the main subdirectories are:
|
||||||
|
- [`adapters`](core/dbt/adapters): Define base classes for behavior that is likely to differ across databases
|
||||||
|
- [`clients`](core/dbt/clients): Interface with dependencies (agate, jinja) or across operating systems
|
||||||
|
- [`config`](core/dbt/config): Reconcile user-supplied configuration from connection profiles, project files, and Jinja macros
|
||||||
|
- [`context`](core/dbt/context): Build and expose dbt-specific Jinja functionality
|
||||||
|
- [`contracts`](core/dbt/contracts): Define Python objects (dataclasses) that dbt expects to create and validate
|
||||||
|
- [`deps`](core/dbt/deps): Package installation and dependency resolution
|
||||||
|
- [`graph`](core/dbt/graph): Produce a `networkx` DAG of project resources, and selecting those resources given user-supplied criteria
|
||||||
|
- [`include`](core/dbt/include): The dbt "global project," which defines default implementations of Jinja2 macros
|
||||||
|
- [`parser`](core/dbt/parser): Read project files, validate, construct python objects
|
||||||
|
- [`rpc`](core/dbt/rpc): Provide remote procedure call server for invoking dbt, following JSON-RPC 2.0 spec
|
||||||
|
- [`task`](core/dbt/task): Set forth the actions that dbt can perform when invoked
|
||||||
|
|
||||||
|
### Invoking dbt
|
||||||
|
|
||||||
|
There are two supported ways of invoking dbt: from the command line and using an RPC server.
|
||||||
|
|
||||||
|
The "tasks" map to top-level dbt commands. So `dbt run` => task.run.RunTask, etc. Some are more like abstract base classes (GraphRunnableTask, for example) but all the concrete types outside of task/rpc should map to tasks. Currently one executes at a time. The tasks kick off their “Runners” and those do execute in parallel. The parallelism is managed via a thread pool, in GraphRunnableTask.
|
||||||
|
|
||||||
|
core/dbt/include/index.html
|
||||||
|
This is the docs website code. It comes from the dbt-docs repository, and is generated when a release is packaged.
|
||||||
|
|
||||||
|
## Adapters
|
||||||
|
|
||||||
|
dbt uses an adapter-plugin pattern to extend support to different databases, warehouses, query engines, etc. The four core adapters that are in the main repository, contained within the [`plugins`](plugins) subdirectory, are: Postgres Redshift, Snowflake and BigQuery. Other warehouses use adapter plugins defined in separate repositories (e.g. [dbt-spark](https://github.com/fishtown-analytics/dbt-spark), [dbt-presto](https://github.com/fishtown-analytics/dbt-presto)).
|
||||||
|
|
||||||
|
Each adapter is a mix of python, Jinja2, and SQL. The adapter code also makes heavy use of Jinja2 to wrap modular chunks of SQL functionality, define default implementations, and allow plugins to override it.
|
||||||
|
|
||||||
|
Each adapter plugin is a standalone python package that includes:
|
||||||
|
|
||||||
|
- `dbt/include/[name]`: A "sub-global" dbt project, of YAML and SQL files, that reimplements Jinja macros to use the adapter's supported SQL syntax
|
||||||
|
- `dbt/adapters/[name]`: Python modules that inherit, and optionally reimplement, the base adapter classes defined in dbt-core
|
||||||
|
- `setup.py`
|
||||||
|
|
||||||
|
The Postgres adapter code is the most central, and many of its implementations are used as the default defined in the dbt-core global project. The greater the distance of a data technology from Postgres, the more its adapter plugin may need to reimplement.
|
||||||
|
|
||||||
|
## Testing dbt
|
||||||
|
|
||||||
|
The [`test/`](test/) subdirectory includes unit and integration tests that run as continuous integration checks against open pull requests. Unit tests check mock inputs and outputs of specific python functions. Integration tests perform end-to-end dbt invocations against real adapters (Postgres, Redshift, Snowflake, BigQuery) and assert that the results match expectations. See [the contributing guide](CONTRIBUTING.md) for a step-by-step walkthrough of setting up a local development and testing environment.
|
||||||
|
|
||||||
|
## Everything else
|
||||||
|
|
||||||
|
- [docker](docker/): All dbt versions are published as Docker images on DockerHub. This subfolder contains the `Dockerfile` (constant) and `requirements.txt` (one for each version).
|
||||||
|
- [etc](etc/): Images for README
|
||||||
|
- [scripts](scripts/): Helper scripts for testing, releasing, and producing JSON schemas. These are not included in distributions of dbt, not are they rigorously tested—they're just handy tools for the dbt maintainers :)
|
||||||
21
CHANGELOG.md
21
CHANGELOG.md
@@ -1,24 +1,43 @@
|
|||||||
## dbt 0.20.0 (Release TBD)
|
## dbt 0.20.0 (Release TBD)
|
||||||
|
|
||||||
### Fixes
|
### Fixes
|
||||||
|
|
||||||
- Fix exit code from dbt debug not returning a failure when one of the tests fail ([#3017](https://github.com/fishtown-analytics/dbt/issues/3017))
|
- Fix exit code from dbt debug not returning a failure when one of the tests fail ([#3017](https://github.com/fishtown-analytics/dbt/issues/3017))
|
||||||
- Auto-generated CTEs in tests and ephemeral models have lowercase names to comply with dbt coding conventions ([#3027](https://github.com/fishtown-analytics/dbt/issues/3027), [#3028](https://github.com/fishtown-analytics/dbt/issues/3028))
|
- Auto-generated CTEs in tests and ephemeral models have lowercase names to comply with dbt coding conventions ([#3027](https://github.com/fishtown-analytics/dbt/issues/3027), [#3028](https://github.com/fishtown-analytics/dbt/issues/3028))
|
||||||
|
- Fix incorrect error message when a selector does not match any node [#3036](https://github.com/fishtown-analytics/dbt/issues/3036))
|
||||||
|
- Fix variable `_dbt_max_partition` declaration and initialization for BigQuery incremental models ([#2940](https://github.com/fishtown-analytics/dbt/issues/2940), [#2976](https://github.com/fishtown-analytics/dbt/pull/2976))
|
||||||
|
- Moving from 'master' to 'HEAD' default branch in git ([#3057](https://github.com/fishtown-analytics/dbt/issues/3057), [#3104](https://github.com/fishtown-analytics/dbt/issues/3104), [#3117](https://github.com/fishtown-analytics/dbt/issues/3117)))
|
||||||
|
- Requirement on `dataclasses` is relaxed to be between `>=0.6,<0.9` allowing dbt to cohabit with other libraries which required higher versions. ([#3150](https://github.com/fishtown-analytics/dbt/issues/3150), [#3151](https://github.com/fishtown-analytics/dbt/pull/3151))
|
||||||
|
|
||||||
### Features
|
### Features
|
||||||
- Add optional configs for `require_partition_filter` and `partition_expiration_days` in BigQuery ([#1843](https://github.com/fishtown-analytics/dbt/issues/1843), [#2928](https://github.com/fishtown-analytics/dbt/pull/2928))
|
- Add optional configs for `require_partition_filter` and `partition_expiration_days` in BigQuery ([#1843](https://github.com/fishtown-analytics/dbt/issues/1843), [#2928](https://github.com/fishtown-analytics/dbt/pull/2928))
|
||||||
- Fix for EOL SQL comments prevent entire line execution ([#2731](https://github.com/fishtown-analytics/dbt/issues/2731), [#2974](https://github.com/fishtown-analytics/dbt/pull/2974))
|
- Fix for EOL SQL comments prevent entire line execution ([#2731](https://github.com/fishtown-analytics/dbt/issues/2731), [#2974](https://github.com/fishtown-analytics/dbt/pull/2974))
|
||||||
|
|
||||||
|
### Under the hood
|
||||||
|
- Add dependabot configuration for alerting maintainers about keeping dependencies up to date and secure. ([#3061](https://github.com/fishtown-analytics/dbt/issues/3061), [#3062](https://github.com/fishtown-analytics/dbt/pull/3062))
|
||||||
|
- Update script to collect and write json schema for dbt artifacts ([#2870](https://github.com/fishtown-analytics/dbt/issues/2870), [#3065](https://github.com/fishtown-analytics/dbt/pull/3065))
|
||||||
|
|
||||||
Contributors:
|
Contributors:
|
||||||
- [@yu-iskw](https://github.com/yu-iskw) ([#2928](https://github.com/fishtown-analytics/dbt/pull/2928))
|
- [@yu-iskw](https://github.com/yu-iskw) ([#2928](https://github.com/fishtown-analytics/dbt/pull/2928))
|
||||||
- [@sdebruyn](https://github.com/sdebruyn) / [@lynxcare](https://github.com/lynxcare) ([#3018](https://github.com/fishtown-analytics/dbt/pull/3018))
|
- [@sdebruyn](https://github.com/sdebruyn) / [@lynxcare](https://github.com/lynxcare) ([#3018](https://github.com/fishtown-analytics/dbt/pull/3018))
|
||||||
- [@rvacaru](https://github.com/rvacaru) ([#2974](https://github.com/fishtown-analytics/dbt/pull/2974))
|
- [@rvacaru](https://github.com/rvacaru) ([#2974](https://github.com/fishtown-analytics/dbt/pull/2974))
|
||||||
- [@NiallRees](https://github.com/NiallRees) ([#3028](https://github.com/fishtown-analytics/dbt/pull/3028))
|
- [@NiallRees](https://github.com/NiallRees) ([#3028](https://github.com/fishtown-analytics/dbt/pull/3028))
|
||||||
|
- [ran-eh](https://github.com/ran-eh) ([#3036](https://github.com/fishtown-analytics/dbt/pull/3036))
|
||||||
|
- [@pcasteran](https://github.com/pcasteran) ([#2976](https://github.com/fishtown-analytics/dbt/pull/2976))
|
||||||
|
- [@VasiliiSurov](https://github.com/VasiliiSurov) ([#3104](https://github.com/fishtown-analytics/dbt/pull/3104))
|
||||||
|
- [@bastienboutonnet](https://github.com/bastienboutonnet) ([#3151](https://github.com/fishtown-analytics/dbt/pull/3151))
|
||||||
|
|
||||||
## dbt 0.19.1 (Release TBD)
|
## dbt 0.19.1 (Release TBD)
|
||||||
|
|
||||||
|
### Fixes
|
||||||
|
|
||||||
|
- On BigQuery, fix regressions for `insert_overwrite` incremental strategy with `int64` and `timestamp` partition columns ([#3063](https://github.com/fishtown-analytics/dbt/issues/3063), [#3095](https://github.com/fishtown-analytics/dbt/issues/3095), [#3098](https://github.com/fishtown-analytics/dbt/issues/3098))
|
||||||
|
|
||||||
### Under the hood
|
### Under the hood
|
||||||
- Bump werkzeug upper bound dependency to `<v2.0` ([#3011](https://github.com/fishtown-analytics/dbt/pull/3011))
|
- Bump werkzeug upper bound dependency to `<v2.0` ([#3011](https://github.com/fishtown-analytics/dbt/pull/3011))
|
||||||
|
- Performance fixes for many different things ([#2862](https://github.com/fishtown-analytics/dbt/issues/2862), [#3034](https://github.com/fishtown-analytics/dbt/pull/3034))
|
||||||
|
- Update code to use Mashumaro 2.0 ([#3138](https://github.com/fishtown-analytics/dbt/pull/3138))
|
||||||
|
- Add an event to track resource counts ([#3050](https://github.com/fishtown-analytics/dbt/issues/3050), [#3156](https://github.com/fishtown-analytics/dbt/pull/3156))
|
||||||
|
- Pin `agate<1.6.2` to avoid installation errors relating to its new dependency `PyICU` ([#3160](https://github.com/fishtown-analytics/dbt/issues/3160), [#3161](https://github.com/fishtown-analytics/dbt/pull/3161))
|
||||||
|
|
||||||
Contributors:
|
Contributors:
|
||||||
- [@Bl3f](https://github.com/Bl3f) ([#3011](https://github.com/fishtown-analytics/dbt/pull/3011))
|
- [@Bl3f](https://github.com/Bl3f) ([#3011](https://github.com/fishtown-analytics/dbt/pull/3011))
|
||||||
|
|||||||
@@ -62,6 +62,13 @@ The dbt maintainers use labels to categorize open issues. Some labels indicate t
|
|||||||
| [stale](https://github.com/fishtown-analytics/dbt/labels/stale) | This is an old issue which has not recently been updated. Stale issues will periodically be closed by dbt maintainers, but they can be re-opened if the discussion is restarted. |
|
| [stale](https://github.com/fishtown-analytics/dbt/labels/stale) | This is an old issue which has not recently been updated. Stale issues will periodically be closed by dbt maintainers, but they can be re-opened if the discussion is restarted. |
|
||||||
| [wontfix](https://github.com/fishtown-analytics/dbt/labels/wontfix) | This issue does not require a code change in the dbt repository, or the maintainers are unwilling/unable to merge a Pull Request which implements the behavior described in the issue. |
|
| [wontfix](https://github.com/fishtown-analytics/dbt/labels/wontfix) | This issue does not require a code change in the dbt repository, or the maintainers are unwilling/unable to merge a Pull Request which implements the behavior described in the issue. |
|
||||||
|
|
||||||
|
#### Branching Strategy
|
||||||
|
|
||||||
|
dbt has three types of branches:
|
||||||
|
|
||||||
|
- **Trunks** are where active development of the next release takes place. There is one trunk named `develop` at the time of writing this, and will be the default branch of the repository.
|
||||||
|
- **Release Branches** track a specific, not yet complete release of dbt. Each minor version release has a corresponding release branch. For example, the `0.11.x` series of releases has a branch called `0.11.latest`. This allows us to release new patch versions under `0.11` without necessarily needing to pull them into the latest version of dbt.
|
||||||
|
- **Feature Branches** track individual features and fixes. On completion they should be merged into the trunk brnach or a specific release branch.
|
||||||
|
|
||||||
## Getting the code
|
## Getting the code
|
||||||
|
|
||||||
@@ -81,10 +88,9 @@ If you are not a member of the `fishtown-analytics` GitHub organization, you can
|
|||||||
|
|
||||||
### Core contributors
|
### Core contributors
|
||||||
|
|
||||||
If you are a member of the `fishtown-analytics` GitHub organization, you will have push access to the dbt repo. Rather than
|
If you are a member of the `fishtown-analytics` GitHub organization, you will have push access to the dbt repo. Rather than
|
||||||
forking dbt to make your changes, just clone the repository, check out a new branch, and push directly to that branch.
|
forking dbt to make your changes, just clone the repository, check out a new branch, and push directly to that branch.
|
||||||
|
|
||||||
|
|
||||||
## Setting up an environment
|
## Setting up an environment
|
||||||
|
|
||||||
There are some tools that will be helpful to you in developing locally. While this is the list relevant for dbt development, many of these tools are used commonly across open-source python projects.
|
There are some tools that will be helpful to you in developing locally. While this is the list relevant for dbt development, many of these tools are used commonly across open-source python projects.
|
||||||
@@ -115,7 +121,7 @@ This will create and activate a new Python virtual environment.
|
|||||||
|
|
||||||
#### docker and docker-compose
|
#### docker and docker-compose
|
||||||
|
|
||||||
Docker and docker-compose are both used in testing. For macOS, the easiest thing to do is to [download docker for mac](https://store.docker.com/editions/community/docker-ce-desktop-mac). You'll need to make an account. On Linux, you can use one of the packages [here](https://docs.docker.com/install/#server). We recommend installing from docker.com instead of from your package manager. On Linux you also have to install docker-compose separately, following [these instructions](https://docs.docker.com/compose/install/#install-compose).
|
Docker and docker-compose are both used in testing. Specific instructions for you OS can be found [here](https://docs.docker.com/get-docker/).
|
||||||
|
|
||||||
|
|
||||||
#### postgres (optional)
|
#### postgres (optional)
|
||||||
@@ -133,7 +139,7 @@ brew install postgresql
|
|||||||
First make sure that you set up your `virtualenv` as described in section _Setting up an environment_. Next, install dbt (and its dependencies) with:
|
First make sure that you set up your `virtualenv` as described in section _Setting up an environment_. Next, install dbt (and its dependencies) with:
|
||||||
|
|
||||||
```
|
```
|
||||||
pip install -r editable_requirements.txt
|
pip install -r requirements-editable.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
When dbt is installed from source in this way, any changes you make to the dbt source code will be reflected immediately in your next `dbt` run.
|
When dbt is installed from source in this way, any changes you make to the dbt source code will be reflected immediately in your next `dbt` run.
|
||||||
@@ -159,7 +165,6 @@ dbt uses test credentials specified in a `test.env` file in the root of the repo
|
|||||||
|
|
||||||
```
|
```
|
||||||
cp test.env.sample test.env
|
cp test.env.sample test.env
|
||||||
atom test.env # supply your credentials
|
|
||||||
```
|
```
|
||||||
|
|
||||||
We recommend starting with dbt's Postgres tests. These tests cover most of the functionality in dbt, are the fastest to run, and are the easiest to set up. dbt's test suite runs Postgres in a Docker container, so no setup should be required to run these tests.
|
We recommend starting with dbt's Postgres tests. These tests cover most of the functionality in dbt, are the fastest to run, and are the easiest to set up. dbt's test suite runs Postgres in a Docker container, so no setup should be required to run these tests.
|
||||||
|
|||||||
7
Makefile
7
Makefile
@@ -1,10 +1,5 @@
|
|||||||
.PHONY: install test test-unit test-integration
|
.PHONY: install test test-unit test-integration
|
||||||
|
|
||||||
changed_tests := `git status --porcelain | grep '^\(M\| M\|A\| A\)' | awk '{ print $$2 }' | grep '\/test_[a-zA-Z_\-\.]\+.py'`
|
|
||||||
|
|
||||||
install:
|
|
||||||
pip install -e .
|
|
||||||
|
|
||||||
test: .env
|
test: .env
|
||||||
@echo "Full test run starting..."
|
@echo "Full test run starting..."
|
||||||
@time docker-compose run --rm test tox
|
@time docker-compose run --rm test tox
|
||||||
@@ -18,7 +13,7 @@ test-integration: .env
|
|||||||
@time docker-compose run --rm test tox -e integration-postgres-py36,integration-redshift-py36,integration-snowflake-py36,integration-bigquery-py36
|
@time docker-compose run --rm test tox -e integration-postgres-py36,integration-redshift-py36,integration-snowflake-py36,integration-bigquery-py36
|
||||||
|
|
||||||
test-quick: .env
|
test-quick: .env
|
||||||
@echo "Integration test run starting..."
|
@echo "Integration test run starting, will exit on first failure..."
|
||||||
@time docker-compose run --rm test tox -e integration-postgres-py36 -- -x
|
@time docker-compose run --rm test tox -e integration-postgres-py36 -- -x
|
||||||
|
|
||||||
# This rule creates a file named .env that is used by docker-compose for passing
|
# This rule creates a file named .env that is used by docker-compose for passing
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="/etc/dbt-logo-full.svg" alt="dbt logo" width="500"/>
|
<img src="https://raw.githubusercontent.com/fishtown-analytics/dbt/6c6649f9129d5d108aa3b0526f634cd8f3a9d1ed/etc/dbt-logo-full.svg" alt="dbt logo" width="500"/>
|
||||||
</p>
|
</p>
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<a href="https://codeclimate.com/github/fishtown-analytics/dbt">
|
<a href="https://codeclimate.com/github/fishtown-analytics/dbt">
|
||||||
@@ -20,7 +20,7 @@
|
|||||||
|
|
||||||
dbt is the T in ELT. Organize, cleanse, denormalize, filter, rename, and pre-aggregate the raw data in your warehouse so that it's ready for analysis.
|
dbt is the T in ELT. Organize, cleanse, denormalize, filter, rename, and pre-aggregate the raw data in your warehouse so that it's ready for analysis.
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
dbt can be used to [aggregate pageviews into sessions](https://github.com/fishtown-analytics/snowplow), calculate [ad spend ROI](https://github.com/fishtown-analytics/facebook-ads), or report on [email campaign performance](https://github.com/fishtown-analytics/mailchimp).
|
dbt can be used to [aggregate pageviews into sessions](https://github.com/fishtown-analytics/snowplow), calculate [ad spend ROI](https://github.com/fishtown-analytics/facebook-ads), or report on [email campaign performance](https://github.com/fishtown-analytics/mailchimp).
|
||||||
|
|
||||||
@@ -30,7 +30,7 @@ Analysts using dbt can transform their data by simply writing select statements,
|
|||||||
|
|
||||||
These select statements, or "models", form a dbt project. Models frequently build on top of one another – dbt makes it easy to [manage relationships](https://docs.getdbt.com/docs/ref) between models, and [visualize these relationships](https://docs.getdbt.com/docs/documentation), as well as assure the quality of your transformations through [testing](https://docs.getdbt.com/docs/testing).
|
These select statements, or "models", form a dbt project. Models frequently build on top of one another – dbt makes it easy to [manage relationships](https://docs.getdbt.com/docs/ref) between models, and [visualize these relationships](https://docs.getdbt.com/docs/documentation), as well as assure the quality of your transformations through [testing](https://docs.getdbt.com/docs/testing).
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
## Getting started
|
## Getting started
|
||||||
|
|
||||||
@@ -51,7 +51,7 @@ These select statements, or "models", form a dbt project. Models frequently buil
|
|||||||
## Reporting bugs and contributing code
|
## Reporting bugs and contributing code
|
||||||
|
|
||||||
- Want to report a bug or request a feature? Let us know on [Slack](http://community.getdbt.com/), or open [an issue](https://github.com/fishtown-analytics/dbt/issues/new).
|
- Want to report a bug or request a feature? Let us know on [Slack](http://community.getdbt.com/), or open [an issue](https://github.com/fishtown-analytics/dbt/issues/new).
|
||||||
- Want to help us build dbt? Check out the [Contributing Getting Started Guide](/CONTRIBUTING.md)
|
- Want to help us build dbt? Check out the [Contributing Getting Started Guide](https://github.com/fishtown-analytics/dbt/blob/HEAD/CONTRIBUTING.md)
|
||||||
|
|
||||||
## Code of Conduct
|
## Code of Conduct
|
||||||
|
|
||||||
|
|||||||
92
RELEASE.md
92
RELEASE.md
@@ -1,92 +0,0 @@
|
|||||||
### Release Procedure :shipit:
|
|
||||||
|
|
||||||
#### Branching Strategy
|
|
||||||
|
|
||||||
dbt has three types of branches:
|
|
||||||
|
|
||||||
- **Trunks** track the latest release of a minor version of dbt. Historically, we used the `master` branch as the trunk. Each minor version release has a corresponding trunk. For example, the `0.11.x` series of releases has a branch called `0.11.latest`. This allows us to release new patch versions under `0.11` without necessarily needing to pull them into the latest version of dbt.
|
|
||||||
- **Release Branches** track a specific, not yet complete release of dbt. These releases are codenamed since we don't always know what their semantic version will be. Example: `dev/lucretia-mott` became `0.11.1`.
|
|
||||||
- **Feature Branches** track individual features and fixes. On completion they should be merged into a release branch.
|
|
||||||
|
|
||||||
#### Git & PyPI
|
|
||||||
|
|
||||||
1. Update CHANGELOG.md with the most recent changes
|
|
||||||
2. If this is a release candidate, you want to create it off of your release branch. If it's an actual release, you must first merge to a master branch. Open a Pull Request in Github to merge it into the appropriate trunk (`X.X.latest`)
|
|
||||||
3. Bump the version using `bumpversion`:
|
|
||||||
- Dry run first by running `bumpversion --new-version <desired-version> <part>` and checking the diff. If it looks correct, clean up the chanages and move on:
|
|
||||||
- Alpha releases: `bumpversion --commit --no-tag --new-version 0.10.2a1 num`
|
|
||||||
- Patch releases: `bumpversion --commit --no-tag --new-version 0.10.2 patch`
|
|
||||||
- Minor releases: `bumpversion --commit --no-tag --new-version 0.11.0 minor`
|
|
||||||
- Major releases: `bumpversion --commit --no-tag --new-version 1.0.0 major`
|
|
||||||
4. (If this is a not a release candidate) Merge to `x.x.latest` and (optionally) `master`.
|
|
||||||
5. Update the default branch to the next dev release branch.
|
|
||||||
6. Build source distributions for all packages by running `./scripts/build-sdists.sh`. Note that this will clean out your `dist/` folder, so if you have important stuff in there, don't run it!!!
|
|
||||||
7. Deploy to pypi
|
|
||||||
- `twine upload dist/*`
|
|
||||||
8. Deploy to homebrew (see below)
|
|
||||||
9. Deploy to conda-forge (see below)
|
|
||||||
10. Git release notes (points to changelog)
|
|
||||||
11. Post to slack (point to changelog)
|
|
||||||
|
|
||||||
After releasing a new version, it's important to merge the changes back into the other outstanding release branches. This avoids merge conflicts moving forward.
|
|
||||||
|
|
||||||
In some cases, where the branches have diverged wildly, it's ok to skip this step. But this means that the changes you just released won't be included in future releases.
|
|
||||||
|
|
||||||
#### Homebrew Release Process
|
|
||||||
|
|
||||||
1. Clone the `homebrew-dbt` repository:
|
|
||||||
|
|
||||||
```
|
|
||||||
git clone git@github.com:fishtown-analytics/homebrew-dbt.git
|
|
||||||
```
|
|
||||||
|
|
||||||
2. For ALL releases (prereleases and version releases), copy the relevant formula. To copy from the latest version release of dbt, do:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cp Formula/dbt.rb Formula/dbt@{NEW-VERSION}.rb
|
|
||||||
```
|
|
||||||
|
|
||||||
To copy from a different version, simply copy the corresponding file.
|
|
||||||
|
|
||||||
3. Open the file, and edit the following:
|
|
||||||
- the name of the ruby class: this is important, homebrew won't function properly if the class name is wrong. Check historical versions to figure out the right name.
|
|
||||||
- under the `bottle` section, remove all of the hashes (lines starting with `sha256`)
|
|
||||||
|
|
||||||
4. Create a **Python 3.7** virtualenv, activate it, and then install two packages: `homebrew-pypi-poet`, and the version of dbt you are preparing. I use:
|
|
||||||
|
|
||||||
```
|
|
||||||
pyenv virtualenv 3.7.0 homebrew-dbt-{VERSION}
|
|
||||||
pyenv activate homebrew-dbt-{VERSION}
|
|
||||||
pip install dbt=={VERSION} homebrew-pypi-poet
|
|
||||||
```
|
|
||||||
|
|
||||||
homebrew-pypi-poet is a program that generates a valid homebrew formula for an installed pip package. You want to use it to generate a diff against the existing formula. Then you want to apply the diff for the dependency packages only -- e.g. it will tell you that `google-api-core` has been updated and that you need to use the latest version.
|
|
||||||
|
|
||||||
5. reinstall, test, and audit dbt. if the test or audit fails, fix the formula with step 1.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
brew uninstall --force Formula/{YOUR-FILE}.rb
|
|
||||||
brew install Formula/{YOUR-FILE}.rb
|
|
||||||
brew test dbt
|
|
||||||
brew audit --strict dbt
|
|
||||||
```
|
|
||||||
|
|
||||||
6. Ask Connor to bottle the change (only his laptop can do it!)
|
|
||||||
|
|
||||||
#### Conda Forge Release Process
|
|
||||||
|
|
||||||
1. Clone the fork of `conda-forge/dbt-feedstock` [here](https://github.com/fishtown-analytics/dbt-feedstock)
|
|
||||||
```bash
|
|
||||||
git clone git@github.com:fishtown-analytics/dbt-feedstock.git
|
|
||||||
|
|
||||||
```
|
|
||||||
2. Update the version and sha256 in `recipe/meta.yml`. To calculate the sha256, run:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
wget https://github.com/fishtown-analytics/dbt/archive/v{version}.tar.gz
|
|
||||||
openssl sha256 v{version}.tar.gz
|
|
||||||
```
|
|
||||||
|
|
||||||
3. Push the changes and create a PR against `conda-forge/dbt-feedstock`
|
|
||||||
|
|
||||||
4. Confirm that all automated conda-forge tests are passing
|
|
||||||
@@ -139,7 +139,7 @@ jobs:
|
|||||||
inputs:
|
inputs:
|
||||||
versionSpec: '3.7'
|
versionSpec: '3.7'
|
||||||
architecture: 'x64'
|
architecture: 'x64'
|
||||||
- script: python -m pip install --upgrade pip setuptools && python -m pip install -r requirements.txt && python -m pip install -r dev_requirements.txt
|
- script: python -m pip install --upgrade pip setuptools && python -m pip install -r requirements.txt && python -m pip install -r requirements-dev.txt
|
||||||
displayName: Install dependencies
|
displayName: Install dependencies
|
||||||
- task: ShellScript@2
|
- task: ShellScript@2
|
||||||
inputs:
|
inputs:
|
||||||
|
|||||||
@@ -63,11 +63,13 @@ def main():
|
|||||||
packages = registry.packages()
|
packages = registry.packages()
|
||||||
project_json = init_project_in_packages(args, packages)
|
project_json = init_project_in_packages(args, packages)
|
||||||
if args.project["version"] in project_json["versions"]:
|
if args.project["version"] in project_json["versions"]:
|
||||||
raise Exception("Version {} already in packages JSON"
|
raise Exception(
|
||||||
.format(args.project["version"]),
|
"Version {} already in packages JSON".format(args.project["version"]),
|
||||||
file=sys.stderr)
|
file=sys.stderr,
|
||||||
|
)
|
||||||
add_version_to_package(args, project_json)
|
add_version_to_package(args, project_json)
|
||||||
print(json.dumps(packages, indent=2))
|
print(json.dumps(packages, indent=2))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -1,19 +1,17 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from hologram import JsonSchemaMixin
|
|
||||||
from dbt.exceptions import RuntimeException
|
|
||||||
|
|
||||||
from typing import Dict, ClassVar, Any, Optional
|
from typing import Dict, ClassVar, Any, Optional
|
||||||
|
|
||||||
|
from dbt.exceptions import RuntimeException
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Column(JsonSchemaMixin):
|
class Column:
|
||||||
TYPE_LABELS: ClassVar[Dict[str, str]] = {
|
TYPE_LABELS: ClassVar[Dict[str, str]] = {
|
||||||
'STRING': 'TEXT',
|
"STRING": "TEXT",
|
||||||
'TIMESTAMP': 'TIMESTAMP',
|
"TIMESTAMP": "TIMESTAMP",
|
||||||
'FLOAT': 'FLOAT',
|
"FLOAT": "FLOAT",
|
||||||
'INTEGER': 'INT'
|
"INTEGER": "INT",
|
||||||
}
|
}
|
||||||
column: str
|
column: str
|
||||||
dtype: str
|
dtype: str
|
||||||
@@ -26,7 +24,7 @@ class Column(JsonSchemaMixin):
|
|||||||
return cls.TYPE_LABELS.get(dtype.upper(), dtype)
|
return cls.TYPE_LABELS.get(dtype.upper(), dtype)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, name, label_or_dtype: str) -> 'Column':
|
def create(cls, name, label_or_dtype: str) -> "Column":
|
||||||
column_type = cls.translate_type(label_or_dtype)
|
column_type = cls.translate_type(label_or_dtype)
|
||||||
return cls(name, column_type)
|
return cls(name, column_type)
|
||||||
|
|
||||||
@@ -43,14 +41,19 @@ class Column(JsonSchemaMixin):
|
|||||||
if self.is_string():
|
if self.is_string():
|
||||||
return Column.string_type(self.string_size())
|
return Column.string_type(self.string_size())
|
||||||
elif self.is_numeric():
|
elif self.is_numeric():
|
||||||
return Column.numeric_type(self.dtype, self.numeric_precision,
|
return Column.numeric_type(
|
||||||
self.numeric_scale)
|
self.dtype, self.numeric_precision, self.numeric_scale
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return self.dtype
|
return self.dtype
|
||||||
|
|
||||||
def is_string(self) -> bool:
|
def is_string(self) -> bool:
|
||||||
return self.dtype.lower() in ['text', 'character varying', 'character',
|
return self.dtype.lower() in [
|
||||||
'varchar']
|
"text",
|
||||||
|
"character varying",
|
||||||
|
"character",
|
||||||
|
"varchar",
|
||||||
|
]
|
||||||
|
|
||||||
def is_number(self):
|
def is_number(self):
|
||||||
return any([self.is_integer(), self.is_numeric(), self.is_float()])
|
return any([self.is_integer(), self.is_numeric(), self.is_float()])
|
||||||
@@ -58,33 +61,45 @@ class Column(JsonSchemaMixin):
|
|||||||
def is_float(self):
|
def is_float(self):
|
||||||
return self.dtype.lower() in [
|
return self.dtype.lower() in [
|
||||||
# floats
|
# floats
|
||||||
'real', 'float4', 'float', 'double precision', 'float8'
|
"real",
|
||||||
|
"float4",
|
||||||
|
"float",
|
||||||
|
"double precision",
|
||||||
|
"float8",
|
||||||
]
|
]
|
||||||
|
|
||||||
def is_integer(self) -> bool:
|
def is_integer(self) -> bool:
|
||||||
return self.dtype.lower() in [
|
return self.dtype.lower() in [
|
||||||
# real types
|
# real types
|
||||||
'smallint', 'integer', 'bigint',
|
"smallint",
|
||||||
'smallserial', 'serial', 'bigserial',
|
"integer",
|
||||||
|
"bigint",
|
||||||
|
"smallserial",
|
||||||
|
"serial",
|
||||||
|
"bigserial",
|
||||||
# aliases
|
# aliases
|
||||||
'int2', 'int4', 'int8',
|
"int2",
|
||||||
'serial2', 'serial4', 'serial8',
|
"int4",
|
||||||
|
"int8",
|
||||||
|
"serial2",
|
||||||
|
"serial4",
|
||||||
|
"serial8",
|
||||||
]
|
]
|
||||||
|
|
||||||
def is_numeric(self) -> bool:
|
def is_numeric(self) -> bool:
|
||||||
return self.dtype.lower() in ['numeric', 'decimal']
|
return self.dtype.lower() in ["numeric", "decimal"]
|
||||||
|
|
||||||
def string_size(self) -> int:
|
def string_size(self) -> int:
|
||||||
if not self.is_string():
|
if not self.is_string():
|
||||||
raise RuntimeException("Called string_size() on non-string field!")
|
raise RuntimeException("Called string_size() on non-string field!")
|
||||||
|
|
||||||
if self.dtype == 'text' or self.char_size is None:
|
if self.dtype == "text" or self.char_size is None:
|
||||||
# char_size should never be None. Handle it reasonably just in case
|
# char_size should never be None. Handle it reasonably just in case
|
||||||
return 256
|
return 256
|
||||||
else:
|
else:
|
||||||
return int(self.char_size)
|
return int(self.char_size)
|
||||||
|
|
||||||
def can_expand_to(self, other_column: 'Column') -> bool:
|
def can_expand_to(self, other_column: "Column") -> bool:
|
||||||
"""returns True if this column can be expanded to the size of the
|
"""returns True if this column can be expanded to the size of the
|
||||||
other column"""
|
other column"""
|
||||||
if not self.is_string() or not other_column.is_string():
|
if not self.is_string() or not other_column.is_string():
|
||||||
@@ -112,12 +127,10 @@ class Column(JsonSchemaMixin):
|
|||||||
return "<Column {} ({})>".format(self.name, self.data_type)
|
return "<Column {} ({})>".format(self.name, self.data_type)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_description(cls, name: str, raw_data_type: str) -> 'Column':
|
def from_description(cls, name: str, raw_data_type: str) -> "Column":
|
||||||
match = re.match(r'([^(]+)(\([^)]+\))?', raw_data_type)
|
match = re.match(r"([^(]+)(\([^)]+\))?", raw_data_type)
|
||||||
if match is None:
|
if match is None:
|
||||||
raise RuntimeException(
|
raise RuntimeException(f'Could not interpret data type "{raw_data_type}"')
|
||||||
f'Could not interpret data type "{raw_data_type}"'
|
|
||||||
)
|
|
||||||
data_type, size_info = match.groups()
|
data_type, size_info = match.groups()
|
||||||
char_size = None
|
char_size = None
|
||||||
numeric_precision = None
|
numeric_precision = None
|
||||||
@@ -125,7 +138,7 @@ class Column(JsonSchemaMixin):
|
|||||||
if size_info is not None:
|
if size_info is not None:
|
||||||
# strip out the parentheses
|
# strip out the parentheses
|
||||||
size_info = size_info[1:-1]
|
size_info = size_info[1:-1]
|
||||||
parts = size_info.split(',')
|
parts = size_info.split(",")
|
||||||
if len(parts) == 1:
|
if len(parts) == 1:
|
||||||
try:
|
try:
|
||||||
char_size = int(parts[0])
|
char_size = int(parts[0])
|
||||||
@@ -150,6 +163,4 @@ class Column(JsonSchemaMixin):
|
|||||||
f'could not convert "{parts[1]}" to an integer'
|
f'could not convert "{parts[1]}" to an integer'
|
||||||
)
|
)
|
||||||
|
|
||||||
return cls(
|
return cls(name, data_type, char_size, numeric_precision, numeric_scale)
|
||||||
name, data_type, char_size, numeric_precision, numeric_scale
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -1,18 +1,21 @@
|
|||||||
import abc
|
import abc
|
||||||
import os
|
import os
|
||||||
|
|
||||||
# multiprocessing.RLock is a function returning this type
|
# multiprocessing.RLock is a function returning this type
|
||||||
from multiprocessing.synchronize import RLock
|
from multiprocessing.synchronize import RLock
|
||||||
from threading import get_ident
|
from threading import get_ident
|
||||||
from typing import (
|
from typing import Dict, Tuple, Hashable, Optional, ContextManager, List, Union
|
||||||
Dict, Tuple, Hashable, Optional, ContextManager, List, Union
|
|
||||||
)
|
|
||||||
|
|
||||||
import agate
|
import agate
|
||||||
|
|
||||||
import dbt.exceptions
|
import dbt.exceptions
|
||||||
from dbt.contracts.connection import (
|
from dbt.contracts.connection import (
|
||||||
Connection, Identifier, ConnectionState,
|
Connection,
|
||||||
AdapterRequiredConfig, LazyHandle, AdapterResponse
|
Identifier,
|
||||||
|
ConnectionState,
|
||||||
|
AdapterRequiredConfig,
|
||||||
|
LazyHandle,
|
||||||
|
AdapterResponse,
|
||||||
)
|
)
|
||||||
from dbt.contracts.graph.manifest import Manifest
|
from dbt.contracts.graph.manifest import Manifest
|
||||||
from dbt.adapters.base.query_headers import (
|
from dbt.adapters.base.query_headers import (
|
||||||
@@ -35,6 +38,7 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
|||||||
You must also set the 'TYPE' class attribute with a class-unique constant
|
You must also set the 'TYPE' class attribute with a class-unique constant
|
||||||
string.
|
string.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
TYPE: str = NotImplemented
|
TYPE: str = NotImplemented
|
||||||
|
|
||||||
def __init__(self, profile: AdapterRequiredConfig):
|
def __init__(self, profile: AdapterRequiredConfig):
|
||||||
@@ -65,7 +69,7 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
|||||||
key = self.get_thread_identifier()
|
key = self.get_thread_identifier()
|
||||||
if key in self.thread_connections:
|
if key in self.thread_connections:
|
||||||
raise dbt.exceptions.InternalException(
|
raise dbt.exceptions.InternalException(
|
||||||
'In set_thread_connection, existing connection exists for {}'
|
"In set_thread_connection, existing connection exists for {}"
|
||||||
)
|
)
|
||||||
self.thread_connections[key] = conn
|
self.thread_connections[key] = conn
|
||||||
|
|
||||||
@@ -105,18 +109,19 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
|||||||
underlying database.
|
underlying database.
|
||||||
"""
|
"""
|
||||||
raise dbt.exceptions.NotImplementedException(
|
raise dbt.exceptions.NotImplementedException(
|
||||||
'`exception_handler` is not implemented for this adapter!')
|
"`exception_handler` is not implemented for this adapter!"
|
||||||
|
)
|
||||||
|
|
||||||
def set_connection_name(self, name: Optional[str] = None) -> Connection:
|
def set_connection_name(self, name: Optional[str] = None) -> Connection:
|
||||||
conn_name: str
|
conn_name: str
|
||||||
if name is None:
|
if name is None:
|
||||||
# if a name isn't specified, we'll re-use a single handle
|
# if a name isn't specified, we'll re-use a single handle
|
||||||
# named 'master'
|
# named 'master'
|
||||||
conn_name = 'master'
|
conn_name = "master"
|
||||||
else:
|
else:
|
||||||
if not isinstance(name, str):
|
if not isinstance(name, str):
|
||||||
raise dbt.exceptions.CompilerException(
|
raise dbt.exceptions.CompilerException(
|
||||||
f'For connection name, got {name} - not a string!'
|
f"For connection name, got {name} - not a string!"
|
||||||
)
|
)
|
||||||
assert isinstance(name, str)
|
assert isinstance(name, str)
|
||||||
conn_name = name
|
conn_name = name
|
||||||
@@ -129,20 +134,20 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
|||||||
state=ConnectionState.INIT,
|
state=ConnectionState.INIT,
|
||||||
transaction_open=False,
|
transaction_open=False,
|
||||||
handle=None,
|
handle=None,
|
||||||
credentials=self.profile.credentials
|
credentials=self.profile.credentials,
|
||||||
)
|
)
|
||||||
self.set_thread_connection(conn)
|
self.set_thread_connection(conn)
|
||||||
|
|
||||||
if conn.name == conn_name and conn.state == 'open':
|
if conn.name == conn_name and conn.state == "open":
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
logger.debug(
|
logger.debug('Acquiring new {} connection "{}".'.format(self.TYPE, conn_name))
|
||||||
'Acquiring new {} connection "{}".'.format(self.TYPE, conn_name))
|
|
||||||
|
|
||||||
if conn.state == 'open':
|
if conn.state == "open":
|
||||||
logger.debug(
|
logger.debug(
|
||||||
'Re-using an available connection from the pool (formerly {}).'
|
"Re-using an available connection from the pool (formerly {}).".format(
|
||||||
.format(conn.name)
|
conn.name
|
||||||
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
conn.handle = LazyHandle(self.open)
|
conn.handle = LazyHandle(self.open)
|
||||||
@@ -154,7 +159,7 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
|||||||
def cancel_open(self) -> Optional[List[str]]:
|
def cancel_open(self) -> Optional[List[str]]:
|
||||||
"""Cancel all open connections on the adapter. (passable)"""
|
"""Cancel all open connections on the adapter. (passable)"""
|
||||||
raise dbt.exceptions.NotImplementedException(
|
raise dbt.exceptions.NotImplementedException(
|
||||||
'`cancel_open` is not implemented for this adapter!'
|
"`cancel_open` is not implemented for this adapter!"
|
||||||
)
|
)
|
||||||
|
|
||||||
@abc.abstractclassmethod
|
@abc.abstractclassmethod
|
||||||
@@ -168,7 +173,7 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
|||||||
connection should not be in either in_use or available.
|
connection should not be in either in_use or available.
|
||||||
"""
|
"""
|
||||||
raise dbt.exceptions.NotImplementedException(
|
raise dbt.exceptions.NotImplementedException(
|
||||||
'`open` is not implemented for this adapter!'
|
"`open` is not implemented for this adapter!"
|
||||||
)
|
)
|
||||||
|
|
||||||
def release(self) -> None:
|
def release(self) -> None:
|
||||||
@@ -189,12 +194,14 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
|||||||
def cleanup_all(self) -> None:
|
def cleanup_all(self) -> None:
|
||||||
with self.lock:
|
with self.lock:
|
||||||
for connection in self.thread_connections.values():
|
for connection in self.thread_connections.values():
|
||||||
if connection.state not in {'closed', 'init'}:
|
if connection.state not in {"closed", "init"}:
|
||||||
logger.debug("Connection '{}' was left open."
|
logger.debug(
|
||||||
.format(connection.name))
|
"Connection '{}' was left open.".format(connection.name)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug("Connection '{}' was properly closed."
|
logger.debug(
|
||||||
.format(connection.name))
|
"Connection '{}' was properly closed.".format(connection.name)
|
||||||
|
)
|
||||||
self.close(connection)
|
self.close(connection)
|
||||||
|
|
||||||
# garbage collect these connections
|
# garbage collect these connections
|
||||||
@@ -204,14 +211,14 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
|||||||
def begin(self) -> None:
|
def begin(self) -> None:
|
||||||
"""Begin a transaction. (passable)"""
|
"""Begin a transaction. (passable)"""
|
||||||
raise dbt.exceptions.NotImplementedException(
|
raise dbt.exceptions.NotImplementedException(
|
||||||
'`begin` is not implemented for this adapter!'
|
"`begin` is not implemented for this adapter!"
|
||||||
)
|
)
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def commit(self) -> None:
|
def commit(self) -> None:
|
||||||
"""Commit a transaction. (passable)"""
|
"""Commit a transaction. (passable)"""
|
||||||
raise dbt.exceptions.NotImplementedException(
|
raise dbt.exceptions.NotImplementedException(
|
||||||
'`commit` is not implemented for this adapter!'
|
"`commit` is not implemented for this adapter!"
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -220,20 +227,17 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
|||||||
try:
|
try:
|
||||||
connection.handle.rollback()
|
connection.handle.rollback()
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.debug(
|
logger.debug("Failed to rollback {}".format(connection.name), exc_info=True)
|
||||||
'Failed to rollback {}'.format(connection.name),
|
|
||||||
exc_info=True
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _close_handle(cls, connection: Connection) -> None:
|
def _close_handle(cls, connection: Connection) -> None:
|
||||||
"""Perform the actual close operation."""
|
"""Perform the actual close operation."""
|
||||||
# On windows, sometimes connection handles don't have a close() attr.
|
# On windows, sometimes connection handles don't have a close() attr.
|
||||||
if hasattr(connection.handle, 'close'):
|
if hasattr(connection.handle, "close"):
|
||||||
logger.debug(f'On {connection.name}: Close')
|
logger.debug(f"On {connection.name}: Close")
|
||||||
connection.handle.close()
|
connection.handle.close()
|
||||||
else:
|
else:
|
||||||
logger.debug(f'On {connection.name}: No close available on handle')
|
logger.debug(f"On {connection.name}: No close available on handle")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _rollback(cls, connection: Connection) -> None:
|
def _rollback(cls, connection: Connection) -> None:
|
||||||
@@ -241,16 +245,16 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
|||||||
if flags.STRICT_MODE:
|
if flags.STRICT_MODE:
|
||||||
if not isinstance(connection, Connection):
|
if not isinstance(connection, Connection):
|
||||||
raise dbt.exceptions.CompilerException(
|
raise dbt.exceptions.CompilerException(
|
||||||
f'In _rollback, got {connection} - not a Connection!'
|
f"In _rollback, got {connection} - not a Connection!"
|
||||||
)
|
)
|
||||||
|
|
||||||
if connection.transaction_open is False:
|
if connection.transaction_open is False:
|
||||||
raise dbt.exceptions.InternalException(
|
raise dbt.exceptions.InternalException(
|
||||||
f'Tried to rollback transaction on connection '
|
f"Tried to rollback transaction on connection "
|
||||||
f'"{connection.name}", but it does not have one open!'
|
f'"{connection.name}", but it does not have one open!'
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f'On {connection.name}: ROLLBACK')
|
logger.debug(f"On {connection.name}: ROLLBACK")
|
||||||
cls._rollback_handle(connection)
|
cls._rollback_handle(connection)
|
||||||
|
|
||||||
connection.transaction_open = False
|
connection.transaction_open = False
|
||||||
@@ -260,7 +264,7 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
|||||||
if flags.STRICT_MODE:
|
if flags.STRICT_MODE:
|
||||||
if not isinstance(connection, Connection):
|
if not isinstance(connection, Connection):
|
||||||
raise dbt.exceptions.CompilerException(
|
raise dbt.exceptions.CompilerException(
|
||||||
f'In close, got {connection} - not a Connection!'
|
f"In close, got {connection} - not a Connection!"
|
||||||
)
|
)
|
||||||
|
|
||||||
# if the connection is in closed or init, there's nothing to do
|
# if the connection is in closed or init, there's nothing to do
|
||||||
@@ -268,7 +272,7 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
|||||||
return connection
|
return connection
|
||||||
|
|
||||||
if connection.transaction_open and connection.handle:
|
if connection.transaction_open and connection.handle:
|
||||||
logger.debug('On {}: ROLLBACK'.format(connection.name))
|
logger.debug("On {}: ROLLBACK".format(connection.name))
|
||||||
cls._rollback_handle(connection)
|
cls._rollback_handle(connection)
|
||||||
connection.transaction_open = False
|
connection.transaction_open = False
|
||||||
|
|
||||||
@@ -302,5 +306,5 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
|
|||||||
:rtype: Tuple[Union[str, AdapterResponse], agate.Table]
|
:rtype: Tuple[Union[str, AdapterResponse], agate.Table]
|
||||||
"""
|
"""
|
||||||
raise dbt.exceptions.NotImplementedException(
|
raise dbt.exceptions.NotImplementedException(
|
||||||
'`execute` is not implemented for this adapter!'
|
"`execute` is not implemented for this adapter!"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,17 +4,31 @@ from contextlib import contextmanager
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import (
|
from typing import (
|
||||||
Optional, Tuple, Callable, Iterable, Type, Dict, Any, List, Mapping,
|
Optional,
|
||||||
Iterator, Union, Set
|
Tuple,
|
||||||
|
Callable,
|
||||||
|
Iterable,
|
||||||
|
Type,
|
||||||
|
Dict,
|
||||||
|
Any,
|
||||||
|
List,
|
||||||
|
Mapping,
|
||||||
|
Iterator,
|
||||||
|
Union,
|
||||||
|
Set,
|
||||||
)
|
)
|
||||||
|
|
||||||
import agate
|
import agate
|
||||||
import pytz
|
import pytz
|
||||||
|
|
||||||
from dbt.exceptions import (
|
from dbt.exceptions import (
|
||||||
raise_database_error, raise_compiler_error, invalid_type_error,
|
raise_database_error,
|
||||||
|
raise_compiler_error,
|
||||||
|
invalid_type_error,
|
||||||
get_relation_returned_multiple_results,
|
get_relation_returned_multiple_results,
|
||||||
InternalException, NotImplementedException, RuntimeException,
|
InternalException,
|
||||||
|
NotImplementedException,
|
||||||
|
RuntimeException,
|
||||||
)
|
)
|
||||||
from dbt import flags
|
from dbt import flags
|
||||||
|
|
||||||
@@ -25,10 +39,8 @@ from dbt.adapters.protocol import (
|
|||||||
)
|
)
|
||||||
from dbt.clients.agate_helper import empty_table, merge_tables, table_from_rows
|
from dbt.clients.agate_helper import empty_table, merge_tables, table_from_rows
|
||||||
from dbt.clients.jinja import MacroGenerator
|
from dbt.clients.jinja import MacroGenerator
|
||||||
from dbt.contracts.graph.compiled import (
|
from dbt.contracts.graph.compiled import CompileResultNode, CompiledSeedNode
|
||||||
CompileResultNode, CompiledSeedNode
|
from dbt.contracts.graph.manifest import Manifest, MacroManifest
|
||||||
)
|
|
||||||
from dbt.contracts.graph.manifest import Manifest
|
|
||||||
from dbt.contracts.graph.parsed import ParsedSeedNode
|
from dbt.contracts.graph.parsed import ParsedSeedNode
|
||||||
from dbt.exceptions import warn_or_error
|
from dbt.exceptions import warn_or_error
|
||||||
from dbt.node_types import NodeType
|
from dbt.node_types import NodeType
|
||||||
@@ -38,7 +50,10 @@ from dbt.utils import filter_null_values, executor
|
|||||||
from dbt.adapters.base.connections import Connection, AdapterResponse
|
from dbt.adapters.base.connections import Connection, AdapterResponse
|
||||||
from dbt.adapters.base.meta import AdapterMeta, available
|
from dbt.adapters.base.meta import AdapterMeta, available
|
||||||
from dbt.adapters.base.relation import (
|
from dbt.adapters.base.relation import (
|
||||||
ComponentName, BaseRelation, InformationSchema, SchemaSearchMap
|
ComponentName,
|
||||||
|
BaseRelation,
|
||||||
|
InformationSchema,
|
||||||
|
SchemaSearchMap,
|
||||||
)
|
)
|
||||||
from dbt.adapters.base import Column as BaseColumn
|
from dbt.adapters.base import Column as BaseColumn
|
||||||
from dbt.adapters.cache import RelationsCache
|
from dbt.adapters.cache import RelationsCache
|
||||||
@@ -47,15 +62,14 @@ from dbt.adapters.cache import RelationsCache
|
|||||||
SeedModel = Union[ParsedSeedNode, CompiledSeedNode]
|
SeedModel = Union[ParsedSeedNode, CompiledSeedNode]
|
||||||
|
|
||||||
|
|
||||||
GET_CATALOG_MACRO_NAME = 'get_catalog'
|
GET_CATALOG_MACRO_NAME = "get_catalog"
|
||||||
FRESHNESS_MACRO_NAME = 'collect_freshness'
|
FRESHNESS_MACRO_NAME = "collect_freshness"
|
||||||
|
|
||||||
|
|
||||||
def _expect_row_value(key: str, row: agate.Row):
|
def _expect_row_value(key: str, row: agate.Row):
|
||||||
if key not in row.keys():
|
if key not in row.keys():
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
'Got a row without "{}" column, columns: {}'
|
'Got a row without "{}" column, columns: {}'.format(key, row.keys())
|
||||||
.format(key, row.keys())
|
|
||||||
)
|
)
|
||||||
return row[key]
|
return row[key]
|
||||||
|
|
||||||
@@ -64,40 +78,37 @@ def _catalog_filter_schemas(manifest: Manifest) -> Callable[[agate.Row], bool]:
|
|||||||
"""Return a function that takes a row and decides if the row should be
|
"""Return a function that takes a row and decides if the row should be
|
||||||
included in the catalog output.
|
included in the catalog output.
|
||||||
"""
|
"""
|
||||||
schemas = frozenset((d.lower(), s.lower())
|
schemas = frozenset((d.lower(), s.lower()) for d, s in manifest.get_used_schemas())
|
||||||
for d, s in manifest.get_used_schemas())
|
|
||||||
|
|
||||||
def test(row: agate.Row) -> bool:
|
def test(row: agate.Row) -> bool:
|
||||||
table_database = _expect_row_value('table_database', row)
|
table_database = _expect_row_value("table_database", row)
|
||||||
table_schema = _expect_row_value('table_schema', row)
|
table_schema = _expect_row_value("table_schema", row)
|
||||||
# the schema may be present but None, which is not an error and should
|
# the schema may be present but None, which is not an error and should
|
||||||
# be filtered out
|
# be filtered out
|
||||||
if table_schema is None:
|
if table_schema is None:
|
||||||
return False
|
return False
|
||||||
return (table_database.lower(), table_schema.lower()) in schemas
|
return (table_database.lower(), table_schema.lower()) in schemas
|
||||||
|
|
||||||
return test
|
return test
|
||||||
|
|
||||||
|
|
||||||
def _utc(
|
def _utc(dt: Optional[datetime], source: BaseRelation, field_name: str) -> datetime:
|
||||||
dt: Optional[datetime], source: BaseRelation, field_name: str
|
|
||||||
) -> datetime:
|
|
||||||
"""If dt has a timezone, return a new datetime that's in UTC. Otherwise,
|
"""If dt has a timezone, return a new datetime that's in UTC. Otherwise,
|
||||||
assume the datetime is already for UTC and add the timezone.
|
assume the datetime is already for UTC and add the timezone.
|
||||||
"""
|
"""
|
||||||
if dt is None:
|
if dt is None:
|
||||||
raise raise_database_error(
|
raise raise_database_error(
|
||||||
"Expected a non-null value when querying field '{}' of table "
|
"Expected a non-null value when querying field '{}' of table "
|
||||||
" {} but received value 'null' instead".format(
|
" {} but received value 'null' instead".format(field_name, source)
|
||||||
field_name,
|
)
|
||||||
source))
|
|
||||||
|
|
||||||
elif not hasattr(dt, 'tzinfo'):
|
elif not hasattr(dt, "tzinfo"):
|
||||||
raise raise_database_error(
|
raise raise_database_error(
|
||||||
"Expected a timestamp value when querying field '{}' of table "
|
"Expected a timestamp value when querying field '{}' of table "
|
||||||
"{} but received value of type '{}' instead".format(
|
"{} but received value of type '{}' instead".format(
|
||||||
field_name,
|
field_name, source, type(dt).__name__
|
||||||
source,
|
)
|
||||||
type(dt).__name__))
|
)
|
||||||
|
|
||||||
elif dt.tzinfo:
|
elif dt.tzinfo:
|
||||||
return dt.astimezone(pytz.UTC)
|
return dt.astimezone(pytz.UTC)
|
||||||
@@ -107,7 +118,7 @@ def _utc(
|
|||||||
|
|
||||||
def _relation_name(rel: Optional[BaseRelation]) -> str:
|
def _relation_name(rel: Optional[BaseRelation]) -> str:
|
||||||
if rel is None:
|
if rel is None:
|
||||||
return 'null relation'
|
return "null relation"
|
||||||
else:
|
else:
|
||||||
return str(rel)
|
return str(rel)
|
||||||
|
|
||||||
@@ -148,6 +159,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
Macros:
|
Macros:
|
||||||
- get_catalog
|
- get_catalog
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Relation: Type[BaseRelation] = BaseRelation
|
Relation: Type[BaseRelation] = BaseRelation
|
||||||
Column: Type[BaseColumn] = BaseColumn
|
Column: Type[BaseColumn] = BaseColumn
|
||||||
ConnectionManager: Type[ConnectionManagerProtocol]
|
ConnectionManager: Type[ConnectionManagerProtocol]
|
||||||
@@ -160,7 +172,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.cache = RelationsCache()
|
self.cache = RelationsCache()
|
||||||
self.connections = self.ConnectionManager(config)
|
self.connections = self.ConnectionManager(config)
|
||||||
self._macro_manifest_lazy: Optional[Manifest] = None
|
self._macro_manifest_lazy: Optional[MacroManifest] = None
|
||||||
|
|
||||||
###
|
###
|
||||||
# Methods that pass through to the connection manager
|
# Methods that pass through to the connection manager
|
||||||
@@ -181,12 +193,12 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
self.connections.commit_if_has_connection()
|
self.connections.commit_if_has_connection()
|
||||||
|
|
||||||
def debug_query(self) -> None:
|
def debug_query(self) -> None:
|
||||||
self.execute('select 1 as id')
|
self.execute("select 1 as id")
|
||||||
|
|
||||||
def nice_connection_name(self) -> str:
|
def nice_connection_name(self) -> str:
|
||||||
conn = self.connections.get_if_exists()
|
conn = self.connections.get_if_exists()
|
||||||
if conn is None or conn.name is None:
|
if conn is None or conn.name is None:
|
||||||
return '<None>'
|
return "<None>"
|
||||||
return conn.name
|
return conn.name
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
@@ -204,13 +216,11 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
self.connections.query_header.reset()
|
self.connections.query_header.reset()
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def connection_for(
|
def connection_for(self, node: CompileResultNode) -> Iterator[None]:
|
||||||
self, node: CompileResultNode
|
|
||||||
) -> Iterator[None]:
|
|
||||||
with self.connection_named(node.unique_id, node):
|
with self.connection_named(node.unique_id, node):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@available.parse(lambda *a, **k: ('', empty_table()))
|
@available.parse(lambda *a, **k: ("", empty_table()))
|
||||||
def execute(
|
def execute(
|
||||||
self, sql: str, auto_begin: bool = False, fetch: bool = False
|
self, sql: str, auto_begin: bool = False, fetch: bool = False
|
||||||
) -> Tuple[Union[str, AdapterResponse], agate.Table]:
|
) -> Tuple[Union[str, AdapterResponse], agate.Table]:
|
||||||
@@ -224,16 +234,10 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
:return: A tuple of the status and the results (empty if fetch=False).
|
:return: A tuple of the status and the results (empty if fetch=False).
|
||||||
:rtype: Tuple[Union[str, AdapterResponse], agate.Table]
|
:rtype: Tuple[Union[str, AdapterResponse], agate.Table]
|
||||||
"""
|
"""
|
||||||
return self.connections.execute(
|
return self.connections.execute(sql=sql, auto_begin=auto_begin, fetch=fetch)
|
||||||
sql=sql,
|
|
||||||
auto_begin=auto_begin,
|
|
||||||
fetch=fetch
|
|
||||||
)
|
|
||||||
|
|
||||||
@available.parse(lambda *a, **k: ('', empty_table()))
|
@available.parse(lambda *a, **k: ("", empty_table()))
|
||||||
def get_partitions_metadata(
|
def get_partitions_metadata(self, table: str) -> Tuple[agate.Table]:
|
||||||
self, table: str
|
|
||||||
) -> Tuple[agate.Table]:
|
|
||||||
"""Obtain partitions metadata for a BigQuery partitioned table.
|
"""Obtain partitions metadata for a BigQuery partitioned table.
|
||||||
|
|
||||||
:param str table_id: a partitioned table id, in standard SQL format.
|
:param str table_id: a partitioned table id, in standard SQL format.
|
||||||
@@ -241,9 +245,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
https://cloud.google.com/bigquery/docs/creating-partitioned-tables#getting_partition_metadata_using_meta_tables.
|
https://cloud.google.com/bigquery/docs/creating-partitioned-tables#getting_partition_metadata_using_meta_tables.
|
||||||
:rtype: agate.Table
|
:rtype: agate.Table
|
||||||
"""
|
"""
|
||||||
return self.connections.get_partitions_metadata(
|
return self.connections.get_partitions_metadata(table=table)
|
||||||
table=table
|
|
||||||
)
|
|
||||||
|
|
||||||
###
|
###
|
||||||
# Methods that should never be overridden
|
# Methods that should never be overridden
|
||||||
@@ -259,21 +261,22 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
return cls.ConnectionManager.TYPE
|
return cls.ConnectionManager.TYPE
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _macro_manifest(self) -> Manifest:
|
def _macro_manifest(self) -> MacroManifest:
|
||||||
if self._macro_manifest_lazy is None:
|
if self._macro_manifest_lazy is None:
|
||||||
return self.load_macro_manifest()
|
return self.load_macro_manifest()
|
||||||
return self._macro_manifest_lazy
|
return self._macro_manifest_lazy
|
||||||
|
|
||||||
def check_macro_manifest(self) -> Optional[Manifest]:
|
def check_macro_manifest(self) -> Optional[MacroManifest]:
|
||||||
"""Return the internal manifest (used for executing macros) if it's
|
"""Return the internal manifest (used for executing macros) if it's
|
||||||
been initialized, otherwise return None.
|
been initialized, otherwise return None.
|
||||||
"""
|
"""
|
||||||
return self._macro_manifest_lazy
|
return self._macro_manifest_lazy
|
||||||
|
|
||||||
def load_macro_manifest(self) -> Manifest:
|
def load_macro_manifest(self) -> MacroManifest:
|
||||||
if self._macro_manifest_lazy is None:
|
if self._macro_manifest_lazy is None:
|
||||||
# avoid a circular import
|
# avoid a circular import
|
||||||
from dbt.parser.manifest import load_macro_manifest
|
from dbt.parser.manifest import load_macro_manifest
|
||||||
|
|
||||||
manifest = load_macro_manifest(
|
manifest = load_macro_manifest(
|
||||||
self.config, self.connections.set_query_header
|
self.config, self.connections.set_query_header
|
||||||
)
|
)
|
||||||
@@ -294,8 +297,9 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
return False
|
return False
|
||||||
elif (database, schema) not in self.cache:
|
elif (database, schema) not in self.cache:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
'On "{}": cache miss for schema "{}.{}", this is inefficient'
|
'On "{}": cache miss for schema "{}.{}", this is inefficient'.format(
|
||||||
.format(self.nice_connection_name(), database, schema)
|
self.nice_connection_name(), database, schema
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
@@ -310,8 +314,8 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
self.Relation.create_from(self.config, node).without_identifier()
|
self.Relation.create_from(self.config, node).without_identifier()
|
||||||
for node in manifest.nodes.values()
|
for node in manifest.nodes.values()
|
||||||
if (
|
if (
|
||||||
node.resource_type in NodeType.executable() and
|
node.resource_type in NodeType.executable()
|
||||||
not node.is_ephemeral_model
|
and not node.is_ephemeral_model
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -351,9 +355,9 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
for cache_schema in cache_schemas:
|
for cache_schema in cache_schemas:
|
||||||
fut = tpe.submit_connected(
|
fut = tpe.submit_connected(
|
||||||
self,
|
self,
|
||||||
f'list_{cache_schema.database}_{cache_schema.schema}',
|
f"list_{cache_schema.database}_{cache_schema.schema}",
|
||||||
self.list_relations_without_caching,
|
self.list_relations_without_caching,
|
||||||
cache_schema
|
cache_schema,
|
||||||
)
|
)
|
||||||
futures.append(fut)
|
futures.append(fut)
|
||||||
|
|
||||||
@@ -371,9 +375,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
cache_update.add((relation.database, relation.schema))
|
cache_update.add((relation.database, relation.schema))
|
||||||
self.cache.update_schemas(cache_update)
|
self.cache.update_schemas(cache_update)
|
||||||
|
|
||||||
def set_relations_cache(
|
def set_relations_cache(self, manifest: Manifest, clear: bool = False) -> None:
|
||||||
self, manifest: Manifest, clear: bool = False
|
|
||||||
) -> None:
|
|
||||||
"""Run a query that gets a populated cache of the relations in the
|
"""Run a query that gets a populated cache of the relations in the
|
||||||
database and set the cache on this adapter.
|
database and set the cache on this adapter.
|
||||||
"""
|
"""
|
||||||
@@ -391,12 +393,12 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
if relation is None:
|
if relation is None:
|
||||||
name = self.nice_connection_name()
|
name = self.nice_connection_name()
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
'Attempted to cache a null relation for {}'.format(name)
|
"Attempted to cache a null relation for {}".format(name)
|
||||||
)
|
)
|
||||||
if flags.USE_CACHE:
|
if flags.USE_CACHE:
|
||||||
self.cache.add(relation)
|
self.cache.add(relation)
|
||||||
# so jinja doesn't render things
|
# so jinja doesn't render things
|
||||||
return ''
|
return ""
|
||||||
|
|
||||||
@available
|
@available
|
||||||
def cache_dropped(self, relation: Optional[BaseRelation]) -> str:
|
def cache_dropped(self, relation: Optional[BaseRelation]) -> str:
|
||||||
@@ -406,11 +408,11 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
if relation is None:
|
if relation is None:
|
||||||
name = self.nice_connection_name()
|
name = self.nice_connection_name()
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
'Attempted to drop a null relation for {}'.format(name)
|
"Attempted to drop a null relation for {}".format(name)
|
||||||
)
|
)
|
||||||
if flags.USE_CACHE:
|
if flags.USE_CACHE:
|
||||||
self.cache.drop(relation)
|
self.cache.drop(relation)
|
||||||
return ''
|
return ""
|
||||||
|
|
||||||
@available
|
@available
|
||||||
def cache_renamed(
|
def cache_renamed(
|
||||||
@@ -426,13 +428,12 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
src_name = _relation_name(from_relation)
|
src_name = _relation_name(from_relation)
|
||||||
dst_name = _relation_name(to_relation)
|
dst_name = _relation_name(to_relation)
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
'Attempted to rename {} to {} for {}'
|
"Attempted to rename {} to {} for {}".format(src_name, dst_name, name)
|
||||||
.format(src_name, dst_name, name)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if flags.USE_CACHE:
|
if flags.USE_CACHE:
|
||||||
self.cache.rename(from_relation, to_relation)
|
self.cache.rename(from_relation, to_relation)
|
||||||
return ''
|
return ""
|
||||||
|
|
||||||
###
|
###
|
||||||
# Abstract methods for database-specific values, attributes, and types
|
# Abstract methods for database-specific values, attributes, and types
|
||||||
@@ -441,12 +442,13 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
def date_function(cls) -> str:
|
def date_function(cls) -> str:
|
||||||
"""Get the date function used by this adapter's database."""
|
"""Get the date function used by this adapter's database."""
|
||||||
raise NotImplementedException(
|
raise NotImplementedException(
|
||||||
'`date_function` is not implemented for this adapter!')
|
"`date_function` is not implemented for this adapter!"
|
||||||
|
)
|
||||||
|
|
||||||
@abc.abstractclassmethod
|
@abc.abstractclassmethod
|
||||||
def is_cancelable(cls) -> bool:
|
def is_cancelable(cls) -> bool:
|
||||||
raise NotImplementedException(
|
raise NotImplementedException(
|
||||||
'`is_cancelable` is not implemented for this adapter!'
|
"`is_cancelable` is not implemented for this adapter!"
|
||||||
)
|
)
|
||||||
|
|
||||||
###
|
###
|
||||||
@@ -456,7 +458,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
def list_schemas(self, database: str) -> List[str]:
|
def list_schemas(self, database: str) -> List[str]:
|
||||||
"""Get a list of existing schemas in database"""
|
"""Get a list of existing schemas in database"""
|
||||||
raise NotImplementedException(
|
raise NotImplementedException(
|
||||||
'`list_schemas` is not implemented for this adapter!'
|
"`list_schemas` is not implemented for this adapter!"
|
||||||
)
|
)
|
||||||
|
|
||||||
@available.parse(lambda *a, **k: False)
|
@available.parse(lambda *a, **k: False)
|
||||||
@@ -467,10 +469,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
and adapters should implement it if there is an optimized path (and
|
and adapters should implement it if there is an optimized path (and
|
||||||
there probably is)
|
there probably is)
|
||||||
"""
|
"""
|
||||||
search = (
|
search = (s.lower() for s in self.list_schemas(database=database))
|
||||||
s.lower() for s in
|
|
||||||
self.list_schemas(database=database)
|
|
||||||
)
|
|
||||||
return schema.lower() in search
|
return schema.lower() in search
|
||||||
|
|
||||||
###
|
###
|
||||||
@@ -484,7 +483,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
*Implementors must call self.cache.drop() to preserve cache state!*
|
*Implementors must call self.cache.drop() to preserve cache state!*
|
||||||
"""
|
"""
|
||||||
raise NotImplementedException(
|
raise NotImplementedException(
|
||||||
'`drop_relation` is not implemented for this adapter!'
|
"`drop_relation` is not implemented for this adapter!"
|
||||||
)
|
)
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
@@ -492,7 +491,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
def truncate_relation(self, relation: BaseRelation) -> None:
|
def truncate_relation(self, relation: BaseRelation) -> None:
|
||||||
"""Truncate the given relation."""
|
"""Truncate the given relation."""
|
||||||
raise NotImplementedException(
|
raise NotImplementedException(
|
||||||
'`truncate_relation` is not implemented for this adapter!'
|
"`truncate_relation` is not implemented for this adapter!"
|
||||||
)
|
)
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
@@ -505,36 +504,30 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
Implementors must call self.cache.rename() to preserve cache state.
|
Implementors must call self.cache.rename() to preserve cache state.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedException(
|
raise NotImplementedException(
|
||||||
'`rename_relation` is not implemented for this adapter!'
|
"`rename_relation` is not implemented for this adapter!"
|
||||||
)
|
)
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
@available.parse_list
|
@available.parse_list
|
||||||
def get_columns_in_relation(
|
def get_columns_in_relation(self, relation: BaseRelation) -> List[BaseColumn]:
|
||||||
self, relation: BaseRelation
|
|
||||||
) -> List[BaseColumn]:
|
|
||||||
"""Get a list of the columns in the given Relation."""
|
"""Get a list of the columns in the given Relation."""
|
||||||
raise NotImplementedException(
|
raise NotImplementedException(
|
||||||
'`get_columns_in_relation` is not implemented for this adapter!'
|
"`get_columns_in_relation` is not implemented for this adapter!"
|
||||||
)
|
)
|
||||||
|
|
||||||
@available.deprecated('get_columns_in_relation', lambda *a, **k: [])
|
@available.deprecated("get_columns_in_relation", lambda *a, **k: [])
|
||||||
def get_columns_in_table(
|
def get_columns_in_table(self, schema: str, identifier: str) -> List[BaseColumn]:
|
||||||
self, schema: str, identifier: str
|
|
||||||
) -> List[BaseColumn]:
|
|
||||||
"""DEPRECATED: Get a list of the columns in the given table."""
|
"""DEPRECATED: Get a list of the columns in the given table."""
|
||||||
relation = self.Relation.create(
|
relation = self.Relation.create(
|
||||||
database=self.config.credentials.database,
|
database=self.config.credentials.database,
|
||||||
schema=schema,
|
schema=schema,
|
||||||
identifier=identifier,
|
identifier=identifier,
|
||||||
quote_policy=self.config.quoting
|
quote_policy=self.config.quoting,
|
||||||
)
|
)
|
||||||
return self.get_columns_in_relation(relation)
|
return self.get_columns_in_relation(relation)
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def expand_column_types(
|
def expand_column_types(self, goal: BaseRelation, current: BaseRelation) -> None:
|
||||||
self, goal: BaseRelation, current: BaseRelation
|
|
||||||
) -> None:
|
|
||||||
"""Expand the current table's types to match the goal table. (passable)
|
"""Expand the current table's types to match the goal table. (passable)
|
||||||
|
|
||||||
:param self.Relation goal: A relation that currently exists in the
|
:param self.Relation goal: A relation that currently exists in the
|
||||||
@@ -543,7 +536,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
database with columns of unspecified types.
|
database with columns of unspecified types.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedException(
|
raise NotImplementedException(
|
||||||
'`expand_target_column_types` is not implemented for this adapter!'
|
"`expand_target_column_types` is not implemented for this adapter!"
|
||||||
)
|
)
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
@@ -560,8 +553,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
:rtype: List[self.Relation]
|
:rtype: List[self.Relation]
|
||||||
"""
|
"""
|
||||||
raise NotImplementedException(
|
raise NotImplementedException(
|
||||||
'`list_relations_without_caching` is not implemented for this '
|
"`list_relations_without_caching` is not implemented for this " "adapter!"
|
||||||
'adapter!'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
###
|
###
|
||||||
@@ -576,32 +568,33 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
"""
|
"""
|
||||||
if not isinstance(from_relation, self.Relation):
|
if not isinstance(from_relation, self.Relation):
|
||||||
invalid_type_error(
|
invalid_type_error(
|
||||||
method_name='get_missing_columns',
|
method_name="get_missing_columns",
|
||||||
arg_name='from_relation',
|
arg_name="from_relation",
|
||||||
got_value=from_relation,
|
got_value=from_relation,
|
||||||
expected_type=self.Relation)
|
expected_type=self.Relation,
|
||||||
|
)
|
||||||
|
|
||||||
if not isinstance(to_relation, self.Relation):
|
if not isinstance(to_relation, self.Relation):
|
||||||
invalid_type_error(
|
invalid_type_error(
|
||||||
method_name='get_missing_columns',
|
method_name="get_missing_columns",
|
||||||
arg_name='to_relation',
|
arg_name="to_relation",
|
||||||
got_value=to_relation,
|
got_value=to_relation,
|
||||||
expected_type=self.Relation)
|
expected_type=self.Relation,
|
||||||
|
)
|
||||||
|
|
||||||
from_columns = {
|
from_columns = {
|
||||||
col.name: col for col in
|
col.name: col for col in self.get_columns_in_relation(from_relation)
|
||||||
self.get_columns_in_relation(from_relation)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
to_columns = {
|
to_columns = {
|
||||||
col.name: col for col in
|
col.name: col for col in self.get_columns_in_relation(to_relation)
|
||||||
self.get_columns_in_relation(to_relation)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
missing_columns = set(from_columns.keys()) - set(to_columns.keys())
|
missing_columns = set(from_columns.keys()) - set(to_columns.keys())
|
||||||
|
|
||||||
return [
|
return [
|
||||||
col for (col_name, col) in from_columns.items()
|
col
|
||||||
|
for (col_name, col) in from_columns.items()
|
||||||
if col_name in missing_columns
|
if col_name in missing_columns
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -616,18 +609,19 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
"""
|
"""
|
||||||
if not isinstance(relation, self.Relation):
|
if not isinstance(relation, self.Relation):
|
||||||
invalid_type_error(
|
invalid_type_error(
|
||||||
method_name='valid_snapshot_target',
|
method_name="valid_snapshot_target",
|
||||||
arg_name='relation',
|
arg_name="relation",
|
||||||
got_value=relation,
|
got_value=relation,
|
||||||
expected_type=self.Relation)
|
expected_type=self.Relation,
|
||||||
|
)
|
||||||
|
|
||||||
columns = self.get_columns_in_relation(relation)
|
columns = self.get_columns_in_relation(relation)
|
||||||
names = set(c.name.lower() for c in columns)
|
names = set(c.name.lower() for c in columns)
|
||||||
expanded_keys = ('scd_id', 'valid_from', 'valid_to')
|
expanded_keys = ("scd_id", "valid_from", "valid_to")
|
||||||
extra = []
|
extra = []
|
||||||
missing = []
|
missing = []
|
||||||
for legacy in expanded_keys:
|
for legacy in expanded_keys:
|
||||||
desired = 'dbt_' + legacy
|
desired = "dbt_" + legacy
|
||||||
if desired not in names:
|
if desired not in names:
|
||||||
missing.append(desired)
|
missing.append(desired)
|
||||||
if legacy in names:
|
if legacy in names:
|
||||||
@@ -637,13 +631,13 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
if extra:
|
if extra:
|
||||||
msg = (
|
msg = (
|
||||||
'Snapshot target has ("{}") but not ("{}") - is it an '
|
'Snapshot target has ("{}") but not ("{}") - is it an '
|
||||||
'unmigrated previous version archive?'
|
"unmigrated previous version archive?".format(
|
||||||
.format('", "'.join(extra), '", "'.join(missing))
|
'", "'.join(extra), '", "'.join(missing)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
msg = (
|
msg = 'Snapshot target is not a snapshot table (missing "{}")'.format(
|
||||||
'Snapshot target is not a snapshot table (missing "{}")'
|
'", "'.join(missing)
|
||||||
.format('", "'.join(missing))
|
|
||||||
)
|
)
|
||||||
raise_compiler_error(msg)
|
raise_compiler_error(msg)
|
||||||
|
|
||||||
@@ -653,17 +647,19 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
) -> None:
|
) -> None:
|
||||||
if not isinstance(from_relation, self.Relation):
|
if not isinstance(from_relation, self.Relation):
|
||||||
invalid_type_error(
|
invalid_type_error(
|
||||||
method_name='expand_target_column_types',
|
method_name="expand_target_column_types",
|
||||||
arg_name='from_relation',
|
arg_name="from_relation",
|
||||||
got_value=from_relation,
|
got_value=from_relation,
|
||||||
expected_type=self.Relation)
|
expected_type=self.Relation,
|
||||||
|
)
|
||||||
|
|
||||||
if not isinstance(to_relation, self.Relation):
|
if not isinstance(to_relation, self.Relation):
|
||||||
invalid_type_error(
|
invalid_type_error(
|
||||||
method_name='expand_target_column_types',
|
method_name="expand_target_column_types",
|
||||||
arg_name='to_relation',
|
arg_name="to_relation",
|
||||||
got_value=to_relation,
|
got_value=to_relation,
|
||||||
expected_type=self.Relation)
|
expected_type=self.Relation,
|
||||||
|
)
|
||||||
|
|
||||||
self.expand_column_types(from_relation, to_relation)
|
self.expand_column_types(from_relation, to_relation)
|
||||||
|
|
||||||
@@ -676,38 +672,41 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
schema_relation = self.Relation.create(
|
schema_relation = self.Relation.create(
|
||||||
database=database,
|
database=database,
|
||||||
schema=schema,
|
schema=schema,
|
||||||
identifier='',
|
identifier="",
|
||||||
quote_policy=self.config.quoting
|
quote_policy=self.config.quoting,
|
||||||
).without_identifier()
|
).without_identifier()
|
||||||
|
|
||||||
# we can't build the relations cache because we don't have a
|
# we can't build the relations cache because we don't have a
|
||||||
# manifest so we can't run any operations.
|
# manifest so we can't run any operations.
|
||||||
relations = self.list_relations_without_caching(
|
relations = self.list_relations_without_caching(schema_relation)
|
||||||
schema_relation
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug('with database={}, schema={}, relations={}'
|
logger.debug(
|
||||||
.format(database, schema, relations))
|
"with database={}, schema={}, relations={}".format(
|
||||||
|
database, schema, relations
|
||||||
|
)
|
||||||
|
)
|
||||||
return relations
|
return relations
|
||||||
|
|
||||||
def _make_match_kwargs(
|
def _make_match_kwargs(
|
||||||
self, database: str, schema: str, identifier: str
|
self, database: str, schema: str, identifier: str
|
||||||
) -> Dict[str, str]:
|
) -> Dict[str, str]:
|
||||||
quoting = self.config.quoting
|
quoting = self.config.quoting
|
||||||
if identifier is not None and quoting['identifier'] is False:
|
if identifier is not None and quoting["identifier"] is False:
|
||||||
identifier = identifier.lower()
|
identifier = identifier.lower()
|
||||||
|
|
||||||
if schema is not None and quoting['schema'] is False:
|
if schema is not None and quoting["schema"] is False:
|
||||||
schema = schema.lower()
|
schema = schema.lower()
|
||||||
|
|
||||||
if database is not None and quoting['database'] is False:
|
if database is not None and quoting["database"] is False:
|
||||||
database = database.lower()
|
database = database.lower()
|
||||||
|
|
||||||
return filter_null_values({
|
return filter_null_values(
|
||||||
'database': database,
|
{
|
||||||
'identifier': identifier,
|
"database": database,
|
||||||
'schema': schema,
|
"identifier": identifier,
|
||||||
})
|
"schema": schema,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def _make_match(
|
def _make_match(
|
||||||
self,
|
self,
|
||||||
@@ -733,25 +732,22 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
) -> Optional[BaseRelation]:
|
) -> Optional[BaseRelation]:
|
||||||
relations_list = self.list_relations(database, schema)
|
relations_list = self.list_relations(database, schema)
|
||||||
|
|
||||||
matches = self._make_match(relations_list, database, schema,
|
matches = self._make_match(relations_list, database, schema, identifier)
|
||||||
identifier)
|
|
||||||
|
|
||||||
if len(matches) > 1:
|
if len(matches) > 1:
|
||||||
kwargs = {
|
kwargs = {
|
||||||
'identifier': identifier,
|
"identifier": identifier,
|
||||||
'schema': schema,
|
"schema": schema,
|
||||||
'database': database,
|
"database": database,
|
||||||
}
|
}
|
||||||
get_relation_returned_multiple_results(
|
get_relation_returned_multiple_results(kwargs, matches)
|
||||||
kwargs, matches
|
|
||||||
)
|
|
||||||
|
|
||||||
elif matches:
|
elif matches:
|
||||||
return matches[0]
|
return matches[0]
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@available.deprecated('get_relation', lambda *a, **k: False)
|
@available.deprecated("get_relation", lambda *a, **k: False)
|
||||||
def already_exists(self, schema: str, name: str) -> bool:
|
def already_exists(self, schema: str, name: str) -> bool:
|
||||||
"""DEPRECATED: Return if a model already exists in the database"""
|
"""DEPRECATED: Return if a model already exists in the database"""
|
||||||
database = self.config.credentials.database
|
database = self.config.credentials.database
|
||||||
@@ -767,7 +763,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
def create_schema(self, relation: BaseRelation):
|
def create_schema(self, relation: BaseRelation):
|
||||||
"""Create the given schema if it does not exist."""
|
"""Create the given schema if it does not exist."""
|
||||||
raise NotImplementedException(
|
raise NotImplementedException(
|
||||||
'`create_schema` is not implemented for this adapter!'
|
"`create_schema` is not implemented for this adapter!"
|
||||||
)
|
)
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
@@ -775,16 +771,14 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
def drop_schema(self, relation: BaseRelation):
|
def drop_schema(self, relation: BaseRelation):
|
||||||
"""Drop the given schema (and everything in it) if it exists."""
|
"""Drop the given schema (and everything in it) if it exists."""
|
||||||
raise NotImplementedException(
|
raise NotImplementedException(
|
||||||
'`drop_schema` is not implemented for this adapter!'
|
"`drop_schema` is not implemented for this adapter!"
|
||||||
)
|
)
|
||||||
|
|
||||||
@available
|
@available
|
||||||
@abc.abstractclassmethod
|
@abc.abstractclassmethod
|
||||||
def quote(cls, identifier: str) -> str:
|
def quote(cls, identifier: str) -> str:
|
||||||
"""Quote the given identifier, as appropriate for the database."""
|
"""Quote the given identifier, as appropriate for the database."""
|
||||||
raise NotImplementedException(
|
raise NotImplementedException("`quote` is not implemented for this adapter!")
|
||||||
'`quote` is not implemented for this adapter!'
|
|
||||||
)
|
|
||||||
|
|
||||||
@available
|
@available
|
||||||
def quote_as_configured(self, identifier: str, quote_key: str) -> str:
|
def quote_as_configured(self, identifier: str, quote_key: str) -> str:
|
||||||
@@ -806,19 +800,17 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
return identifier
|
return identifier
|
||||||
|
|
||||||
@available
|
@available
|
||||||
def quote_seed_column(
|
def quote_seed_column(self, column: str, quote_config: Optional[bool]) -> str:
|
||||||
self, column: str, quote_config: Optional[bool]
|
|
||||||
) -> str:
|
|
||||||
# this is the default for now
|
# this is the default for now
|
||||||
quote_columns: bool = False
|
quote_columns: bool = False
|
||||||
if isinstance(quote_config, bool):
|
if isinstance(quote_config, bool):
|
||||||
quote_columns = quote_config
|
quote_columns = quote_config
|
||||||
elif quote_config is None:
|
elif quote_config is None:
|
||||||
deprecations.warn('column-quoting-unset')
|
deprecations.warn("column-quoting-unset")
|
||||||
else:
|
else:
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
f'The seed configuration value of "quote_columns" has an '
|
f'The seed configuration value of "quote_columns" has an '
|
||||||
f'invalid type {type(quote_config)}'
|
f"invalid type {type(quote_config)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if quote_columns:
|
if quote_columns:
|
||||||
@@ -831,9 +823,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
# converting agate types into their sql equivalents.
|
# converting agate types into their sql equivalents.
|
||||||
###
|
###
|
||||||
@abc.abstractclassmethod
|
@abc.abstractclassmethod
|
||||||
def convert_text_type(
|
def convert_text_type(cls, agate_table: agate.Table, col_idx: int) -> str:
|
||||||
cls, agate_table: agate.Table, col_idx: int
|
|
||||||
) -> str:
|
|
||||||
"""Return the type in the database that best maps to the agate.Text
|
"""Return the type in the database that best maps to the agate.Text
|
||||||
type for the given agate table and column index.
|
type for the given agate table and column index.
|
||||||
|
|
||||||
@@ -842,12 +832,11 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
:return: The name of the type in the database
|
:return: The name of the type in the database
|
||||||
"""
|
"""
|
||||||
raise NotImplementedException(
|
raise NotImplementedException(
|
||||||
'`convert_text_type` is not implemented for this adapter!')
|
"`convert_text_type` is not implemented for this adapter!"
|
||||||
|
)
|
||||||
|
|
||||||
@abc.abstractclassmethod
|
@abc.abstractclassmethod
|
||||||
def convert_number_type(
|
def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str:
|
||||||
cls, agate_table: agate.Table, col_idx: int
|
|
||||||
) -> str:
|
|
||||||
"""Return the type in the database that best maps to the agate.Number
|
"""Return the type in the database that best maps to the agate.Number
|
||||||
type for the given agate table and column index.
|
type for the given agate table and column index.
|
||||||
|
|
||||||
@@ -856,12 +845,11 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
:return: The name of the type in the database
|
:return: The name of the type in the database
|
||||||
"""
|
"""
|
||||||
raise NotImplementedException(
|
raise NotImplementedException(
|
||||||
'`convert_number_type` is not implemented for this adapter!')
|
"`convert_number_type` is not implemented for this adapter!"
|
||||||
|
)
|
||||||
|
|
||||||
@abc.abstractclassmethod
|
@abc.abstractclassmethod
|
||||||
def convert_boolean_type(
|
def convert_boolean_type(cls, agate_table: agate.Table, col_idx: int) -> str:
|
||||||
cls, agate_table: agate.Table, col_idx: int
|
|
||||||
) -> str:
|
|
||||||
"""Return the type in the database that best maps to the agate.Boolean
|
"""Return the type in the database that best maps to the agate.Boolean
|
||||||
type for the given agate table and column index.
|
type for the given agate table and column index.
|
||||||
|
|
||||||
@@ -870,12 +858,11 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
:return: The name of the type in the database
|
:return: The name of the type in the database
|
||||||
"""
|
"""
|
||||||
raise NotImplementedException(
|
raise NotImplementedException(
|
||||||
'`convert_boolean_type` is not implemented for this adapter!')
|
"`convert_boolean_type` is not implemented for this adapter!"
|
||||||
|
)
|
||||||
|
|
||||||
@abc.abstractclassmethod
|
@abc.abstractclassmethod
|
||||||
def convert_datetime_type(
|
def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str:
|
||||||
cls, agate_table: agate.Table, col_idx: int
|
|
||||||
) -> str:
|
|
||||||
"""Return the type in the database that best maps to the agate.DateTime
|
"""Return the type in the database that best maps to the agate.DateTime
|
||||||
type for the given agate table and column index.
|
type for the given agate table and column index.
|
||||||
|
|
||||||
@@ -884,7 +871,8 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
:return: The name of the type in the database
|
:return: The name of the type in the database
|
||||||
"""
|
"""
|
||||||
raise NotImplementedException(
|
raise NotImplementedException(
|
||||||
'`convert_datetime_type` is not implemented for this adapter!')
|
"`convert_datetime_type` is not implemented for this adapter!"
|
||||||
|
)
|
||||||
|
|
||||||
@abc.abstractclassmethod
|
@abc.abstractclassmethod
|
||||||
def convert_date_type(cls, agate_table: agate.Table, col_idx: int) -> str:
|
def convert_date_type(cls, agate_table: agate.Table, col_idx: int) -> str:
|
||||||
@@ -896,7 +884,8 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
:return: The name of the type in the database
|
:return: The name of the type in the database
|
||||||
"""
|
"""
|
||||||
raise NotImplementedException(
|
raise NotImplementedException(
|
||||||
'`convert_date_type` is not implemented for this adapter!')
|
"`convert_date_type` is not implemented for this adapter!"
|
||||||
|
)
|
||||||
|
|
||||||
@abc.abstractclassmethod
|
@abc.abstractclassmethod
|
||||||
def convert_time_type(cls, agate_table: agate.Table, col_idx: int) -> str:
|
def convert_time_type(cls, agate_table: agate.Table, col_idx: int) -> str:
|
||||||
@@ -908,13 +897,12 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
:return: The name of the type in the database
|
:return: The name of the type in the database
|
||||||
"""
|
"""
|
||||||
raise NotImplementedException(
|
raise NotImplementedException(
|
||||||
'`convert_time_type` is not implemented for this adapter!')
|
"`convert_time_type` is not implemented for this adapter!"
|
||||||
|
)
|
||||||
|
|
||||||
@available
|
@available
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_type(
|
def convert_type(cls, agate_table: agate.Table, col_idx: int) -> Optional[str]:
|
||||||
cls, agate_table: agate.Table, col_idx: int
|
|
||||||
) -> Optional[str]:
|
|
||||||
return cls.convert_agate_type(agate_table, col_idx)
|
return cls.convert_agate_type(agate_table, col_idx)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -963,7 +951,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
:param release: Ignored.
|
:param release: Ignored.
|
||||||
"""
|
"""
|
||||||
if release is not False:
|
if release is not False:
|
||||||
deprecations.warn('execute-macro-release')
|
deprecations.warn("execute-macro-release")
|
||||||
if kwargs is None:
|
if kwargs is None:
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if context_override is None:
|
if context_override is None:
|
||||||
@@ -977,28 +965,27 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
)
|
)
|
||||||
if macro is None:
|
if macro is None:
|
||||||
if project is None:
|
if project is None:
|
||||||
package_name = 'any package'
|
package_name = "any package"
|
||||||
else:
|
else:
|
||||||
package_name = 'the "{}" package'.format(project)
|
package_name = 'the "{}" package'.format(project)
|
||||||
|
|
||||||
raise RuntimeException(
|
raise RuntimeException(
|
||||||
'dbt could not find a macro with the name "{}" in {}'
|
'dbt could not find a macro with the name "{}" in {}'.format(
|
||||||
.format(macro_name, package_name)
|
macro_name, package_name
|
||||||
|
)
|
||||||
)
|
)
|
||||||
# This causes a reference cycle, as generate_runtime_macro()
|
# This causes a reference cycle, as generate_runtime_macro()
|
||||||
# ends up calling get_adapter, so the import has to be here.
|
# ends up calling get_adapter, so the import has to be here.
|
||||||
from dbt.context.providers import generate_runtime_macro
|
from dbt.context.providers import generate_runtime_macro
|
||||||
|
|
||||||
macro_context = generate_runtime_macro(
|
macro_context = generate_runtime_macro(
|
||||||
macro=macro,
|
macro=macro, config=self.config, manifest=manifest, package_name=project
|
||||||
config=self.config,
|
|
||||||
manifest=manifest,
|
|
||||||
package_name=project
|
|
||||||
)
|
)
|
||||||
macro_context.update(context_override)
|
macro_context.update(context_override)
|
||||||
|
|
||||||
macro_function = MacroGenerator(macro, macro_context)
|
macro_function = MacroGenerator(macro, macro_context)
|
||||||
|
|
||||||
with self.connections.exception_handler(f'macro {macro_name}'):
|
with self.connections.exception_handler(f"macro {macro_name}"):
|
||||||
result = macro_function(**kwargs)
|
result = macro_function(**kwargs)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -1013,7 +1000,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
table = table_from_rows(
|
table = table_from_rows(
|
||||||
table.rows,
|
table.rows,
|
||||||
table.column_names,
|
table.column_names,
|
||||||
text_only_columns=['table_database', 'table_schema', 'table_name']
|
text_only_columns=["table_database", "table_schema", "table_name"],
|
||||||
)
|
)
|
||||||
return table.where(_catalog_filter_schemas(manifest))
|
return table.where(_catalog_filter_schemas(manifest))
|
||||||
|
|
||||||
@@ -1024,10 +1011,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
manifest: Manifest,
|
manifest: Manifest,
|
||||||
) -> agate.Table:
|
) -> agate.Table:
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {"information_schema": information_schema, "schemas": schemas}
|
||||||
'information_schema': information_schema,
|
|
||||||
'schemas': schemas
|
|
||||||
}
|
|
||||||
table = self.execute_macro(
|
table = self.execute_macro(
|
||||||
GET_CATALOG_MACRO_NAME,
|
GET_CATALOG_MACRO_NAME,
|
||||||
kwargs=kwargs,
|
kwargs=kwargs,
|
||||||
@@ -1039,9 +1023,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
results = self._catalog_filter_table(table, manifest)
|
results = self._catalog_filter_table(table, manifest)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def get_catalog(
|
def get_catalog(self, manifest: Manifest) -> Tuple[agate.Table, List[Exception]]:
|
||||||
self, manifest: Manifest
|
|
||||||
) -> Tuple[agate.Table, List[Exception]]:
|
|
||||||
schema_map = self._get_catalog_schemas(manifest)
|
schema_map = self._get_catalog_schemas(manifest)
|
||||||
|
|
||||||
with executor(self.config) as tpe:
|
with executor(self.config) as tpe:
|
||||||
@@ -1049,14 +1031,10 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
for info, schemas in schema_map.items():
|
for info, schemas in schema_map.items():
|
||||||
if len(schemas) == 0:
|
if len(schemas) == 0:
|
||||||
continue
|
continue
|
||||||
name = '.'.join([
|
name = ".".join([str(info.database), "information_schema"])
|
||||||
str(info.database),
|
|
||||||
'information_schema'
|
|
||||||
])
|
|
||||||
|
|
||||||
fut = tpe.submit_connected(
|
fut = tpe.submit_connected(
|
||||||
self, name,
|
self, name, self._get_one_catalog, info, schemas, manifest
|
||||||
self._get_one_catalog, info, schemas, manifest
|
|
||||||
)
|
)
|
||||||
futures.append(fut)
|
futures.append(fut)
|
||||||
|
|
||||||
@@ -1073,20 +1051,18 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
source: BaseRelation,
|
source: BaseRelation,
|
||||||
loaded_at_field: str,
|
loaded_at_field: str,
|
||||||
filter: Optional[str],
|
filter: Optional[str],
|
||||||
manifest: Optional[Manifest] = None
|
manifest: Optional[Manifest] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Calculate the freshness of sources in dbt, and return it"""
|
"""Calculate the freshness of sources in dbt, and return it"""
|
||||||
kwargs: Dict[str, Any] = {
|
kwargs: Dict[str, Any] = {
|
||||||
'source': source,
|
"source": source,
|
||||||
'loaded_at_field': loaded_at_field,
|
"loaded_at_field": loaded_at_field,
|
||||||
'filter': filter,
|
"filter": filter,
|
||||||
}
|
}
|
||||||
|
|
||||||
# run the macro
|
# run the macro
|
||||||
table = self.execute_macro(
|
table = self.execute_macro(
|
||||||
FRESHNESS_MACRO_NAME,
|
FRESHNESS_MACRO_NAME, kwargs=kwargs, manifest=manifest
|
||||||
kwargs=kwargs,
|
|
||||||
manifest=manifest
|
|
||||||
)
|
)
|
||||||
# now we have a 1-row table of the maximum `loaded_at_field` value and
|
# now we have a 1-row table of the maximum `loaded_at_field` value and
|
||||||
# the current time according to the db.
|
# the current time according to the db.
|
||||||
@@ -1106,9 +1082,9 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
snapshotted_at = _utc(table[0][1], source, loaded_at_field)
|
snapshotted_at = _utc(table[0][1], source, loaded_at_field)
|
||||||
age = (snapshotted_at - max_loaded_at).total_seconds()
|
age = (snapshotted_at - max_loaded_at).total_seconds()
|
||||||
return {
|
return {
|
||||||
'max_loaded_at': max_loaded_at,
|
"max_loaded_at": max_loaded_at,
|
||||||
'snapshotted_at': snapshotted_at,
|
"snapshotted_at": snapshotted_at,
|
||||||
'age': age,
|
"age": age,
|
||||||
}
|
}
|
||||||
|
|
||||||
def pre_model_hook(self, config: Mapping[str, Any]) -> Any:
|
def pre_model_hook(self, config: Mapping[str, Any]) -> Any:
|
||||||
@@ -1138,6 +1114,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
|
|
||||||
def get_compiler(self):
|
def get_compiler(self):
|
||||||
from dbt.compilation import Compiler
|
from dbt.compilation import Compiler
|
||||||
|
|
||||||
return Compiler(self.config)
|
return Compiler(self.config)
|
||||||
|
|
||||||
# Methods used in adapter tests
|
# Methods used in adapter tests
|
||||||
@@ -1148,13 +1125,13 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
clause: str,
|
clause: str,
|
||||||
where_clause: Optional[str] = None,
|
where_clause: Optional[str] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
clause = f'update {dst_name} set {dst_column} = {clause}'
|
clause = f"update {dst_name} set {dst_column} = {clause}"
|
||||||
if where_clause is not None:
|
if where_clause is not None:
|
||||||
clause += f' where {where_clause}'
|
clause += f" where {where_clause}"
|
||||||
return clause
|
return clause
|
||||||
|
|
||||||
def timestamp_add_sql(
|
def timestamp_add_sql(
|
||||||
self, add_to: str, number: int = 1, interval: str = 'hour'
|
self, add_to: str, number: int = 1, interval: str = "hour"
|
||||||
) -> str:
|
) -> str:
|
||||||
# for backwards compatibility, we're compelled to set some sort of
|
# for backwards compatibility, we're compelled to set some sort of
|
||||||
# default. A lot of searching has lead me to believe that the
|
# default. A lot of searching has lead me to believe that the
|
||||||
@@ -1163,23 +1140,24 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
return f"{add_to} + interval '{number} {interval}'"
|
return f"{add_to} + interval '{number} {interval}'"
|
||||||
|
|
||||||
def string_add_sql(
|
def string_add_sql(
|
||||||
self, add_to: str, value: str, location='append',
|
self,
|
||||||
|
add_to: str,
|
||||||
|
value: str,
|
||||||
|
location="append",
|
||||||
) -> str:
|
) -> str:
|
||||||
if location == 'append':
|
if location == "append":
|
||||||
return f"{add_to} || '{value}'"
|
return f"{add_to} || '{value}'"
|
||||||
elif location == 'prepend':
|
elif location == "prepend":
|
||||||
return f"'{value}' || {add_to}"
|
return f"'{value}' || {add_to}"
|
||||||
else:
|
else:
|
||||||
raise RuntimeException(
|
raise RuntimeException(f'Got an unexpected location value of "{location}"')
|
||||||
f'Got an unexpected location value of "{location}"'
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_rows_different_sql(
|
def get_rows_different_sql(
|
||||||
self,
|
self,
|
||||||
relation_a: BaseRelation,
|
relation_a: BaseRelation,
|
||||||
relation_b: BaseRelation,
|
relation_b: BaseRelation,
|
||||||
column_names: Optional[List[str]] = None,
|
column_names: Optional[List[str]] = None,
|
||||||
except_operator: str = 'EXCEPT',
|
except_operator: str = "EXCEPT",
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generate SQL for a query that returns a single row with a two
|
"""Generate SQL for a query that returns a single row with a two
|
||||||
columns: the number of rows that are different between the two
|
columns: the number of rows that are different between the two
|
||||||
@@ -1192,7 +1170,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
names = sorted((self.quote(c.name) for c in columns))
|
names = sorted((self.quote(c.name) for c in columns))
|
||||||
else:
|
else:
|
||||||
names = sorted((self.quote(n) for n in column_names))
|
names = sorted((self.quote(n) for n in column_names))
|
||||||
columns_csv = ', '.join(names)
|
columns_csv = ", ".join(names)
|
||||||
|
|
||||||
sql = COLUMNS_EQUAL_SQL.format(
|
sql = COLUMNS_EQUAL_SQL.format(
|
||||||
columns=columns_csv,
|
columns=columns_csv,
|
||||||
@@ -1204,7 +1182,7 @@ class BaseAdapter(metaclass=AdapterMeta):
|
|||||||
return sql
|
return sql
|
||||||
|
|
||||||
|
|
||||||
COLUMNS_EQUAL_SQL = '''
|
COLUMNS_EQUAL_SQL = """
|
||||||
with diff_count as (
|
with diff_count as (
|
||||||
SELECT
|
SELECT
|
||||||
1 as id,
|
1 as id,
|
||||||
@@ -1230,11 +1208,11 @@ select
|
|||||||
diff_count.num_missing as num_mismatched
|
diff_count.num_missing as num_mismatched
|
||||||
from row_count_diff
|
from row_count_diff
|
||||||
join diff_count using (id)
|
join diff_count using (id)
|
||||||
'''.strip()
|
""".strip()
|
||||||
|
|
||||||
|
|
||||||
def catch_as_completed(
|
def catch_as_completed(
|
||||||
futures # typing: List[Future[agate.Table]]
|
futures, # typing: List[Future[agate.Table]]
|
||||||
) -> Tuple[agate.Table, List[Exception]]:
|
) -> Tuple[agate.Table, List[Exception]]:
|
||||||
|
|
||||||
# catalogs: agate.Table = agate.Table(rows=[])
|
# catalogs: agate.Table = agate.Table(rows=[])
|
||||||
@@ -1247,15 +1225,10 @@ def catch_as_completed(
|
|||||||
if exc is None:
|
if exc is None:
|
||||||
catalog = future.result()
|
catalog = future.result()
|
||||||
tables.append(catalog)
|
tables.append(catalog)
|
||||||
elif (
|
elif isinstance(exc, KeyboardInterrupt) or not isinstance(exc, Exception):
|
||||||
isinstance(exc, KeyboardInterrupt) or
|
|
||||||
not isinstance(exc, Exception)
|
|
||||||
):
|
|
||||||
raise exc
|
raise exc
|
||||||
else:
|
else:
|
||||||
warn_or_error(
|
warn_or_error(f"Encountered an error while generating catalog: {str(exc)}")
|
||||||
f'Encountered an error while generating catalog: {str(exc)}'
|
|
||||||
)
|
|
||||||
# exc is not None, derives from Exception, and isn't ctrl+c
|
# exc is not None, derives from Exception, and isn't ctrl+c
|
||||||
exceptions.append(exc)
|
exceptions.append(exc)
|
||||||
return merge_tables(tables), exceptions
|
return merge_tables(tables), exceptions
|
||||||
|
|||||||
@@ -30,9 +30,11 @@ class _Available:
|
|||||||
x.update(big_expensive_db_query())
|
x.update(big_expensive_db_query())
|
||||||
return x
|
return x
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def inner(func):
|
def inner(func):
|
||||||
func._parse_replacement_ = parse_replacement
|
func._parse_replacement_ = parse_replacement
|
||||||
return self(func)
|
return self(func)
|
||||||
|
|
||||||
return inner
|
return inner
|
||||||
|
|
||||||
def deprecated(
|
def deprecated(
|
||||||
@@ -57,13 +59,14 @@ class _Available:
|
|||||||
The optional parse_replacement, if provided, will provide a parse-time
|
The optional parse_replacement, if provided, will provide a parse-time
|
||||||
replacement for the actual method (see `available.parse`).
|
replacement for the actual method (see `available.parse`).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrapper(func):
|
def wrapper(func):
|
||||||
func_name = func.__name__
|
func_name = func.__name__
|
||||||
renamed_method(func_name, supported_name)
|
renamed_method(func_name, supported_name)
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def inner(*args, **kwargs):
|
def inner(*args, **kwargs):
|
||||||
warn('adapter:{}'.format(func_name))
|
warn("adapter:{}".format(func_name))
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
if parse_replacement:
|
if parse_replacement:
|
||||||
@@ -71,6 +74,7 @@ class _Available:
|
|||||||
else:
|
else:
|
||||||
available_function = self
|
available_function = self
|
||||||
return available_function(inner)
|
return available_function(inner)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
def parse_none(self, func: Callable) -> Callable:
|
def parse_none(self, func: Callable) -> Callable:
|
||||||
@@ -109,14 +113,14 @@ class AdapterMeta(abc.ABCMeta):
|
|||||||
|
|
||||||
# collect base class data first
|
# collect base class data first
|
||||||
for base in bases:
|
for base in bases:
|
||||||
available.update(getattr(base, '_available_', set()))
|
available.update(getattr(base, "_available_", set()))
|
||||||
replacements.update(getattr(base, '_parse_replacements_', set()))
|
replacements.update(getattr(base, "_parse_replacements_", set()))
|
||||||
|
|
||||||
# override with local data if it exists
|
# override with local data if it exists
|
||||||
for name, value in namespace.items():
|
for name, value in namespace.items():
|
||||||
if getattr(value, '_is_available_', False):
|
if getattr(value, "_is_available_", False):
|
||||||
available.add(name)
|
available.add(name)
|
||||||
parse_replacement = getattr(value, '_parse_replacement_', None)
|
parse_replacement = getattr(value, "_parse_replacement_", None)
|
||||||
if parse_replacement is not None:
|
if parse_replacement is not None:
|
||||||
replacements[name] = parse_replacement
|
replacements[name] = parse_replacement
|
||||||
|
|
||||||
|
|||||||
@@ -8,11 +8,10 @@ from dbt.adapters.protocol import AdapterProtocol
|
|||||||
def project_name_from_path(include_path: str) -> str:
|
def project_name_from_path(include_path: str) -> str:
|
||||||
# avoid an import cycle
|
# avoid an import cycle
|
||||||
from dbt.config.project import Project
|
from dbt.config.project import Project
|
||||||
|
|
||||||
partial = Project.partial_load(include_path)
|
partial = Project.partial_load(include_path)
|
||||||
if partial.project_name is None:
|
if partial.project_name is None:
|
||||||
raise CompilationException(
|
raise CompilationException(f"Invalid project at {include_path}: name not set!")
|
||||||
f'Invalid project at {include_path}: name not set!'
|
|
||||||
)
|
|
||||||
return partial.project_name
|
return partial.project_name
|
||||||
|
|
||||||
|
|
||||||
@@ -23,12 +22,13 @@ class AdapterPlugin:
|
|||||||
:param dependencies: A list of adapter names that this adapter depends
|
:param dependencies: A list of adapter names that this adapter depends
|
||||||
upon.
|
upon.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
adapter: Type[AdapterProtocol],
|
adapter: Type[AdapterProtocol],
|
||||||
credentials: Type[Credentials],
|
credentials: Type[Credentials],
|
||||||
include_path: str,
|
include_path: str,
|
||||||
dependencies: Optional[List[str]] = None
|
dependencies: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
self.adapter: Type[AdapterProtocol] = adapter
|
self.adapter: Type[AdapterProtocol] = adapter
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ class NodeWrapper:
|
|||||||
self._inner_node = node
|
self._inner_node = node
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
return getattr(self._inner_node, name, '')
|
return getattr(self._inner_node, name, "")
|
||||||
|
|
||||||
|
|
||||||
class _QueryComment(local):
|
class _QueryComment(local):
|
||||||
@@ -24,6 +24,7 @@ class _QueryComment(local):
|
|||||||
- the current thread's query comment.
|
- the current thread's query comment.
|
||||||
- a source_name indicating what set the current thread's query comment
|
- a source_name indicating what set the current thread's query comment
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, initial):
|
def __init__(self, initial):
|
||||||
self.query_comment: Optional[str] = initial
|
self.query_comment: Optional[str] = initial
|
||||||
self.append = False
|
self.append = False
|
||||||
@@ -35,16 +36,16 @@ class _QueryComment(local):
|
|||||||
if self.append:
|
if self.append:
|
||||||
# replace last ';' with '<comment>;'
|
# replace last ';' with '<comment>;'
|
||||||
sql = sql.rstrip()
|
sql = sql.rstrip()
|
||||||
if sql[-1] == ';':
|
if sql[-1] == ";":
|
||||||
sql = sql[:-1]
|
sql = sql[:-1]
|
||||||
return '{}\n/* {} */;'.format(sql, self.query_comment.strip())
|
return "{}\n/* {} */;".format(sql, self.query_comment.strip())
|
||||||
|
|
||||||
return '{}\n/* {} */'.format(sql, self.query_comment.strip())
|
return "{}\n/* {} */".format(sql, self.query_comment.strip())
|
||||||
|
|
||||||
return '/* {} */\n{}'.format(self.query_comment.strip(), sql)
|
return "/* {} */\n{}".format(self.query_comment.strip(), sql)
|
||||||
|
|
||||||
def set(self, comment: Optional[str], append: bool):
|
def set(self, comment: Optional[str], append: bool):
|
||||||
if isinstance(comment, str) and '*/' in comment:
|
if isinstance(comment, str) and "*/" in comment:
|
||||||
# tell the user "no" so they don't hurt themselves by writing
|
# tell the user "no" so they don't hurt themselves by writing
|
||||||
# garbage
|
# garbage
|
||||||
raise RuntimeException(
|
raise RuntimeException(
|
||||||
@@ -63,15 +64,17 @@ class MacroQueryStringSetter:
|
|||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
comment_macro = self._get_comment_macro()
|
comment_macro = self._get_comment_macro()
|
||||||
self.generator: QueryStringFunc = lambda name, model: ''
|
self.generator: QueryStringFunc = lambda name, model: ""
|
||||||
# if the comment value was None or the empty string, just skip it
|
# if the comment value was None or the empty string, just skip it
|
||||||
if comment_macro:
|
if comment_macro:
|
||||||
assert isinstance(comment_macro, str)
|
assert isinstance(comment_macro, str)
|
||||||
macro = '\n'.join((
|
macro = "\n".join(
|
||||||
'{%- macro query_comment_macro(connection_name, node) -%}',
|
(
|
||||||
comment_macro,
|
"{%- macro query_comment_macro(connection_name, node) -%}",
|
||||||
'{% endmacro %}'
|
comment_macro,
|
||||||
))
|
"{% endmacro %}",
|
||||||
|
)
|
||||||
|
)
|
||||||
ctx = self._get_context()
|
ctx = self._get_context()
|
||||||
self.generator = QueryStringGenerator(macro, ctx)
|
self.generator = QueryStringGenerator(macro, ctx)
|
||||||
self.comment = _QueryComment(None)
|
self.comment = _QueryComment(None)
|
||||||
@@ -87,7 +90,7 @@ class MacroQueryStringSetter:
|
|||||||
return self.comment.add(sql)
|
return self.comment.add(sql)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.set('master', None)
|
self.set("master", None)
|
||||||
|
|
||||||
def set(self, name: str, node: Optional[CompileResultNode]):
|
def set(self, name: str, node: Optional[CompileResultNode]):
|
||||||
wrapped: Optional[NodeWrapper] = None
|
wrapped: Optional[NodeWrapper] = None
|
||||||
|
|||||||
@@ -1,13 +1,16 @@
|
|||||||
from collections.abc import Hashable
|
from collections.abc import Hashable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (
|
from typing import Optional, TypeVar, Any, Type, Dict, Union, Iterator, Tuple, Set
|
||||||
Optional, TypeVar, Any, Type, Dict, Union, Iterator, Tuple, Set
|
|
||||||
)
|
|
||||||
|
|
||||||
from dbt.contracts.graph.compiled import CompiledNode
|
from dbt.contracts.graph.compiled import CompiledNode
|
||||||
from dbt.contracts.graph.parsed import ParsedSourceDefinition, ParsedNode
|
from dbt.contracts.graph.parsed import ParsedSourceDefinition, ParsedNode
|
||||||
from dbt.contracts.relation import (
|
from dbt.contracts.relation import (
|
||||||
RelationType, ComponentName, HasQuoting, FakeAPIObject, Policy, Path
|
RelationType,
|
||||||
|
ComponentName,
|
||||||
|
HasQuoting,
|
||||||
|
FakeAPIObject,
|
||||||
|
Policy,
|
||||||
|
Path,
|
||||||
)
|
)
|
||||||
from dbt.exceptions import InternalException
|
from dbt.exceptions import InternalException
|
||||||
from dbt.node_types import NodeType
|
from dbt.node_types import NodeType
|
||||||
@@ -16,13 +19,13 @@ from dbt.utils import filter_null_values, deep_merge, classproperty
|
|||||||
import dbt.exceptions
|
import dbt.exceptions
|
||||||
|
|
||||||
|
|
||||||
Self = TypeVar('Self', bound='BaseRelation')
|
Self = TypeVar("Self", bound="BaseRelation")
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, eq=False, repr=False)
|
@dataclass(frozen=True, eq=False, repr=False)
|
||||||
class BaseRelation(FakeAPIObject, Hashable):
|
class BaseRelation(FakeAPIObject, Hashable):
|
||||||
type: Optional[RelationType]
|
|
||||||
path: Path
|
path: Path
|
||||||
|
type: Optional[RelationType] = None
|
||||||
quote_character: str = '"'
|
quote_character: str = '"'
|
||||||
include_policy: Policy = Policy()
|
include_policy: Policy = Policy()
|
||||||
quote_policy: Policy = Policy()
|
quote_policy: Policy = Policy()
|
||||||
@@ -40,29 +43,27 @@ class BaseRelation(FakeAPIObject, Hashable):
|
|||||||
if field.name == field_name:
|
if field.name == field_name:
|
||||||
return field
|
return field
|
||||||
# this should be unreachable
|
# this should be unreachable
|
||||||
raise ValueError(f'BaseRelation has no {field_name} field!')
|
raise ValueError(f"BaseRelation has no {field_name} field!")
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if not isinstance(other, self.__class__):
|
if not isinstance(other, self.__class__):
|
||||||
return False
|
return False
|
||||||
return self.to_dict() == other.to_dict()
|
return self.to_dict(omit_none=True) == other.to_dict(omit_none=True)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_quote_policy(cls) -> Policy:
|
def get_default_quote_policy(cls) -> Policy:
|
||||||
return cls._get_field_named('quote_policy').default
|
return cls._get_field_named("quote_policy").default
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_include_policy(cls) -> Policy:
|
def get_default_include_policy(cls) -> Policy:
|
||||||
return cls._get_field_named('include_policy').default
|
return cls._get_field_named("include_policy").default
|
||||||
|
|
||||||
def get(self, key, default=None):
|
def get(self, key, default=None):
|
||||||
"""Override `.get` to return a metadata object so we don't break
|
"""Override `.get` to return a metadata object so we don't break
|
||||||
dbt_utils.
|
dbt_utils.
|
||||||
"""
|
"""
|
||||||
if key == 'metadata':
|
if key == "metadata":
|
||||||
return {
|
return {"type": self.__class__.__name__}
|
||||||
'type': self.__class__.__name__
|
|
||||||
}
|
|
||||||
return super().get(key, default)
|
return super().get(key, default)
|
||||||
|
|
||||||
def matches(
|
def matches(
|
||||||
@@ -71,16 +72,19 @@ class BaseRelation(FakeAPIObject, Hashable):
|
|||||||
schema: Optional[str] = None,
|
schema: Optional[str] = None,
|
||||||
identifier: Optional[str] = None,
|
identifier: Optional[str] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
search = filter_null_values({
|
search = filter_null_values(
|
||||||
ComponentName.Database: database,
|
{
|
||||||
ComponentName.Schema: schema,
|
ComponentName.Database: database,
|
||||||
ComponentName.Identifier: identifier
|
ComponentName.Schema: schema,
|
||||||
})
|
ComponentName.Identifier: identifier,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
if not search:
|
if not search:
|
||||||
# nothing was passed in
|
# nothing was passed in
|
||||||
raise dbt.exceptions.RuntimeException(
|
raise dbt.exceptions.RuntimeException(
|
||||||
"Tried to match relation, but no search path was passed!")
|
"Tried to match relation, but no search path was passed!"
|
||||||
|
)
|
||||||
|
|
||||||
exact_match = True
|
exact_match = True
|
||||||
approximate_match = True
|
approximate_match = True
|
||||||
@@ -109,11 +113,13 @@ class BaseRelation(FakeAPIObject, Hashable):
|
|||||||
schema: Optional[bool] = None,
|
schema: Optional[bool] = None,
|
||||||
identifier: Optional[bool] = None,
|
identifier: Optional[bool] = None,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
policy = filter_null_values({
|
policy = filter_null_values(
|
||||||
ComponentName.Database: database,
|
{
|
||||||
ComponentName.Schema: schema,
|
ComponentName.Database: database,
|
||||||
ComponentName.Identifier: identifier
|
ComponentName.Schema: schema,
|
||||||
})
|
ComponentName.Identifier: identifier,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
new_quote_policy = self.quote_policy.replace_dict(policy)
|
new_quote_policy = self.quote_policy.replace_dict(policy)
|
||||||
return self.replace(quote_policy=new_quote_policy)
|
return self.replace(quote_policy=new_quote_policy)
|
||||||
@@ -124,16 +130,18 @@ class BaseRelation(FakeAPIObject, Hashable):
|
|||||||
schema: Optional[bool] = None,
|
schema: Optional[bool] = None,
|
||||||
identifier: Optional[bool] = None,
|
identifier: Optional[bool] = None,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
policy = filter_null_values({
|
policy = filter_null_values(
|
||||||
ComponentName.Database: database,
|
{
|
||||||
ComponentName.Schema: schema,
|
ComponentName.Database: database,
|
||||||
ComponentName.Identifier: identifier
|
ComponentName.Schema: schema,
|
||||||
})
|
ComponentName.Identifier: identifier,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
new_include_policy = self.include_policy.replace_dict(policy)
|
new_include_policy = self.include_policy.replace_dict(policy)
|
||||||
return self.replace(include_policy=new_include_policy)
|
return self.replace(include_policy=new_include_policy)
|
||||||
|
|
||||||
def information_schema(self, view_name=None) -> 'InformationSchema':
|
def information_schema(self, view_name=None) -> "InformationSchema":
|
||||||
# some of our data comes from jinja, where things can be `Undefined`.
|
# some of our data comes from jinja, where things can be `Undefined`.
|
||||||
if not isinstance(view_name, str):
|
if not isinstance(view_name, str):
|
||||||
view_name = None
|
view_name = None
|
||||||
@@ -143,10 +151,10 @@ class BaseRelation(FakeAPIObject, Hashable):
|
|||||||
info_schema = InformationSchema.from_relation(self, view_name)
|
info_schema = InformationSchema.from_relation(self, view_name)
|
||||||
return info_schema.incorporate(path={"schema": None})
|
return info_schema.incorporate(path={"schema": None})
|
||||||
|
|
||||||
def information_schema_only(self) -> 'InformationSchema':
|
def information_schema_only(self) -> "InformationSchema":
|
||||||
return self.information_schema()
|
return self.information_schema()
|
||||||
|
|
||||||
def without_identifier(self) -> 'BaseRelation':
|
def without_identifier(self) -> "BaseRelation":
|
||||||
"""Return a form of this relation that only has the database and schema
|
"""Return a form of this relation that only has the database and schema
|
||||||
set to included. To get the appropriately-quoted form the schema out of
|
set to included. To get the appropriately-quoted form the schema out of
|
||||||
the result (for use as part of a query), use `.render()`. To get the
|
the result (for use as part of a query), use `.render()`. To get the
|
||||||
@@ -157,7 +165,7 @@ class BaseRelation(FakeAPIObject, Hashable):
|
|||||||
return self.include(identifier=False).replace_path(identifier=None)
|
return self.include(identifier=False).replace_path(identifier=None)
|
||||||
|
|
||||||
def _render_iterator(
|
def _render_iterator(
|
||||||
self
|
self,
|
||||||
) -> Iterator[Tuple[Optional[ComponentName], Optional[str]]]:
|
) -> Iterator[Tuple[Optional[ComponentName], Optional[str]]]:
|
||||||
|
|
||||||
for key in ComponentName:
|
for key in ComponentName:
|
||||||
@@ -170,13 +178,10 @@ class BaseRelation(FakeAPIObject, Hashable):
|
|||||||
|
|
||||||
def render(self) -> str:
|
def render(self) -> str:
|
||||||
# if there is nothing set, this will return the empty string.
|
# if there is nothing set, this will return the empty string.
|
||||||
return '.'.join(
|
return ".".join(part for _, part in self._render_iterator() if part is not None)
|
||||||
part for _, part in self._render_iterator()
|
|
||||||
if part is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
def quoted(self, identifier):
|
def quoted(self, identifier):
|
||||||
return '{quote_char}{identifier}{quote_char}'.format(
|
return "{quote_char}{identifier}{quote_char}".format(
|
||||||
quote_char=self.quote_character,
|
quote_char=self.quote_character,
|
||||||
identifier=identifier,
|
identifier=identifier,
|
||||||
)
|
)
|
||||||
@@ -185,12 +190,12 @@ class BaseRelation(FakeAPIObject, Hashable):
|
|||||||
def create_from_source(
|
def create_from_source(
|
||||||
cls: Type[Self], source: ParsedSourceDefinition, **kwargs: Any
|
cls: Type[Self], source: ParsedSourceDefinition, **kwargs: Any
|
||||||
) -> Self:
|
) -> Self:
|
||||||
source_quoting = source.quoting.to_dict()
|
source_quoting = source.quoting.to_dict(omit_none=True)
|
||||||
source_quoting.pop('column', None)
|
source_quoting.pop("column", None)
|
||||||
quote_policy = deep_merge(
|
quote_policy = deep_merge(
|
||||||
cls.get_default_quote_policy().to_dict(),
|
cls.get_default_quote_policy().to_dict(omit_none=True),
|
||||||
source_quoting,
|
source_quoting,
|
||||||
kwargs.get('quote_policy', {}),
|
kwargs.get("quote_policy", {}),
|
||||||
)
|
)
|
||||||
|
|
||||||
return cls.create(
|
return cls.create(
|
||||||
@@ -198,12 +203,12 @@ class BaseRelation(FakeAPIObject, Hashable):
|
|||||||
schema=source.schema,
|
schema=source.schema,
|
||||||
identifier=source.identifier,
|
identifier=source.identifier,
|
||||||
quote_policy=quote_policy,
|
quote_policy=quote_policy,
|
||||||
**kwargs
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_ephemeral_prefix(name: str):
|
def add_ephemeral_prefix(name: str):
|
||||||
return f'__dbt__cte__{name}'
|
return f"__dbt__cte__{name}"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_ephemeral_from_node(
|
def create_ephemeral_from_node(
|
||||||
@@ -236,7 +241,8 @@ class BaseRelation(FakeAPIObject, Hashable):
|
|||||||
schema=node.schema,
|
schema=node.schema,
|
||||||
identifier=node.alias,
|
identifier=node.alias,
|
||||||
quote_policy=quote_policy,
|
quote_policy=quote_policy,
|
||||||
**kwargs)
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_from(
|
def create_from(
|
||||||
@@ -248,15 +254,16 @@ class BaseRelation(FakeAPIObject, Hashable):
|
|||||||
if node.resource_type == NodeType.Source:
|
if node.resource_type == NodeType.Source:
|
||||||
if not isinstance(node, ParsedSourceDefinition):
|
if not isinstance(node, ParsedSourceDefinition):
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
'type mismatch, expected ParsedSourceDefinition but got {}'
|
"type mismatch, expected ParsedSourceDefinition but got {}".format(
|
||||||
.format(type(node))
|
type(node)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return cls.create_from_source(node, **kwargs)
|
return cls.create_from_source(node, **kwargs)
|
||||||
else:
|
else:
|
||||||
if not isinstance(node, (ParsedNode, CompiledNode)):
|
if not isinstance(node, (ParsedNode, CompiledNode)):
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
'type mismatch, expected ParsedNode or CompiledNode but '
|
"type mismatch, expected ParsedNode or CompiledNode but "
|
||||||
'got {}'.format(type(node))
|
"got {}".format(type(node))
|
||||||
)
|
)
|
||||||
return cls.create_from_node(config, node, **kwargs)
|
return cls.create_from_node(config, node, **kwargs)
|
||||||
|
|
||||||
@@ -269,14 +276,16 @@ class BaseRelation(FakeAPIObject, Hashable):
|
|||||||
type: Optional[RelationType] = None,
|
type: Optional[RelationType] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
kwargs.update({
|
kwargs.update(
|
||||||
'path': {
|
{
|
||||||
'database': database,
|
"path": {
|
||||||
'schema': schema,
|
"database": database,
|
||||||
'identifier': identifier,
|
"schema": schema,
|
||||||
},
|
"identifier": identifier,
|
||||||
'type': type,
|
},
|
||||||
})
|
"type": type,
|
||||||
|
}
|
||||||
|
)
|
||||||
return cls.from_dict(kwargs)
|
return cls.from_dict(kwargs)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
@@ -342,7 +351,7 @@ class BaseRelation(FakeAPIObject, Hashable):
|
|||||||
return RelationType
|
return RelationType
|
||||||
|
|
||||||
|
|
||||||
Info = TypeVar('Info', bound='InformationSchema')
|
Info = TypeVar("Info", bound="InformationSchema")
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, eq=False, repr=False)
|
@dataclass(frozen=True, eq=False, repr=False)
|
||||||
@@ -352,7 +361,7 @@ class InformationSchema(BaseRelation):
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if not isinstance(self.information_schema_view, (type(None), str)):
|
if not isinstance(self.information_schema_view, (type(None), str)):
|
||||||
raise dbt.exceptions.CompilationException(
|
raise dbt.exceptions.CompilationException(
|
||||||
'Got an invalid name: {}'.format(self.information_schema_view)
|
"Got an invalid name: {}".format(self.information_schema_view)
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -362,7 +371,7 @@ class InformationSchema(BaseRelation):
|
|||||||
return Path(
|
return Path(
|
||||||
database=relation.database,
|
database=relation.database,
|
||||||
schema=relation.schema,
|
schema=relation.schema,
|
||||||
identifier='INFORMATION_SCHEMA',
|
identifier="INFORMATION_SCHEMA",
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -393,9 +402,7 @@ class InformationSchema(BaseRelation):
|
|||||||
relation: BaseRelation,
|
relation: BaseRelation,
|
||||||
information_schema_view: Optional[str],
|
information_schema_view: Optional[str],
|
||||||
) -> Info:
|
) -> Info:
|
||||||
include_policy = cls.get_include_policy(
|
include_policy = cls.get_include_policy(relation, information_schema_view)
|
||||||
relation, information_schema_view
|
|
||||||
)
|
|
||||||
quote_policy = cls.get_quote_policy(relation, information_schema_view)
|
quote_policy = cls.get_quote_policy(relation, information_schema_view)
|
||||||
path = cls.get_path(relation, information_schema_view)
|
path = cls.get_path(relation, information_schema_view)
|
||||||
return cls(
|
return cls(
|
||||||
@@ -417,6 +424,7 @@ class SchemaSearchMap(Dict[InformationSchema, Set[Optional[str]]]):
|
|||||||
search for what schemas. The schema values are all lowercased to avoid
|
search for what schemas. The schema values are all lowercased to avoid
|
||||||
duplication.
|
duplication.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def add(self, relation: BaseRelation):
|
def add(self, relation: BaseRelation):
|
||||||
key = relation.information_schema_only()
|
key = relation.information_schema_only()
|
||||||
if key not in self:
|
if key not in self:
|
||||||
@@ -426,9 +434,7 @@ class SchemaSearchMap(Dict[InformationSchema, Set[Optional[str]]]):
|
|||||||
schema = relation.schema.lower()
|
schema = relation.schema.lower()
|
||||||
self[key].add(schema)
|
self[key].add(schema)
|
||||||
|
|
||||||
def search(
|
def search(self) -> Iterator[Tuple[InformationSchema, Optional[str]]]:
|
||||||
self
|
|
||||||
) -> Iterator[Tuple[InformationSchema, Optional[str]]]:
|
|
||||||
for information_schema_name, schemas in self.items():
|
for information_schema_name, schemas in self.items():
|
||||||
for schema in schemas:
|
for schema in schemas:
|
||||||
yield information_schema_name, schema
|
yield information_schema_name, schema
|
||||||
@@ -442,14 +448,13 @@ class SchemaSearchMap(Dict[InformationSchema, Set[Optional[str]]]):
|
|||||||
dbt.exceptions.raise_compiler_error(str(seen))
|
dbt.exceptions.raise_compiler_error(str(seen))
|
||||||
|
|
||||||
for information_schema_name, schema in self.search():
|
for information_schema_name, schema in self.search():
|
||||||
path = {
|
path = {"database": information_schema_name.database, "schema": schema}
|
||||||
'database': information_schema_name.database,
|
new.add(
|
||||||
'schema': schema
|
information_schema_name.incorporate(
|
||||||
}
|
path=path,
|
||||||
new.add(information_schema_name.incorporate(
|
quote_policy={"database": False},
|
||||||
path=path,
|
include_policy={"database": False},
|
||||||
quote_policy={'database': False},
|
)
|
||||||
include_policy={'database': False},
|
)
|
||||||
))
|
|
||||||
|
|
||||||
return new
|
return new
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from dbt.logger import CACHE_LOGGER as logger
|
|||||||
from dbt.utils import lowercase
|
from dbt.utils import lowercase
|
||||||
import dbt.exceptions
|
import dbt.exceptions
|
||||||
|
|
||||||
_ReferenceKey = namedtuple('_ReferenceKey', 'database schema identifier')
|
_ReferenceKey = namedtuple("_ReferenceKey", "database schema identifier")
|
||||||
|
|
||||||
|
|
||||||
def _make_key(relation) -> _ReferenceKey:
|
def _make_key(relation) -> _ReferenceKey:
|
||||||
@@ -15,9 +15,11 @@ def _make_key(relation) -> _ReferenceKey:
|
|||||||
to keep track of quoting
|
to keep track of quoting
|
||||||
"""
|
"""
|
||||||
# databases and schemas can both be None
|
# databases and schemas can both be None
|
||||||
return _ReferenceKey(lowercase(relation.database),
|
return _ReferenceKey(
|
||||||
lowercase(relation.schema),
|
lowercase(relation.database),
|
||||||
lowercase(relation.identifier))
|
lowercase(relation.schema),
|
||||||
|
lowercase(relation.identifier),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def dot_separated(key: _ReferenceKey) -> str:
|
def dot_separated(key: _ReferenceKey) -> str:
|
||||||
@@ -25,7 +27,7 @@ def dot_separated(key: _ReferenceKey) -> str:
|
|||||||
|
|
||||||
:param _ReferenceKey key: The key to stringify.
|
:param _ReferenceKey key: The key to stringify.
|
||||||
"""
|
"""
|
||||||
return '.'.join(map(str, key))
|
return ".".join(map(str, key))
|
||||||
|
|
||||||
|
|
||||||
class _CachedRelation:
|
class _CachedRelation:
|
||||||
@@ -37,13 +39,14 @@ class _CachedRelation:
|
|||||||
that refer to this relation.
|
that refer to this relation.
|
||||||
:attr BaseRelation inner: The underlying dbt relation.
|
:attr BaseRelation inner: The underlying dbt relation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, inner):
|
def __init__(self, inner):
|
||||||
self.referenced_by = {}
|
self.referenced_by = {}
|
||||||
self.inner = inner
|
self.inner = inner
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return (
|
return (
|
||||||
'_CachedRelation(database={}, schema={}, identifier={}, inner={})'
|
"_CachedRelation(database={}, schema={}, identifier={}, inner={})"
|
||||||
).format(self.database, self.schema, self.identifier, self.inner)
|
).format(self.database, self.schema, self.identifier, self.inner)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -78,7 +81,7 @@ class _CachedRelation:
|
|||||||
"""
|
"""
|
||||||
return _make_key(self)
|
return _make_key(self)
|
||||||
|
|
||||||
def add_reference(self, referrer: '_CachedRelation'):
|
def add_reference(self, referrer: "_CachedRelation"):
|
||||||
"""Add a reference from referrer to self, indicating that if this node
|
"""Add a reference from referrer to self, indicating that if this node
|
||||||
were drop...cascaded, the referrer would be dropped as well.
|
were drop...cascaded, the referrer would be dropped as well.
|
||||||
|
|
||||||
@@ -122,9 +125,9 @@ class _CachedRelation:
|
|||||||
# table_name is ever anything but the identifier (via .create())
|
# table_name is ever anything but the identifier (via .create())
|
||||||
self.inner = self.inner.incorporate(
|
self.inner = self.inner.incorporate(
|
||||||
path={
|
path={
|
||||||
'database': new_relation.inner.database,
|
"database": new_relation.inner.database,
|
||||||
'schema': new_relation.inner.schema,
|
"schema": new_relation.inner.schema,
|
||||||
'identifier': new_relation.inner.identifier
|
"identifier": new_relation.inner.identifier,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -140,8 +143,9 @@ class _CachedRelation:
|
|||||||
"""
|
"""
|
||||||
if new_key in self.referenced_by:
|
if new_key in self.referenced_by:
|
||||||
dbt.exceptions.raise_cache_inconsistent(
|
dbt.exceptions.raise_cache_inconsistent(
|
||||||
'in rename of "{}" -> "{}", new name is in the cache already'
|
'in rename of "{}" -> "{}", new name is in the cache already'.format(
|
||||||
.format(old_key, new_key)
|
old_key, new_key
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if old_key not in self.referenced_by:
|
if old_key not in self.referenced_by:
|
||||||
@@ -172,13 +176,16 @@ class RelationsCache:
|
|||||||
The adapters also hold this lock while filling the cache.
|
The adapters also hold this lock while filling the cache.
|
||||||
:attr Set[str] schemas: The set of known/cached schemas, all lowercased.
|
:attr Set[str] schemas: The set of known/cached schemas, all lowercased.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.relations: Dict[_ReferenceKey, _CachedRelation] = {}
|
self.relations: Dict[_ReferenceKey, _CachedRelation] = {}
|
||||||
self.lock = threading.RLock()
|
self.lock = threading.RLock()
|
||||||
self.schemas: Set[Tuple[Optional[str], Optional[str]]] = set()
|
self.schemas: Set[Tuple[Optional[str], Optional[str]]] = set()
|
||||||
|
|
||||||
def add_schema(
|
def add_schema(
|
||||||
self, database: Optional[str], schema: Optional[str],
|
self,
|
||||||
|
database: Optional[str],
|
||||||
|
schema: Optional[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add a schema to the set of known schemas (case-insensitive)
|
"""Add a schema to the set of known schemas (case-insensitive)
|
||||||
|
|
||||||
@@ -188,7 +195,9 @@ class RelationsCache:
|
|||||||
self.schemas.add((lowercase(database), lowercase(schema)))
|
self.schemas.add((lowercase(database), lowercase(schema)))
|
||||||
|
|
||||||
def drop_schema(
|
def drop_schema(
|
||||||
self, database: Optional[str], schema: Optional[str],
|
self,
|
||||||
|
database: Optional[str],
|
||||||
|
schema: Optional[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Drop the given schema and remove it from the set of known schemas.
|
"""Drop the given schema and remove it from the set of known schemas.
|
||||||
|
|
||||||
@@ -263,15 +272,15 @@ class RelationsCache:
|
|||||||
return
|
return
|
||||||
if referenced is None:
|
if referenced is None:
|
||||||
dbt.exceptions.raise_cache_inconsistent(
|
dbt.exceptions.raise_cache_inconsistent(
|
||||||
'in add_link, referenced link key {} not in cache!'
|
"in add_link, referenced link key {} not in cache!".format(
|
||||||
.format(referenced_key)
|
referenced_key
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
dependent = self.relations.get(dependent_key)
|
dependent = self.relations.get(dependent_key)
|
||||||
if dependent is None:
|
if dependent is None:
|
||||||
dbt.exceptions.raise_cache_inconsistent(
|
dbt.exceptions.raise_cache_inconsistent(
|
||||||
'in add_link, dependent link key {} not in cache!'
|
"in add_link, dependent link key {} not in cache!".format(dependent_key)
|
||||||
.format(dependent_key)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert dependent is not None # we just raised!
|
assert dependent is not None # we just raised!
|
||||||
@@ -298,28 +307,23 @@ class RelationsCache:
|
|||||||
# referring to a table outside our control. There's no need to make
|
# referring to a table outside our control. There's no need to make
|
||||||
# a link - we will never drop the referenced relation during a run.
|
# a link - we will never drop the referenced relation during a run.
|
||||||
logger.debug(
|
logger.debug(
|
||||||
'{dep!s} references {ref!s} but {ref.database}.{ref.schema} '
|
"{dep!s} references {ref!s} but {ref.database}.{ref.schema} "
|
||||||
'is not in the cache, skipping assumed external relation'
|
"is not in the cache, skipping assumed external relation".format(
|
||||||
.format(dep=dependent, ref=ref_key)
|
dep=dependent, ref=ref_key
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
if ref_key not in self.relations:
|
if ref_key not in self.relations:
|
||||||
# Insert a dummy "external" relation.
|
# Insert a dummy "external" relation.
|
||||||
referenced = referenced.replace(
|
referenced = referenced.replace(type=referenced.External)
|
||||||
type=referenced.External
|
|
||||||
)
|
|
||||||
self.add(referenced)
|
self.add(referenced)
|
||||||
|
|
||||||
dep_key = _make_key(dependent)
|
dep_key = _make_key(dependent)
|
||||||
if dep_key not in self.relations:
|
if dep_key not in self.relations:
|
||||||
# Insert a dummy "external" relation.
|
# Insert a dummy "external" relation.
|
||||||
dependent = dependent.replace(
|
dependent = dependent.replace(type=referenced.External)
|
||||||
type=referenced.External
|
|
||||||
)
|
|
||||||
self.add(dependent)
|
self.add(dependent)
|
||||||
logger.debug(
|
logger.debug("adding link, {!s} references {!s}".format(dep_key, ref_key))
|
||||||
'adding link, {!s} references {!s}'.format(dep_key, ref_key)
|
|
||||||
)
|
|
||||||
with self.lock:
|
with self.lock:
|
||||||
self._add_link(ref_key, dep_key)
|
self._add_link(ref_key, dep_key)
|
||||||
|
|
||||||
@@ -330,14 +334,14 @@ class RelationsCache:
|
|||||||
:param BaseRelation relation: The underlying relation.
|
:param BaseRelation relation: The underlying relation.
|
||||||
"""
|
"""
|
||||||
cached = _CachedRelation(relation)
|
cached = _CachedRelation(relation)
|
||||||
logger.debug('Adding relation: {!s}'.format(cached))
|
logger.debug("Adding relation: {!s}".format(cached))
|
||||||
|
|
||||||
lazy_log('before adding: {!s}', self.dump_graph)
|
lazy_log("before adding: {!s}", self.dump_graph)
|
||||||
|
|
||||||
with self.lock:
|
with self.lock:
|
||||||
self._setdefault(cached)
|
self._setdefault(cached)
|
||||||
|
|
||||||
lazy_log('after adding: {!s}', self.dump_graph)
|
lazy_log("after adding: {!s}", self.dump_graph)
|
||||||
|
|
||||||
def _remove_refs(self, keys):
|
def _remove_refs(self, keys):
|
||||||
"""Removes all references to all entries in keys. This does not
|
"""Removes all references to all entries in keys. This does not
|
||||||
@@ -359,13 +363,10 @@ class RelationsCache:
|
|||||||
:param _CachedRelation dropped: An existing _CachedRelation to drop.
|
:param _CachedRelation dropped: An existing _CachedRelation to drop.
|
||||||
"""
|
"""
|
||||||
if dropped not in self.relations:
|
if dropped not in self.relations:
|
||||||
logger.debug('dropped a nonexistent relationship: {!s}'
|
logger.debug("dropped a nonexistent relationship: {!s}".format(dropped))
|
||||||
.format(dropped))
|
|
||||||
return
|
return
|
||||||
consequences = self.relations[dropped].collect_consequences()
|
consequences = self.relations[dropped].collect_consequences()
|
||||||
logger.debug(
|
logger.debug("drop {} is cascading to {}".format(dropped, consequences))
|
||||||
'drop {} is cascading to {}'.format(dropped, consequences)
|
|
||||||
)
|
|
||||||
self._remove_refs(consequences)
|
self._remove_refs(consequences)
|
||||||
|
|
||||||
def drop(self, relation):
|
def drop(self, relation):
|
||||||
@@ -380,7 +381,7 @@ class RelationsCache:
|
|||||||
:param str identifier: The identifier of the relation to drop.
|
:param str identifier: The identifier of the relation to drop.
|
||||||
"""
|
"""
|
||||||
dropped = _make_key(relation)
|
dropped = _make_key(relation)
|
||||||
logger.debug('Dropping relation: {!s}'.format(dropped))
|
logger.debug("Dropping relation: {!s}".format(dropped))
|
||||||
with self.lock:
|
with self.lock:
|
||||||
self._drop_cascade_relation(dropped)
|
self._drop_cascade_relation(dropped)
|
||||||
|
|
||||||
@@ -404,8 +405,9 @@ class RelationsCache:
|
|||||||
for cached in self.relations.values():
|
for cached in self.relations.values():
|
||||||
if cached.is_referenced_by(old_key):
|
if cached.is_referenced_by(old_key):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
'updated reference from {0} -> {2} to {1} -> {2}'
|
"updated reference from {0} -> {2} to {1} -> {2}".format(
|
||||||
.format(old_key, new_key, cached.key())
|
old_key, new_key, cached.key()
|
||||||
|
)
|
||||||
)
|
)
|
||||||
cached.rename_key(old_key, new_key)
|
cached.rename_key(old_key, new_key)
|
||||||
|
|
||||||
@@ -430,14 +432,16 @@ class RelationsCache:
|
|||||||
"""
|
"""
|
||||||
if new_key in self.relations:
|
if new_key in self.relations:
|
||||||
dbt.exceptions.raise_cache_inconsistent(
|
dbt.exceptions.raise_cache_inconsistent(
|
||||||
'in rename, new key {} already in cache: {}'
|
"in rename, new key {} already in cache: {}".format(
|
||||||
.format(new_key, list(self.relations.keys()))
|
new_key, list(self.relations.keys())
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if old_key not in self.relations:
|
if old_key not in self.relations:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
'old key {} not found in self.relations, assuming temporary'
|
"old key {} not found in self.relations, assuming temporary".format(
|
||||||
.format(old_key)
|
old_key
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
@@ -456,11 +460,9 @@ class RelationsCache:
|
|||||||
"""
|
"""
|
||||||
old_key = _make_key(old)
|
old_key = _make_key(old)
|
||||||
new_key = _make_key(new)
|
new_key = _make_key(new)
|
||||||
logger.debug('Renaming relation {!s} to {!s}'.format(
|
logger.debug("Renaming relation {!s} to {!s}".format(old_key, new_key))
|
||||||
old_key, new_key
|
|
||||||
))
|
|
||||||
|
|
||||||
lazy_log('before rename: {!s}', self.dump_graph)
|
lazy_log("before rename: {!s}", self.dump_graph)
|
||||||
|
|
||||||
with self.lock:
|
with self.lock:
|
||||||
if self._check_rename_constraints(old_key, new_key):
|
if self._check_rename_constraints(old_key, new_key):
|
||||||
@@ -468,7 +470,7 @@ class RelationsCache:
|
|||||||
else:
|
else:
|
||||||
self._setdefault(_CachedRelation(new))
|
self._setdefault(_CachedRelation(new))
|
||||||
|
|
||||||
lazy_log('after rename: {!s}', self.dump_graph)
|
lazy_log("after rename: {!s}", self.dump_graph)
|
||||||
|
|
||||||
def get_relations(
|
def get_relations(
|
||||||
self, database: Optional[str], schema: Optional[str]
|
self, database: Optional[str], schema: Optional[str]
|
||||||
@@ -483,14 +485,14 @@ class RelationsCache:
|
|||||||
schema = lowercase(schema)
|
schema = lowercase(schema)
|
||||||
with self.lock:
|
with self.lock:
|
||||||
results = [
|
results = [
|
||||||
r.inner for r in self.relations.values()
|
r.inner
|
||||||
if (lowercase(r.schema) == schema and
|
for r in self.relations.values()
|
||||||
lowercase(r.database) == database)
|
if (lowercase(r.schema) == schema and lowercase(r.database) == database)
|
||||||
]
|
]
|
||||||
|
|
||||||
if None in results:
|
if None in results:
|
||||||
dbt.exceptions.raise_cache_inconsistent(
|
dbt.exceptions.raise_cache_inconsistent(
|
||||||
'in get_relations, a None relation was found in the cache!'
|
"in get_relations, a None relation was found in the cache!"
|
||||||
)
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|||||||
@@ -50,9 +50,7 @@ class AdapterContainer:
|
|||||||
adapter = self.get_adapter_class_by_name(name)
|
adapter = self.get_adapter_class_by_name(name)
|
||||||
return adapter.Relation
|
return adapter.Relation
|
||||||
|
|
||||||
def get_config_class_by_name(
|
def get_config_class_by_name(self, name: str) -> Type[AdapterConfig]:
|
||||||
self, name: str
|
|
||||||
) -> Type[AdapterConfig]:
|
|
||||||
adapter = self.get_adapter_class_by_name(name)
|
adapter = self.get_adapter_class_by_name(name)
|
||||||
return adapter.AdapterSpecificConfigs
|
return adapter.AdapterSpecificConfigs
|
||||||
|
|
||||||
@@ -62,24 +60,24 @@ class AdapterContainer:
|
|||||||
# singletons
|
# singletons
|
||||||
try:
|
try:
|
||||||
# mypy doesn't think modules have any attributes.
|
# mypy doesn't think modules have any attributes.
|
||||||
mod: Any = import_module('.' + name, 'dbt.adapters')
|
mod: Any = import_module("." + name, "dbt.adapters")
|
||||||
except ModuleNotFoundError as exc:
|
except ModuleNotFoundError as exc:
|
||||||
# if we failed to import the target module in particular, inform
|
# if we failed to import the target module in particular, inform
|
||||||
# the user about it via a runtime error
|
# the user about it via a runtime error
|
||||||
if exc.name == 'dbt.adapters.' + name:
|
if exc.name == "dbt.adapters." + name:
|
||||||
raise RuntimeException(f'Could not find adapter type {name}!')
|
raise RuntimeException(f"Could not find adapter type {name}!")
|
||||||
logger.info(f'Error importing adapter: {exc}')
|
logger.info(f"Error importing adapter: {exc}")
|
||||||
# otherwise, the error had to have come from some underlying
|
# otherwise, the error had to have come from some underlying
|
||||||
# library. Log the stack trace.
|
# library. Log the stack trace.
|
||||||
logger.debug('', exc_info=True)
|
logger.debug("", exc_info=True)
|
||||||
raise
|
raise
|
||||||
plugin: AdapterPlugin = mod.Plugin
|
plugin: AdapterPlugin = mod.Plugin
|
||||||
plugin_type = plugin.adapter.type()
|
plugin_type = plugin.adapter.type()
|
||||||
|
|
||||||
if plugin_type != name:
|
if plugin_type != name:
|
||||||
raise RuntimeException(
|
raise RuntimeException(
|
||||||
f'Expected to find adapter with type named {name}, got '
|
f"Expected to find adapter with type named {name}, got "
|
||||||
f'adapter with type {plugin_type}'
|
f"adapter with type {plugin_type}"
|
||||||
)
|
)
|
||||||
|
|
||||||
with self.lock:
|
with self.lock:
|
||||||
@@ -109,8 +107,7 @@ class AdapterContainer:
|
|||||||
return self.adapters[adapter_name]
|
return self.adapters[adapter_name]
|
||||||
|
|
||||||
def reset_adapters(self):
|
def reset_adapters(self):
|
||||||
"""Clear the adapters. This is useful for tests, which change configs.
|
"""Clear the adapters. This is useful for tests, which change configs."""
|
||||||
"""
|
|
||||||
with self.lock:
|
with self.lock:
|
||||||
for adapter in self.adapters.values():
|
for adapter in self.adapters.values():
|
||||||
adapter.cleanup_connections()
|
adapter.cleanup_connections()
|
||||||
@@ -140,9 +137,7 @@ class AdapterContainer:
|
|||||||
try:
|
try:
|
||||||
plugin = self.plugins[plugin_name]
|
plugin = self.plugins[plugin_name]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise InternalException(
|
raise InternalException(f"No plugin found for {plugin_name}") from None
|
||||||
f'No plugin found for {plugin_name}'
|
|
||||||
) from None
|
|
||||||
plugins.append(plugin)
|
plugins.append(plugin)
|
||||||
seen.add(plugin_name)
|
seen.add(plugin_name)
|
||||||
if plugin.dependencies is None:
|
if plugin.dependencies is None:
|
||||||
@@ -166,7 +161,7 @@ class AdapterContainer:
|
|||||||
path = self.packages[package_name]
|
path = self.packages[package_name]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
f'No internal package listing found for {package_name}'
|
f"No internal package listing found for {package_name}"
|
||||||
)
|
)
|
||||||
paths.append(path)
|
paths.append(path)
|
||||||
return paths
|
return paths
|
||||||
@@ -187,8 +182,7 @@ def get_adapter(config: AdapterRequiredConfig):
|
|||||||
|
|
||||||
|
|
||||||
def reset_adapters():
|
def reset_adapters():
|
||||||
"""Clear the adapters. This is useful for tests, which change configs.
|
"""Clear the adapters. This is useful for tests, which change configs."""
|
||||||
"""
|
|
||||||
FACTORY.reset_adapters()
|
FACTORY.reset_adapters()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,17 +1,27 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (
|
from typing import (
|
||||||
Type, Hashable, Optional, ContextManager, List, Generic, TypeVar, ClassVar,
|
Type,
|
||||||
Tuple, Union, Dict, Any
|
Hashable,
|
||||||
|
Optional,
|
||||||
|
ContextManager,
|
||||||
|
List,
|
||||||
|
Generic,
|
||||||
|
TypeVar,
|
||||||
|
ClassVar,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
Dict,
|
||||||
|
Any,
|
||||||
)
|
)
|
||||||
from typing_extensions import Protocol
|
from typing_extensions import Protocol
|
||||||
|
|
||||||
import agate
|
import agate
|
||||||
|
|
||||||
from dbt.contracts.connection import (
|
from dbt.contracts.connection import Connection, AdapterRequiredConfig, AdapterResponse
|
||||||
Connection, AdapterRequiredConfig, AdapterResponse
|
|
||||||
)
|
|
||||||
from dbt.contracts.graph.compiled import (
|
from dbt.contracts.graph.compiled import (
|
||||||
CompiledNode, ManifestNode, NonSourceCompiledNode
|
CompiledNode,
|
||||||
|
ManifestNode,
|
||||||
|
NonSourceCompiledNode,
|
||||||
)
|
)
|
||||||
from dbt.contracts.graph.parsed import ParsedNode, ParsedSourceDefinition
|
from dbt.contracts.graph.parsed import ParsedNode, ParsedSourceDefinition
|
||||||
from dbt.contracts.graph.model_config import BaseConfig
|
from dbt.contracts.graph.model_config import BaseConfig
|
||||||
@@ -34,7 +44,7 @@ class ColumnProtocol(Protocol):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
Self = TypeVar('Self', bound='RelationProtocol')
|
Self = TypeVar("Self", bound="RelationProtocol")
|
||||||
|
|
||||||
|
|
||||||
class RelationProtocol(Protocol):
|
class RelationProtocol(Protocol):
|
||||||
@@ -64,19 +74,11 @@ class CompilerProtocol(Protocol):
|
|||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
AdapterConfig_T = TypeVar(
|
AdapterConfig_T = TypeVar("AdapterConfig_T", bound=AdapterConfig)
|
||||||
'AdapterConfig_T', bound=AdapterConfig
|
ConnectionManager_T = TypeVar("ConnectionManager_T", bound=ConnectionManagerProtocol)
|
||||||
)
|
Relation_T = TypeVar("Relation_T", bound=RelationProtocol)
|
||||||
ConnectionManager_T = TypeVar(
|
Column_T = TypeVar("Column_T", bound=ColumnProtocol)
|
||||||
'ConnectionManager_T', bound=ConnectionManagerProtocol
|
Compiler_T = TypeVar("Compiler_T", bound=CompilerProtocol)
|
||||||
)
|
|
||||||
Relation_T = TypeVar(
|
|
||||||
'Relation_T', bound=RelationProtocol
|
|
||||||
)
|
|
||||||
Column_T = TypeVar(
|
|
||||||
'Column_T', bound=ColumnProtocol
|
|
||||||
)
|
|
||||||
Compiler_T = TypeVar('Compiler_T', bound=CompilerProtocol)
|
|
||||||
|
|
||||||
|
|
||||||
class AdapterProtocol(
|
class AdapterProtocol(
|
||||||
@@ -87,7 +89,7 @@ class AdapterProtocol(
|
|||||||
Relation_T,
|
Relation_T,
|
||||||
Column_T,
|
Column_T,
|
||||||
Compiler_T,
|
Compiler_T,
|
||||||
]
|
],
|
||||||
):
|
):
|
||||||
AdapterSpecificConfigs: ClassVar[Type[AdapterConfig_T]]
|
AdapterSpecificConfigs: ClassVar[Type[AdapterConfig_T]]
|
||||||
Column: ClassVar[Type[Column_T]]
|
Column: ClassVar[Type[Column_T]]
|
||||||
|
|||||||
@@ -7,9 +7,7 @@ import agate
|
|||||||
import dbt.clients.agate_helper
|
import dbt.clients.agate_helper
|
||||||
import dbt.exceptions
|
import dbt.exceptions
|
||||||
from dbt.adapters.base import BaseConnectionManager
|
from dbt.adapters.base import BaseConnectionManager
|
||||||
from dbt.contracts.connection import (
|
from dbt.contracts.connection import Connection, ConnectionState, AdapterResponse
|
||||||
Connection, ConnectionState, AdapterResponse
|
|
||||||
)
|
|
||||||
from dbt.logger import GLOBAL_LOGGER as logger
|
from dbt.logger import GLOBAL_LOGGER as logger
|
||||||
from dbt import flags
|
from dbt import flags
|
||||||
|
|
||||||
@@ -23,11 +21,12 @@ class SQLConnectionManager(BaseConnectionManager):
|
|||||||
- get_response
|
- get_response
|
||||||
- open
|
- open
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def cancel(self, connection: Connection):
|
def cancel(self, connection: Connection):
|
||||||
"""Cancel the given connection."""
|
"""Cancel the given connection."""
|
||||||
raise dbt.exceptions.NotImplementedException(
|
raise dbt.exceptions.NotImplementedException(
|
||||||
'`cancel` is not implemented for this adapter!'
|
"`cancel` is not implemented for this adapter!"
|
||||||
)
|
)
|
||||||
|
|
||||||
def cancel_open(self) -> List[str]:
|
def cancel_open(self) -> List[str]:
|
||||||
@@ -41,8 +40,8 @@ class SQLConnectionManager(BaseConnectionManager):
|
|||||||
# if the connection failed, the handle will be None so we have
|
# if the connection failed, the handle will be None so we have
|
||||||
# nothing to cancel.
|
# nothing to cancel.
|
||||||
if (
|
if (
|
||||||
connection.handle is not None and
|
connection.handle is not None
|
||||||
connection.state == ConnectionState.OPEN
|
and connection.state == ConnectionState.OPEN
|
||||||
):
|
):
|
||||||
self.cancel(connection)
|
self.cancel(connection)
|
||||||
if connection.name is not None:
|
if connection.name is not None:
|
||||||
@@ -54,23 +53,22 @@ class SQLConnectionManager(BaseConnectionManager):
|
|||||||
sql: str,
|
sql: str,
|
||||||
auto_begin: bool = True,
|
auto_begin: bool = True,
|
||||||
bindings: Optional[Any] = None,
|
bindings: Optional[Any] = None,
|
||||||
abridge_sql_log: bool = False
|
abridge_sql_log: bool = False,
|
||||||
) -> Tuple[Connection, Any]:
|
) -> Tuple[Connection, Any]:
|
||||||
connection = self.get_thread_connection()
|
connection = self.get_thread_connection()
|
||||||
if auto_begin and connection.transaction_open is False:
|
if auto_begin and connection.transaction_open is False:
|
||||||
self.begin()
|
self.begin()
|
||||||
|
|
||||||
logger.debug('Using {} connection "{}".'
|
logger.debug('Using {} connection "{}".'.format(self.TYPE, connection.name))
|
||||||
.format(self.TYPE, connection.name))
|
|
||||||
|
|
||||||
with self.exception_handler(sql):
|
with self.exception_handler(sql):
|
||||||
if abridge_sql_log:
|
if abridge_sql_log:
|
||||||
log_sql = '{}...'.format(sql[:512])
|
log_sql = "{}...".format(sql[:512])
|
||||||
else:
|
else:
|
||||||
log_sql = sql
|
log_sql = sql
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
'On {connection_name}: {sql}',
|
"On {connection_name}: {sql}",
|
||||||
connection_name=connection.name,
|
connection_name=connection.name,
|
||||||
sql=log_sql,
|
sql=log_sql,
|
||||||
)
|
)
|
||||||
@@ -81,7 +79,7 @@ class SQLConnectionManager(BaseConnectionManager):
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
"SQL status: {status} in {elapsed:0.2f} seconds",
|
"SQL status: {status} in {elapsed:0.2f} seconds",
|
||||||
status=self.get_response(cursor),
|
status=self.get_response(cursor),
|
||||||
elapsed=(time.time() - pre)
|
elapsed=(time.time() - pre),
|
||||||
)
|
)
|
||||||
|
|
||||||
return connection, cursor
|
return connection, cursor
|
||||||
@@ -90,14 +88,12 @@ class SQLConnectionManager(BaseConnectionManager):
|
|||||||
def get_response(cls, cursor: Any) -> Union[AdapterResponse, str]:
|
def get_response(cls, cursor: Any) -> Union[AdapterResponse, str]:
|
||||||
"""Get the status of the cursor."""
|
"""Get the status of the cursor."""
|
||||||
raise dbt.exceptions.NotImplementedException(
|
raise dbt.exceptions.NotImplementedException(
|
||||||
'`get_response` is not implemented for this adapter!'
|
"`get_response` is not implemented for this adapter!"
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def process_results(
|
def process_results(
|
||||||
cls,
|
cls, column_names: Iterable[str], rows: Iterable[Any]
|
||||||
column_names: Iterable[str],
|
|
||||||
rows: Iterable[Any]
|
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
|
|
||||||
return [dict(zip(column_names, row)) for row in rows]
|
return [dict(zip(column_names, row)) for row in rows]
|
||||||
@@ -112,10 +108,7 @@ class SQLConnectionManager(BaseConnectionManager):
|
|||||||
rows = cursor.fetchall()
|
rows = cursor.fetchall()
|
||||||
data = cls.process_results(column_names, rows)
|
data = cls.process_results(column_names, rows)
|
||||||
|
|
||||||
return dbt.clients.agate_helper.table_from_data_flat(
|
return dbt.clients.agate_helper.table_from_data_flat(data, column_names)
|
||||||
data,
|
|
||||||
column_names
|
|
||||||
)
|
|
||||||
|
|
||||||
def execute(
|
def execute(
|
||||||
self, sql: str, auto_begin: bool = False, fetch: bool = False
|
self, sql: str, auto_begin: bool = False, fetch: bool = False
|
||||||
@@ -130,10 +123,10 @@ class SQLConnectionManager(BaseConnectionManager):
|
|||||||
return response, table
|
return response, table
|
||||||
|
|
||||||
def add_begin_query(self):
|
def add_begin_query(self):
|
||||||
return self.add_query('BEGIN', auto_begin=False)
|
return self.add_query("BEGIN", auto_begin=False)
|
||||||
|
|
||||||
def add_commit_query(self):
|
def add_commit_query(self):
|
||||||
return self.add_query('COMMIT', auto_begin=False)
|
return self.add_query("COMMIT", auto_begin=False)
|
||||||
|
|
||||||
def begin(self):
|
def begin(self):
|
||||||
connection = self.get_thread_connection()
|
connection = self.get_thread_connection()
|
||||||
@@ -141,13 +134,14 @@ class SQLConnectionManager(BaseConnectionManager):
|
|||||||
if flags.STRICT_MODE:
|
if flags.STRICT_MODE:
|
||||||
if not isinstance(connection, Connection):
|
if not isinstance(connection, Connection):
|
||||||
raise dbt.exceptions.CompilerException(
|
raise dbt.exceptions.CompilerException(
|
||||||
f'In begin, got {connection} - not a Connection!'
|
f"In begin, got {connection} - not a Connection!"
|
||||||
)
|
)
|
||||||
|
|
||||||
if connection.transaction_open is True:
|
if connection.transaction_open is True:
|
||||||
raise dbt.exceptions.InternalException(
|
raise dbt.exceptions.InternalException(
|
||||||
'Tried to begin a new transaction on connection "{}", but '
|
'Tried to begin a new transaction on connection "{}", but '
|
||||||
'it already had one open!'.format(connection.name))
|
"it already had one open!".format(connection.name)
|
||||||
|
)
|
||||||
|
|
||||||
self.add_begin_query()
|
self.add_begin_query()
|
||||||
|
|
||||||
@@ -159,15 +153,16 @@ class SQLConnectionManager(BaseConnectionManager):
|
|||||||
if flags.STRICT_MODE:
|
if flags.STRICT_MODE:
|
||||||
if not isinstance(connection, Connection):
|
if not isinstance(connection, Connection):
|
||||||
raise dbt.exceptions.CompilerException(
|
raise dbt.exceptions.CompilerException(
|
||||||
f'In commit, got {connection} - not a Connection!'
|
f"In commit, got {connection} - not a Connection!"
|
||||||
)
|
)
|
||||||
|
|
||||||
if connection.transaction_open is False:
|
if connection.transaction_open is False:
|
||||||
raise dbt.exceptions.InternalException(
|
raise dbt.exceptions.InternalException(
|
||||||
'Tried to commit transaction on connection "{}", but '
|
'Tried to commit transaction on connection "{}", but '
|
||||||
'it does not have one open!'.format(connection.name))
|
"it does not have one open!".format(connection.name)
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug('On {}: COMMIT'.format(connection.name))
|
logger.debug("On {}: COMMIT".format(connection.name))
|
||||||
self.add_commit_query()
|
self.add_commit_query()
|
||||||
|
|
||||||
connection.transaction_open = False
|
connection.transaction_open = False
|
||||||
|
|||||||
@@ -10,16 +10,16 @@ from dbt.logger import GLOBAL_LOGGER as logger
|
|||||||
|
|
||||||
from dbt.adapters.base.relation import BaseRelation
|
from dbt.adapters.base.relation import BaseRelation
|
||||||
|
|
||||||
LIST_RELATIONS_MACRO_NAME = 'list_relations_without_caching'
|
LIST_RELATIONS_MACRO_NAME = "list_relations_without_caching"
|
||||||
GET_COLUMNS_IN_RELATION_MACRO_NAME = 'get_columns_in_relation'
|
GET_COLUMNS_IN_RELATION_MACRO_NAME = "get_columns_in_relation"
|
||||||
LIST_SCHEMAS_MACRO_NAME = 'list_schemas'
|
LIST_SCHEMAS_MACRO_NAME = "list_schemas"
|
||||||
CHECK_SCHEMA_EXISTS_MACRO_NAME = 'check_schema_exists'
|
CHECK_SCHEMA_EXISTS_MACRO_NAME = "check_schema_exists"
|
||||||
CREATE_SCHEMA_MACRO_NAME = 'create_schema'
|
CREATE_SCHEMA_MACRO_NAME = "create_schema"
|
||||||
DROP_SCHEMA_MACRO_NAME = 'drop_schema'
|
DROP_SCHEMA_MACRO_NAME = "drop_schema"
|
||||||
RENAME_RELATION_MACRO_NAME = 'rename_relation'
|
RENAME_RELATION_MACRO_NAME = "rename_relation"
|
||||||
TRUNCATE_RELATION_MACRO_NAME = 'truncate_relation'
|
TRUNCATE_RELATION_MACRO_NAME = "truncate_relation"
|
||||||
DROP_RELATION_MACRO_NAME = 'drop_relation'
|
DROP_RELATION_MACRO_NAME = "drop_relation"
|
||||||
ALTER_COLUMN_TYPE_MACRO_NAME = 'alter_column_type'
|
ALTER_COLUMN_TYPE_MACRO_NAME = "alter_column_type"
|
||||||
|
|
||||||
|
|
||||||
class SQLAdapter(BaseAdapter):
|
class SQLAdapter(BaseAdapter):
|
||||||
@@ -60,30 +60,23 @@ class SQLAdapter(BaseAdapter):
|
|||||||
:param abridge_sql_log: If set, limit the raw sql logged to 512
|
:param abridge_sql_log: If set, limit the raw sql logged to 512
|
||||||
characters
|
characters
|
||||||
"""
|
"""
|
||||||
return self.connections.add_query(sql, auto_begin, bindings,
|
return self.connections.add_query(sql, auto_begin, bindings, abridge_sql_log)
|
||||||
abridge_sql_log)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_text_type(cls, agate_table: agate.Table, col_idx: int) -> str:
|
def convert_text_type(cls, agate_table: agate.Table, col_idx: int) -> str:
|
||||||
return "text"
|
return "text"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_number_type(
|
def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str:
|
||||||
cls, agate_table: agate.Table, col_idx: int
|
decimals = agate_table.aggregate(agate.MaxPrecision(col_idx)) # type: ignore
|
||||||
) -> str:
|
|
||||||
decimals = agate_table.aggregate(agate.MaxPrecision(col_idx))
|
|
||||||
return "float8" if decimals else "integer"
|
return "float8" if decimals else "integer"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_boolean_type(
|
def convert_boolean_type(cls, agate_table: agate.Table, col_idx: int) -> str:
|
||||||
cls, agate_table: agate.Table, col_idx: int
|
|
||||||
) -> str:
|
|
||||||
return "boolean"
|
return "boolean"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_datetime_type(
|
def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str:
|
||||||
cls, agate_table: agate.Table, col_idx: int
|
|
||||||
) -> str:
|
|
||||||
return "timestamp without time zone"
|
return "timestamp without time zone"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -99,31 +92,28 @@ class SQLAdapter(BaseAdapter):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def expand_column_types(self, goal, current):
|
def expand_column_types(self, goal, current):
|
||||||
reference_columns = {
|
reference_columns = {c.name: c for c in self.get_columns_in_relation(goal)}
|
||||||
c.name: c for c in
|
|
||||||
self.get_columns_in_relation(goal)
|
|
||||||
}
|
|
||||||
|
|
||||||
target_columns = {
|
target_columns = {c.name: c for c in self.get_columns_in_relation(current)}
|
||||||
c.name: c for c
|
|
||||||
in self.get_columns_in_relation(current)
|
|
||||||
}
|
|
||||||
|
|
||||||
for column_name, reference_column in reference_columns.items():
|
for column_name, reference_column in reference_columns.items():
|
||||||
target_column = target_columns.get(column_name)
|
target_column = target_columns.get(column_name)
|
||||||
|
|
||||||
if target_column is not None and \
|
if target_column is not None and target_column.can_expand_to(
|
||||||
target_column.can_expand_to(reference_column):
|
reference_column
|
||||||
|
):
|
||||||
col_string_size = reference_column.string_size()
|
col_string_size = reference_column.string_size()
|
||||||
new_type = self.Column.string_type(col_string_size)
|
new_type = self.Column.string_type(col_string_size)
|
||||||
logger.debug("Changing col type from {} to {} in table {}",
|
logger.debug(
|
||||||
target_column.data_type, new_type, current)
|
"Changing col type from {} to {} in table {}",
|
||||||
|
target_column.data_type,
|
||||||
|
new_type,
|
||||||
|
current,
|
||||||
|
)
|
||||||
|
|
||||||
self.alter_column_type(current, column_name, new_type)
|
self.alter_column_type(current, column_name, new_type)
|
||||||
|
|
||||||
def alter_column_type(
|
def alter_column_type(self, relation, column_name, new_column_type) -> None:
|
||||||
self, relation, column_name, new_column_type
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
1. Create a new column (w/ temp name and correct type)
|
1. Create a new column (w/ temp name and correct type)
|
||||||
2. Copy data over to it
|
2. Copy data over to it
|
||||||
@@ -131,53 +121,40 @@ class SQLAdapter(BaseAdapter):
|
|||||||
4. Rename the new column to existing column
|
4. Rename the new column to existing column
|
||||||
"""
|
"""
|
||||||
kwargs = {
|
kwargs = {
|
||||||
'relation': relation,
|
"relation": relation,
|
||||||
'column_name': column_name,
|
"column_name": column_name,
|
||||||
'new_column_type': new_column_type,
|
"new_column_type": new_column_type,
|
||||||
}
|
}
|
||||||
self.execute_macro(
|
self.execute_macro(ALTER_COLUMN_TYPE_MACRO_NAME, kwargs=kwargs)
|
||||||
ALTER_COLUMN_TYPE_MACRO_NAME,
|
|
||||||
kwargs=kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
def drop_relation(self, relation):
|
def drop_relation(self, relation):
|
||||||
if relation.type is None:
|
if relation.type is None:
|
||||||
dbt.exceptions.raise_compiler_error(
|
dbt.exceptions.raise_compiler_error(
|
||||||
'Tried to drop relation {}, but its type is null.'
|
"Tried to drop relation {}, but its type is null.".format(relation)
|
||||||
.format(relation))
|
)
|
||||||
|
|
||||||
self.cache_dropped(relation)
|
self.cache_dropped(relation)
|
||||||
self.execute_macro(
|
self.execute_macro(DROP_RELATION_MACRO_NAME, kwargs={"relation": relation})
|
||||||
DROP_RELATION_MACRO_NAME,
|
|
||||||
kwargs={'relation': relation}
|
|
||||||
)
|
|
||||||
|
|
||||||
def truncate_relation(self, relation):
|
def truncate_relation(self, relation):
|
||||||
self.execute_macro(
|
self.execute_macro(TRUNCATE_RELATION_MACRO_NAME, kwargs={"relation": relation})
|
||||||
TRUNCATE_RELATION_MACRO_NAME,
|
|
||||||
kwargs={'relation': relation}
|
|
||||||
)
|
|
||||||
|
|
||||||
def rename_relation(self, from_relation, to_relation):
|
def rename_relation(self, from_relation, to_relation):
|
||||||
self.cache_renamed(from_relation, to_relation)
|
self.cache_renamed(from_relation, to_relation)
|
||||||
|
|
||||||
kwargs = {'from_relation': from_relation, 'to_relation': to_relation}
|
kwargs = {"from_relation": from_relation, "to_relation": to_relation}
|
||||||
self.execute_macro(
|
self.execute_macro(RENAME_RELATION_MACRO_NAME, kwargs=kwargs)
|
||||||
RENAME_RELATION_MACRO_NAME,
|
|
||||||
kwargs=kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_columns_in_relation(self, relation):
|
def get_columns_in_relation(self, relation):
|
||||||
return self.execute_macro(
|
return self.execute_macro(
|
||||||
GET_COLUMNS_IN_RELATION_MACRO_NAME,
|
GET_COLUMNS_IN_RELATION_MACRO_NAME, kwargs={"relation": relation}
|
||||||
kwargs={'relation': relation}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_schema(self, relation: BaseRelation) -> None:
|
def create_schema(self, relation: BaseRelation) -> None:
|
||||||
relation = relation.without_identifier()
|
relation = relation.without_identifier()
|
||||||
logger.debug('Creating schema "{}"', relation)
|
logger.debug('Creating schema "{}"', relation)
|
||||||
kwargs = {
|
kwargs = {
|
||||||
'relation': relation,
|
"relation": relation,
|
||||||
}
|
}
|
||||||
self.execute_macro(CREATE_SCHEMA_MACRO_NAME, kwargs=kwargs)
|
self.execute_macro(CREATE_SCHEMA_MACRO_NAME, kwargs=kwargs)
|
||||||
self.commit_if_has_connection()
|
self.commit_if_has_connection()
|
||||||
@@ -188,39 +165,35 @@ class SQLAdapter(BaseAdapter):
|
|||||||
relation = relation.without_identifier()
|
relation = relation.without_identifier()
|
||||||
logger.debug('Dropping schema "{}".', relation)
|
logger.debug('Dropping schema "{}".', relation)
|
||||||
kwargs = {
|
kwargs = {
|
||||||
'relation': relation,
|
"relation": relation,
|
||||||
}
|
}
|
||||||
self.execute_macro(DROP_SCHEMA_MACRO_NAME, kwargs=kwargs)
|
self.execute_macro(DROP_SCHEMA_MACRO_NAME, kwargs=kwargs)
|
||||||
# we can update the cache here
|
# we can update the cache here
|
||||||
self.cache.drop_schema(relation.database, relation.schema)
|
self.cache.drop_schema(relation.database, relation.schema)
|
||||||
|
|
||||||
def list_relations_without_caching(
|
def list_relations_without_caching(
|
||||||
self, schema_relation: BaseRelation,
|
self,
|
||||||
|
schema_relation: BaseRelation,
|
||||||
) -> List[BaseRelation]:
|
) -> List[BaseRelation]:
|
||||||
kwargs = {'schema_relation': schema_relation}
|
kwargs = {"schema_relation": schema_relation}
|
||||||
results = self.execute_macro(
|
results = self.execute_macro(LIST_RELATIONS_MACRO_NAME, kwargs=kwargs)
|
||||||
LIST_RELATIONS_MACRO_NAME,
|
|
||||||
kwargs=kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
relations = []
|
relations = []
|
||||||
quote_policy = {
|
quote_policy = {"database": True, "schema": True, "identifier": True}
|
||||||
'database': True,
|
|
||||||
'schema': True,
|
|
||||||
'identifier': True
|
|
||||||
}
|
|
||||||
for _database, name, _schema, _type in results:
|
for _database, name, _schema, _type in results:
|
||||||
try:
|
try:
|
||||||
_type = self.Relation.get_relation_type(_type)
|
_type = self.Relation.get_relation_type(_type)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
_type = self.Relation.External
|
_type = self.Relation.External
|
||||||
relations.append(self.Relation.create(
|
relations.append(
|
||||||
database=_database,
|
self.Relation.create(
|
||||||
schema=_schema,
|
database=_database,
|
||||||
identifier=name,
|
schema=_schema,
|
||||||
quote_policy=quote_policy,
|
identifier=name,
|
||||||
type=_type
|
quote_policy=quote_policy,
|
||||||
))
|
type=_type,
|
||||||
|
)
|
||||||
|
)
|
||||||
return relations
|
return relations
|
||||||
|
|
||||||
def quote(self, identifier):
|
def quote(self, identifier):
|
||||||
@@ -228,8 +201,7 @@ class SQLAdapter(BaseAdapter):
|
|||||||
|
|
||||||
def list_schemas(self, database: str) -> List[str]:
|
def list_schemas(self, database: str) -> List[str]:
|
||||||
results = self.execute_macro(
|
results = self.execute_macro(
|
||||||
LIST_SCHEMAS_MACRO_NAME,
|
LIST_SCHEMAS_MACRO_NAME, kwargs={"database": database}
|
||||||
kwargs={'database': database}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return [row[0] for row in results]
|
return [row[0] for row in results]
|
||||||
@@ -238,13 +210,10 @@ class SQLAdapter(BaseAdapter):
|
|||||||
information_schema = self.Relation.create(
|
information_schema = self.Relation.create(
|
||||||
database=database,
|
database=database,
|
||||||
schema=schema,
|
schema=schema,
|
||||||
identifier='INFORMATION_SCHEMA',
|
identifier="INFORMATION_SCHEMA",
|
||||||
quote_policy=self.config.quoting
|
quote_policy=self.config.quoting,
|
||||||
).information_schema()
|
).information_schema()
|
||||||
|
|
||||||
kwargs = {'information_schema': information_schema, 'schema': schema}
|
kwargs = {"information_schema": information_schema, "schema": schema}
|
||||||
results = self.execute_macro(
|
results = self.execute_macro(CHECK_SCHEMA_EXISTS_MACRO_NAME, kwargs=kwargs)
|
||||||
CHECK_SCHEMA_EXISTS_MACRO_NAME,
|
|
||||||
kwargs=kwargs
|
|
||||||
)
|
|
||||||
return results[0][0] > 0
|
return results[0][0] > 0
|
||||||
|
|||||||
@@ -10,79 +10,89 @@ def regex(pat):
|
|||||||
|
|
||||||
class BlockData:
|
class BlockData:
|
||||||
"""raw plaintext data from the top level of the file."""
|
"""raw plaintext data from the top level of the file."""
|
||||||
|
|
||||||
def __init__(self, contents):
|
def __init__(self, contents):
|
||||||
self.block_type_name = '__dbt__data'
|
self.block_type_name = "__dbt__data"
|
||||||
self.contents = contents
|
self.contents = contents
|
||||||
self.full_block = contents
|
self.full_block = contents
|
||||||
|
|
||||||
|
|
||||||
class BlockTag:
|
class BlockTag:
|
||||||
def __init__(self, block_type_name, block_name, contents=None,
|
def __init__(
|
||||||
full_block=None, **kw):
|
self, block_type_name, block_name, contents=None, full_block=None, **kw
|
||||||
|
):
|
||||||
self.block_type_name = block_type_name
|
self.block_type_name = block_type_name
|
||||||
self.block_name = block_name
|
self.block_name = block_name
|
||||||
self.contents = contents
|
self.contents = contents
|
||||||
self.full_block = full_block
|
self.full_block = full_block
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return 'BlockTag({!r}, {!r})'.format(self.block_type_name,
|
return "BlockTag({!r}, {!r})".format(self.block_type_name, self.block_name)
|
||||||
self.block_name)
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return str(self)
|
return str(self)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def end_block_type_name(self):
|
def end_block_type_name(self):
|
||||||
return 'end{}'.format(self.block_type_name)
|
return "end{}".format(self.block_type_name)
|
||||||
|
|
||||||
def end_pat(self):
|
def end_pat(self):
|
||||||
# we don't want to use string formatting here because jinja uses most
|
# we don't want to use string formatting here because jinja uses most
|
||||||
# of the string formatting operators in its syntax...
|
# of the string formatting operators in its syntax...
|
||||||
pattern = ''.join((
|
pattern = "".join(
|
||||||
r'(?P<endblock>((?:\s*\{\%\-|\{\%)\s*',
|
(
|
||||||
self.end_block_type_name,
|
r"(?P<endblock>((?:\s*\{\%\-|\{\%)\s*",
|
||||||
r'\s*(?:\-\%\}\s*|\%\})))',
|
self.end_block_type_name,
|
||||||
))
|
r"\s*(?:\-\%\}\s*|\%\})))",
|
||||||
|
)
|
||||||
|
)
|
||||||
return regex(pattern)
|
return regex(pattern)
|
||||||
|
|
||||||
|
|
||||||
Tag = namedtuple('Tag', 'block_type_name block_name start end')
|
Tag = namedtuple("Tag", "block_type_name block_name start end")
|
||||||
|
|
||||||
|
|
||||||
_NAME_PATTERN = r'[A-Za-z_][A-Za-z_0-9]*'
|
_NAME_PATTERN = r"[A-Za-z_][A-Za-z_0-9]*"
|
||||||
|
|
||||||
COMMENT_START_PATTERN = regex(r'(?:(?P<comment_start>(\s*\{\#)))')
|
COMMENT_START_PATTERN = regex(r"(?:(?P<comment_start>(\s*\{\#)))")
|
||||||
COMMENT_END_PATTERN = regex(r'(.*?)(\s*\#\})')
|
COMMENT_END_PATTERN = regex(r"(.*?)(\s*\#\})")
|
||||||
RAW_START_PATTERN = regex(
|
RAW_START_PATTERN = regex(
|
||||||
r'(?:\s*\{\%\-|\{\%)\s*(?P<raw_start>(raw))\s*(?:\-\%\}\s*|\%\})'
|
r"(?:\s*\{\%\-|\{\%)\s*(?P<raw_start>(raw))\s*(?:\-\%\}\s*|\%\})"
|
||||||
)
|
)
|
||||||
EXPR_START_PATTERN = regex(r'(?P<expr_start>(\{\{\s*))')
|
EXPR_START_PATTERN = regex(r"(?P<expr_start>(\{\{\s*))")
|
||||||
EXPR_END_PATTERN = regex(r'(?P<expr_end>(\s*\}\}))')
|
EXPR_END_PATTERN = regex(r"(?P<expr_end>(\s*\}\}))")
|
||||||
|
|
||||||
BLOCK_START_PATTERN = regex(''.join((
|
BLOCK_START_PATTERN = regex(
|
||||||
r'(?:\s*\{\%\-|\{\%)\s*',
|
"".join(
|
||||||
r'(?P<block_type_name>({}))'.format(_NAME_PATTERN),
|
(
|
||||||
# some blocks have a 'block name'.
|
r"(?:\s*\{\%\-|\{\%)\s*",
|
||||||
r'(?:\s+(?P<block_name>({})))?'.format(_NAME_PATTERN),
|
r"(?P<block_type_name>({}))".format(_NAME_PATTERN),
|
||||||
)))
|
# some blocks have a 'block name'.
|
||||||
|
r"(?:\s+(?P<block_name>({})))?".format(_NAME_PATTERN),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
RAW_BLOCK_PATTERN = regex(''.join((
|
RAW_BLOCK_PATTERN = regex(
|
||||||
r'(?:\s*\{\%\-|\{\%)\s*raw\s*(?:\-\%\}\s*|\%\})',
|
"".join(
|
||||||
r'(?:.*?)',
|
(
|
||||||
r'(?:\s*\{\%\-|\{\%)\s*endraw\s*(?:\-\%\}\s*|\%\})',
|
r"(?:\s*\{\%\-|\{\%)\s*raw\s*(?:\-\%\}\s*|\%\})",
|
||||||
)))
|
r"(?:.*?)",
|
||||||
|
r"(?:\s*\{\%\-|\{\%)\s*endraw\s*(?:\-\%\}\s*|\%\})",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
TAG_CLOSE_PATTERN = regex(r'(?:(?P<tag_close>(\-\%\}\s*|\%\})))')
|
TAG_CLOSE_PATTERN = regex(r"(?:(?P<tag_close>(\-\%\}\s*|\%\})))")
|
||||||
|
|
||||||
# stolen from jinja's lexer. Note that we've consumed all prefix whitespace by
|
# stolen from jinja's lexer. Note that we've consumed all prefix whitespace by
|
||||||
# the time we want to use this.
|
# the time we want to use this.
|
||||||
STRING_PATTERN = regex(
|
STRING_PATTERN = regex(
|
||||||
r"(?P<string>('([^'\\]*(?:\\.[^'\\]*)*)'|"
|
r"(?P<string>('([^'\\]*(?:\\.[^'\\]*)*)'|" r'"([^"\\]*(?:\\.[^"\\]*)*)"))'
|
||||||
r'"([^"\\]*(?:\\.[^"\\]*)*)"))'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
QUOTE_START_PATTERN = regex(r'''(?P<quote>(['"]))''')
|
QUOTE_START_PATTERN = regex(r"""(?P<quote>(['"]))""")
|
||||||
|
|
||||||
|
|
||||||
class TagIterator:
|
class TagIterator:
|
||||||
@@ -99,10 +109,10 @@ class TagIterator:
|
|||||||
end_val: int = self.pos if end is None else end
|
end_val: int = self.pos if end is None else end
|
||||||
data = self.data[:end_val]
|
data = self.data[:end_val]
|
||||||
# if not found, rfind returns -1, and -1+1=0, which is perfect!
|
# if not found, rfind returns -1, and -1+1=0, which is perfect!
|
||||||
last_line_start = data.rfind('\n') + 1
|
last_line_start = data.rfind("\n") + 1
|
||||||
# it's easy to forget this, but line numbers are 1-indexed
|
# it's easy to forget this, but line numbers are 1-indexed
|
||||||
line_number = data.count('\n') + 1
|
line_number = data.count("\n") + 1
|
||||||
return f'{line_number}:{end_val - last_line_start}'
|
return f"{line_number}:{end_val - last_line_start}"
|
||||||
|
|
||||||
def advance(self, new_position):
|
def advance(self, new_position):
|
||||||
self.pos = new_position
|
self.pos = new_position
|
||||||
@@ -120,7 +130,7 @@ class TagIterator:
|
|||||||
matches = []
|
matches = []
|
||||||
for pattern in patterns:
|
for pattern in patterns:
|
||||||
# default to 'search', but sometimes we want to 'match'.
|
# default to 'search', but sometimes we want to 'match'.
|
||||||
if kwargs.get('method', 'search') == 'search':
|
if kwargs.get("method", "search") == "search":
|
||||||
match = self._search(pattern)
|
match = self._search(pattern)
|
||||||
else:
|
else:
|
||||||
match = self._match(pattern)
|
match = self._match(pattern)
|
||||||
@@ -136,7 +146,7 @@ class TagIterator:
|
|||||||
match = self._first_match(*patterns, **kwargs)
|
match = self._first_match(*patterns, **kwargs)
|
||||||
if match is None:
|
if match is None:
|
||||||
msg = 'unexpected EOF, expected {}, got "{}"'.format(
|
msg = 'unexpected EOF, expected {}, got "{}"'.format(
|
||||||
expected_name, self.data[self.pos:]
|
expected_name, self.data[self.pos :]
|
||||||
)
|
)
|
||||||
dbt.exceptions.raise_compiler_error(msg)
|
dbt.exceptions.raise_compiler_error(msg)
|
||||||
return match
|
return match
|
||||||
@@ -156,22 +166,20 @@ class TagIterator:
|
|||||||
"""
|
"""
|
||||||
self.advance(match.end())
|
self.advance(match.end())
|
||||||
while True:
|
while True:
|
||||||
match = self._expect_match('}}',
|
match = self._expect_match("}}", EXPR_END_PATTERN, QUOTE_START_PATTERN)
|
||||||
EXPR_END_PATTERN,
|
if match.groupdict().get("expr_end") is not None:
|
||||||
QUOTE_START_PATTERN)
|
|
||||||
if match.groupdict().get('expr_end') is not None:
|
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# it's a quote. we haven't advanced for this match yet, so
|
# it's a quote. we haven't advanced for this match yet, so
|
||||||
# just slurp up the whole string, no need to rewind.
|
# just slurp up the whole string, no need to rewind.
|
||||||
match = self._expect_match('string', STRING_PATTERN)
|
match = self._expect_match("string", STRING_PATTERN)
|
||||||
self.advance(match.end())
|
self.advance(match.end())
|
||||||
|
|
||||||
self.advance(match.end())
|
self.advance(match.end())
|
||||||
|
|
||||||
def handle_comment(self, match):
|
def handle_comment(self, match):
|
||||||
self.advance(match.end())
|
self.advance(match.end())
|
||||||
match = self._expect_match('#}', COMMENT_END_PATTERN)
|
match = self._expect_match("#}", COMMENT_END_PATTERN)
|
||||||
self.advance(match.end())
|
self.advance(match.end())
|
||||||
|
|
||||||
def _expect_block_close(self):
|
def _expect_block_close(self):
|
||||||
@@ -188,22 +196,19 @@ class TagIterator:
|
|||||||
"""
|
"""
|
||||||
while True:
|
while True:
|
||||||
end_match = self._expect_match(
|
end_match = self._expect_match(
|
||||||
'tag close ("%}")',
|
'tag close ("%}")', QUOTE_START_PATTERN, TAG_CLOSE_PATTERN
|
||||||
QUOTE_START_PATTERN,
|
|
||||||
TAG_CLOSE_PATTERN
|
|
||||||
)
|
)
|
||||||
self.advance(end_match.end())
|
self.advance(end_match.end())
|
||||||
if end_match.groupdict().get('tag_close') is not None:
|
if end_match.groupdict().get("tag_close") is not None:
|
||||||
return
|
return
|
||||||
# must be a string. Rewind to its start and advance past it.
|
# must be a string. Rewind to its start and advance past it.
|
||||||
self.rewind()
|
self.rewind()
|
||||||
string_match = self._expect_match('string', STRING_PATTERN)
|
string_match = self._expect_match("string", STRING_PATTERN)
|
||||||
self.advance(string_match.end())
|
self.advance(string_match.end())
|
||||||
|
|
||||||
def handle_raw(self):
|
def handle_raw(self):
|
||||||
# raw blocks are super special, they are a single complete regex
|
# raw blocks are super special, they are a single complete regex
|
||||||
match = self._expect_match('{% raw %}...{% endraw %}',
|
match = self._expect_match("{% raw %}...{% endraw %}", RAW_BLOCK_PATTERN)
|
||||||
RAW_BLOCK_PATTERN)
|
|
||||||
self.advance(match.end())
|
self.advance(match.end())
|
||||||
return match.end()
|
return match.end()
|
||||||
|
|
||||||
@@ -220,13 +225,12 @@ class TagIterator:
|
|||||||
"""
|
"""
|
||||||
groups = match.groupdict()
|
groups = match.groupdict()
|
||||||
# always a value
|
# always a value
|
||||||
block_type_name = groups['block_type_name']
|
block_type_name = groups["block_type_name"]
|
||||||
# might be None
|
# might be None
|
||||||
block_name = groups.get('block_name')
|
block_name = groups.get("block_name")
|
||||||
start_pos = self.pos
|
start_pos = self.pos
|
||||||
if block_type_name == 'raw':
|
if block_type_name == "raw":
|
||||||
match = self._expect_match('{% raw %}...{% endraw %}',
|
match = self._expect_match("{% raw %}...{% endraw %}", RAW_BLOCK_PATTERN)
|
||||||
RAW_BLOCK_PATTERN)
|
|
||||||
self.advance(match.end())
|
self.advance(match.end())
|
||||||
else:
|
else:
|
||||||
self.advance(match.end())
|
self.advance(match.end())
|
||||||
@@ -235,15 +239,13 @@ class TagIterator:
|
|||||||
block_type_name=block_type_name,
|
block_type_name=block_type_name,
|
||||||
block_name=block_name,
|
block_name=block_name,
|
||||||
start=start_pos,
|
start=start_pos,
|
||||||
end=self.pos
|
end=self.pos,
|
||||||
)
|
)
|
||||||
|
|
||||||
def find_tags(self):
|
def find_tags(self):
|
||||||
while True:
|
while True:
|
||||||
match = self._first_match(
|
match = self._first_match(
|
||||||
BLOCK_START_PATTERN,
|
BLOCK_START_PATTERN, COMMENT_START_PATTERN, EXPR_START_PATTERN
|
||||||
COMMENT_START_PATTERN,
|
|
||||||
EXPR_START_PATTERN
|
|
||||||
)
|
)
|
||||||
if match is None:
|
if match is None:
|
||||||
break
|
break
|
||||||
@@ -252,9 +254,9 @@ class TagIterator:
|
|||||||
# start = self.pos
|
# start = self.pos
|
||||||
|
|
||||||
groups = match.groupdict()
|
groups = match.groupdict()
|
||||||
comment_start = groups.get('comment_start')
|
comment_start = groups.get("comment_start")
|
||||||
expr_start = groups.get('expr_start')
|
expr_start = groups.get("expr_start")
|
||||||
block_type_name = groups.get('block_type_name')
|
block_type_name = groups.get("block_type_name")
|
||||||
|
|
||||||
if comment_start is not None:
|
if comment_start is not None:
|
||||||
self.handle_comment(match)
|
self.handle_comment(match)
|
||||||
@@ -264,8 +266,8 @@ class TagIterator:
|
|||||||
yield self.handle_tag(match)
|
yield self.handle_tag(match)
|
||||||
else:
|
else:
|
||||||
raise dbt.exceptions.InternalException(
|
raise dbt.exceptions.InternalException(
|
||||||
'Invalid regex match in next_block, expected block start, '
|
"Invalid regex match in next_block, expected block start, "
|
||||||
'expr start, or comment start'
|
"expr start, or comment start"
|
||||||
)
|
)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
@@ -273,21 +275,18 @@ class TagIterator:
|
|||||||
|
|
||||||
|
|
||||||
duplicate_tags = (
|
duplicate_tags = (
|
||||||
'Got nested tags: {outer.block_type_name} (started at {outer.start}) did '
|
"Got nested tags: {outer.block_type_name} (started at {outer.start}) did "
|
||||||
'not have a matching {{% end{outer.block_type_name} %}} before a '
|
"not have a matching {{% end{outer.block_type_name} %}} before a "
|
||||||
'subsequent {inner.block_type_name} was found (started at {inner.start})'
|
"subsequent {inner.block_type_name} was found (started at {inner.start})"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_CONTROL_FLOW_TAGS = {
|
_CONTROL_FLOW_TAGS = {
|
||||||
'if': 'endif',
|
"if": "endif",
|
||||||
'for': 'endfor',
|
"for": "endfor",
|
||||||
}
|
}
|
||||||
|
|
||||||
_CONTROL_FLOW_END_TAGS = {
|
_CONTROL_FLOW_END_TAGS = {v: k for k, v in _CONTROL_FLOW_TAGS.items()}
|
||||||
v: k
|
|
||||||
for k, v in _CONTROL_FLOW_TAGS.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class BlockIterator:
|
class BlockIterator:
|
||||||
@@ -310,15 +309,15 @@ class BlockIterator:
|
|||||||
|
|
||||||
def is_current_end(self, tag):
|
def is_current_end(self, tag):
|
||||||
return (
|
return (
|
||||||
tag.block_type_name.startswith('end') and
|
tag.block_type_name.startswith("end")
|
||||||
self.current is not None and
|
and self.current is not None
|
||||||
tag.block_type_name[3:] == self.current.block_type_name
|
and tag.block_type_name[3:] == self.current.block_type_name
|
||||||
)
|
)
|
||||||
|
|
||||||
def find_blocks(self, allowed_blocks=None, collect_raw_data=True):
|
def find_blocks(self, allowed_blocks=None, collect_raw_data=True):
|
||||||
"""Find all top-level blocks in the data."""
|
"""Find all top-level blocks in the data."""
|
||||||
if allowed_blocks is None:
|
if allowed_blocks is None:
|
||||||
allowed_blocks = {'snapshot', 'macro', 'materialization', 'docs'}
|
allowed_blocks = {"snapshot", "macro", "materialization", "docs"}
|
||||||
|
|
||||||
for tag in self.tag_parser.find_tags():
|
for tag in self.tag_parser.find_tags():
|
||||||
if tag.block_type_name in _CONTROL_FLOW_TAGS:
|
if tag.block_type_name in _CONTROL_FLOW_TAGS:
|
||||||
@@ -329,37 +328,43 @@ class BlockIterator:
|
|||||||
found = self.stack.pop()
|
found = self.stack.pop()
|
||||||
else:
|
else:
|
||||||
expected = _CONTROL_FLOW_END_TAGS[tag.block_type_name]
|
expected = _CONTROL_FLOW_END_TAGS[tag.block_type_name]
|
||||||
dbt.exceptions.raise_compiler_error((
|
dbt.exceptions.raise_compiler_error(
|
||||||
'Got an unexpected control flow end tag, got {} but '
|
(
|
||||||
'never saw a preceeding {} (@ {})'
|
"Got an unexpected control flow end tag, got {} but "
|
||||||
).format(
|
"never saw a preceeding {} (@ {})"
|
||||||
tag.block_type_name,
|
).format(
|
||||||
expected,
|
tag.block_type_name,
|
||||||
self.tag_parser.linepos(tag.start)
|
expected,
|
||||||
))
|
self.tag_parser.linepos(tag.start),
|
||||||
|
)
|
||||||
|
)
|
||||||
expected = _CONTROL_FLOW_TAGS[found]
|
expected = _CONTROL_FLOW_TAGS[found]
|
||||||
if expected != tag.block_type_name:
|
if expected != tag.block_type_name:
|
||||||
dbt.exceptions.raise_compiler_error((
|
dbt.exceptions.raise_compiler_error(
|
||||||
'Got an unexpected control flow end tag, got {} but '
|
(
|
||||||
'expected {} next (@ {})'
|
"Got an unexpected control flow end tag, got {} but "
|
||||||
).format(
|
"expected {} next (@ {})"
|
||||||
tag.block_type_name,
|
).format(
|
||||||
expected,
|
tag.block_type_name,
|
||||||
self.tag_parser.linepos(tag.start)
|
expected,
|
||||||
))
|
self.tag_parser.linepos(tag.start),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if tag.block_type_name in allowed_blocks:
|
if tag.block_type_name in allowed_blocks:
|
||||||
if self.stack:
|
if self.stack:
|
||||||
dbt.exceptions.raise_compiler_error((
|
dbt.exceptions.raise_compiler_error(
|
||||||
'Got a block definition inside control flow at {}. '
|
(
|
||||||
'All dbt block definitions must be at the top level'
|
"Got a block definition inside control flow at {}. "
|
||||||
).format(self.tag_parser.linepos(tag.start)))
|
"All dbt block definitions must be at the top level"
|
||||||
|
).format(self.tag_parser.linepos(tag.start))
|
||||||
|
)
|
||||||
if self.current is not None:
|
if self.current is not None:
|
||||||
dbt.exceptions.raise_compiler_error(
|
dbt.exceptions.raise_compiler_error(
|
||||||
duplicate_tags.format(outer=self.current, inner=tag)
|
duplicate_tags.format(outer=self.current, inner=tag)
|
||||||
)
|
)
|
||||||
if collect_raw_data:
|
if collect_raw_data:
|
||||||
raw_data = self.data[self.last_position:tag.start]
|
raw_data = self.data[self.last_position : tag.start]
|
||||||
self.last_position = tag.start
|
self.last_position = tag.start
|
||||||
if raw_data:
|
if raw_data:
|
||||||
yield BlockData(raw_data)
|
yield BlockData(raw_data)
|
||||||
@@ -371,23 +376,28 @@ class BlockIterator:
|
|||||||
yield BlockTag(
|
yield BlockTag(
|
||||||
block_type_name=self.current.block_type_name,
|
block_type_name=self.current.block_type_name,
|
||||||
block_name=self.current.block_name,
|
block_name=self.current.block_name,
|
||||||
contents=self.data[self.current.end:tag.start],
|
contents=self.data[self.current.end : tag.start],
|
||||||
full_block=self.data[self.current.start:tag.end]
|
full_block=self.data[self.current.start : tag.end],
|
||||||
)
|
)
|
||||||
self.current = None
|
self.current = None
|
||||||
|
|
||||||
if self.current:
|
if self.current:
|
||||||
linecount = self.data[:self.current.end].count('\n') + 1
|
linecount = self.data[: self.current.end].count("\n") + 1
|
||||||
dbt.exceptions.raise_compiler_error((
|
dbt.exceptions.raise_compiler_error(
|
||||||
'Reached EOF without finding a close tag for '
|
(
|
||||||
'{} (searched from line {})'
|
"Reached EOF without finding a close tag for "
|
||||||
).format(self.current.block_type_name, linecount))
|
"{} (searched from line {})"
|
||||||
|
).format(self.current.block_type_name, linecount)
|
||||||
|
)
|
||||||
|
|
||||||
if collect_raw_data:
|
if collect_raw_data:
|
||||||
raw_data = self.data[self.last_position:]
|
raw_data = self.data[self.last_position :]
|
||||||
if raw_data:
|
if raw_data:
|
||||||
yield BlockData(raw_data)
|
yield BlockData(raw_data)
|
||||||
|
|
||||||
def lex_for_blocks(self, allowed_blocks=None, collect_raw_data=True):
|
def lex_for_blocks(self, allowed_blocks=None, collect_raw_data=True):
|
||||||
return list(self.find_blocks(allowed_blocks=allowed_blocks,
|
return list(
|
||||||
collect_raw_data=collect_raw_data))
|
self.find_blocks(
|
||||||
|
allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from typing import Iterable, List, Dict, Union, Optional, Any
|
|||||||
from dbt.exceptions import RuntimeException
|
from dbt.exceptions import RuntimeException
|
||||||
|
|
||||||
|
|
||||||
BOM = BOM_UTF8.decode('utf-8') # '\ufeff'
|
BOM = BOM_UTF8.decode("utf-8") # '\ufeff'
|
||||||
|
|
||||||
|
|
||||||
class ISODateTime(agate.data_types.DateTime):
|
class ISODateTime(agate.data_types.DateTime):
|
||||||
@@ -30,28 +30,23 @@ class ISODateTime(agate.data_types.DateTime):
|
|||||||
except: # noqa
|
except: # noqa
|
||||||
pass
|
pass
|
||||||
|
|
||||||
raise agate.exceptions.CastError(
|
raise agate.exceptions.CastError('Can not parse value "%s" as datetime.' % d)
|
||||||
'Can not parse value "%s" as datetime.' % d
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_type_tester(text_columns: Iterable[str]) -> agate.TypeTester:
|
def build_type_tester(text_columns: Iterable[str]) -> agate.TypeTester:
|
||||||
types = [
|
types = [
|
||||||
agate.data_types.Number(null_values=('null', '')),
|
agate.data_types.Number(null_values=("null", "")),
|
||||||
agate.data_types.Date(null_values=('null', ''),
|
agate.data_types.Date(null_values=("null", ""), date_format="%Y-%m-%d"),
|
||||||
date_format='%Y-%m-%d'),
|
agate.data_types.DateTime(
|
||||||
agate.data_types.DateTime(null_values=('null', ''),
|
null_values=("null", ""), datetime_format="%Y-%m-%d %H:%M:%S"
|
||||||
datetime_format='%Y-%m-%d %H:%M:%S'),
|
),
|
||||||
ISODateTime(null_values=('null', '')),
|
ISODateTime(null_values=("null", "")),
|
||||||
agate.data_types.Boolean(true_values=('true',),
|
agate.data_types.Boolean(
|
||||||
false_values=('false',),
|
true_values=("true",), false_values=("false",), null_values=("null", "")
|
||||||
null_values=('null', '')),
|
),
|
||||||
agate.data_types.Text(null_values=('null', ''))
|
agate.data_types.Text(null_values=("null", "")),
|
||||||
]
|
]
|
||||||
force = {
|
force = {k: agate.data_types.Text(null_values=("null", "")) for k in text_columns}
|
||||||
k: agate.data_types.Text(null_values=('null', ''))
|
|
||||||
for k in text_columns
|
|
||||||
}
|
|
||||||
return agate.TypeTester(force=force, types=types)
|
return agate.TypeTester(force=force, types=types)
|
||||||
|
|
||||||
|
|
||||||
@@ -115,7 +110,7 @@ def as_matrix(table):
|
|||||||
|
|
||||||
def from_csv(abspath, text_columns):
|
def from_csv(abspath, text_columns):
|
||||||
type_tester = build_type_tester(text_columns=text_columns)
|
type_tester = build_type_tester(text_columns=text_columns)
|
||||||
with open(abspath, encoding='utf-8') as fp:
|
with open(abspath, encoding="utf-8") as fp:
|
||||||
if fp.read(1) != BOM:
|
if fp.read(1) != BOM:
|
||||||
fp.seek(0)
|
fp.seek(0)
|
||||||
return agate.Table.from_csv(fp, column_types=type_tester)
|
return agate.Table.from_csv(fp, column_types=type_tester)
|
||||||
@@ -147,8 +142,8 @@ class ColumnTypeBuilder(Dict[str, NullableAgateType]):
|
|||||||
elif not isinstance(value, type(existing_type)):
|
elif not isinstance(value, type(existing_type)):
|
||||||
# actual type mismatch!
|
# actual type mismatch!
|
||||||
raise RuntimeException(
|
raise RuntimeException(
|
||||||
f'Tables contain columns with the same names ({key}), '
|
f"Tables contain columns with the same names ({key}), "
|
||||||
f'but different types ({value} vs {existing_type})'
|
f"but different types ({value} vs {existing_type})"
|
||||||
)
|
)
|
||||||
|
|
||||||
def finalize(self) -> Dict[str, agate.data_types.DataType]:
|
def finalize(self) -> Dict[str, agate.data_types.DataType]:
|
||||||
@@ -163,7 +158,7 @@ class ColumnTypeBuilder(Dict[str, NullableAgateType]):
|
|||||||
|
|
||||||
|
|
||||||
def _merged_column_types(
|
def _merged_column_types(
|
||||||
tables: List[agate.Table]
|
tables: List[agate.Table],
|
||||||
) -> Dict[str, agate.data_types.DataType]:
|
) -> Dict[str, agate.data_types.DataType]:
|
||||||
# this is a lot like agate.Table.merge, but with handling for all-null
|
# this is a lot like agate.Table.merge, but with handling for all-null
|
||||||
# rows being "any type".
|
# rows being "any type".
|
||||||
@@ -190,10 +185,7 @@ def merge_tables(tables: List[agate.Table]) -> agate.Table:
|
|||||||
|
|
||||||
rows: List[agate.Row] = []
|
rows: List[agate.Row] = []
|
||||||
for table in tables:
|
for table in tables:
|
||||||
if (
|
if table.column_names == column_names and table.column_types == column_types:
|
||||||
table.column_names == column_names and
|
|
||||||
table.column_types == column_types
|
|
||||||
):
|
|
||||||
rows.extend(table.rows)
|
rows.extend(table.rows)
|
||||||
else:
|
else:
|
||||||
for row in table.rows:
|
for row in table.rows:
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ https://cloud.google.com/sdk/
|
|||||||
|
|
||||||
def gcloud_installed():
|
def gcloud_installed():
|
||||||
try:
|
try:
|
||||||
run_cmd('.', ['gcloud', '--version'])
|
run_cmd(".", ["gcloud", "--version"])
|
||||||
return True
|
return True
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
logger.debug(e)
|
logger.debug(e)
|
||||||
@@ -21,6 +21,6 @@ def gcloud_installed():
|
|||||||
|
|
||||||
def setup_default_credentials():
|
def setup_default_credentials():
|
||||||
if gcloud_installed():
|
if gcloud_installed():
|
||||||
run_cmd('.', ["gcloud", "auth", "application-default", "login"])
|
run_cmd(".", ["gcloud", "auth", "application-default", "login"])
|
||||||
else:
|
else:
|
||||||
raise dbt.exceptions.RuntimeException(NOT_INSTALLED_MSG)
|
raise dbt.exceptions.RuntimeException(NOT_INSTALLED_MSG)
|
||||||
|
|||||||
@@ -7,77 +7,74 @@ import dbt.exceptions
|
|||||||
|
|
||||||
|
|
||||||
def clone(repo, cwd, dirname=None, remove_git_dir=False, branch=None):
|
def clone(repo, cwd, dirname=None, remove_git_dir=False, branch=None):
|
||||||
clone_cmd = ['git', 'clone', '--depth', '1']
|
clone_cmd = ["git", "clone", "--depth", "1"]
|
||||||
|
|
||||||
if branch is not None:
|
if branch is not None:
|
||||||
clone_cmd.extend(['--branch', branch])
|
clone_cmd.extend(["--branch", branch])
|
||||||
|
|
||||||
clone_cmd.append(repo)
|
clone_cmd.append(repo)
|
||||||
|
|
||||||
if dirname is not None:
|
if dirname is not None:
|
||||||
clone_cmd.append(dirname)
|
clone_cmd.append(dirname)
|
||||||
|
|
||||||
result = run_cmd(cwd, clone_cmd, env={'LC_ALL': 'C'})
|
result = run_cmd(cwd, clone_cmd, env={"LC_ALL": "C"})
|
||||||
|
|
||||||
if remove_git_dir:
|
if remove_git_dir:
|
||||||
rmdir(os.path.join(dirname, '.git'))
|
rmdir(os.path.join(dirname, ".git"))
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def list_tags(cwd):
|
def list_tags(cwd):
|
||||||
out, err = run_cmd(cwd, ['git', 'tag', '--list'], env={'LC_ALL': 'C'})
|
out, err = run_cmd(cwd, ["git", "tag", "--list"], env={"LC_ALL": "C"})
|
||||||
tags = out.decode('utf-8').strip().split("\n")
|
tags = out.decode("utf-8").strip().split("\n")
|
||||||
return tags
|
return tags
|
||||||
|
|
||||||
|
|
||||||
def _checkout(cwd, repo, branch):
|
def _checkout(cwd, repo, branch):
|
||||||
logger.debug(' Checking out branch {}.'.format(branch))
|
logger.debug(" Checking out branch {}.".format(branch))
|
||||||
|
|
||||||
run_cmd(cwd, ['git', 'remote', 'set-branches', 'origin', branch])
|
run_cmd(cwd, ["git", "remote", "set-branches", "origin", branch])
|
||||||
run_cmd(cwd, ['git', 'fetch', '--tags', '--depth', '1', 'origin', branch])
|
run_cmd(cwd, ["git", "fetch", "--tags", "--depth", "1", "origin", branch])
|
||||||
|
|
||||||
tags = list_tags(cwd)
|
tags = list_tags(cwd)
|
||||||
|
|
||||||
# Prefer tags to branches if one exists
|
# Prefer tags to branches if one exists
|
||||||
if branch in tags:
|
if branch in tags:
|
||||||
spec = 'tags/{}'.format(branch)
|
spec = "tags/{}".format(branch)
|
||||||
else:
|
else:
|
||||||
spec = 'origin/{}'.format(branch)
|
spec = "origin/{}".format(branch)
|
||||||
|
|
||||||
out, err = run_cmd(cwd, ['git', 'reset', '--hard', spec],
|
out, err = run_cmd(cwd, ["git", "reset", "--hard", spec], env={"LC_ALL": "C"})
|
||||||
env={'LC_ALL': 'C'})
|
|
||||||
return out, err
|
return out, err
|
||||||
|
|
||||||
|
|
||||||
def checkout(cwd, repo, branch=None):
|
def checkout(cwd, repo, branch=None):
|
||||||
if branch is None:
|
if branch is None:
|
||||||
branch = 'master'
|
branch = "HEAD"
|
||||||
try:
|
try:
|
||||||
return _checkout(cwd, repo, branch)
|
return _checkout(cwd, repo, branch)
|
||||||
except dbt.exceptions.CommandResultError as exc:
|
except dbt.exceptions.CommandResultError as exc:
|
||||||
stderr = exc.stderr.decode('utf-8').strip()
|
stderr = exc.stderr.decode("utf-8").strip()
|
||||||
dbt.exceptions.bad_package_spec(repo, branch, stderr)
|
dbt.exceptions.bad_package_spec(repo, branch, stderr)
|
||||||
|
|
||||||
|
|
||||||
def get_current_sha(cwd):
|
def get_current_sha(cwd):
|
||||||
out, err = run_cmd(cwd, ['git', 'rev-parse', 'HEAD'], env={'LC_ALL': 'C'})
|
out, err = run_cmd(cwd, ["git", "rev-parse", "HEAD"], env={"LC_ALL": "C"})
|
||||||
|
|
||||||
return out.decode('utf-8')
|
return out.decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
def remove_remote(cwd):
|
def remove_remote(cwd):
|
||||||
return run_cmd(cwd, ['git', 'remote', 'rm', 'origin'], env={'LC_ALL': 'C'})
|
return run_cmd(cwd, ["git", "remote", "rm", "origin"], env={"LC_ALL": "C"})
|
||||||
|
|
||||||
|
|
||||||
def clone_and_checkout(repo, cwd, dirname=None, remove_git_dir=False,
|
def clone_and_checkout(repo, cwd, dirname=None, remove_git_dir=False, branch=None):
|
||||||
branch=None):
|
|
||||||
exists = None
|
exists = None
|
||||||
try:
|
try:
|
||||||
_, err = clone(repo, cwd, dirname=dirname,
|
_, err = clone(repo, cwd, dirname=dirname, remove_git_dir=remove_git_dir)
|
||||||
remove_git_dir=remove_git_dir)
|
|
||||||
except dbt.exceptions.CommandResultError as exc:
|
except dbt.exceptions.CommandResultError as exc:
|
||||||
err = exc.stderr.decode('utf-8')
|
err = exc.stderr.decode("utf-8")
|
||||||
exists = re.match("fatal: destination path '(.+)' already exists", err)
|
exists = re.match("fatal: destination path '(.+)' already exists", err)
|
||||||
if not exists: # something else is wrong, raise it
|
if not exists: # something else is wrong, raise it
|
||||||
raise
|
raise
|
||||||
@@ -86,25 +83,26 @@ def clone_and_checkout(repo, cwd, dirname=None, remove_git_dir=False,
|
|||||||
start_sha = None
|
start_sha = None
|
||||||
if exists:
|
if exists:
|
||||||
directory = exists.group(1)
|
directory = exists.group(1)
|
||||||
logger.debug('Updating existing dependency {}.', directory)
|
logger.debug("Updating existing dependency {}.", directory)
|
||||||
else:
|
else:
|
||||||
matches = re.match("Cloning into '(.+)'", err.decode('utf-8'))
|
matches = re.match("Cloning into '(.+)'", err.decode("utf-8"))
|
||||||
if matches is None:
|
if matches is None:
|
||||||
raise dbt.exceptions.RuntimeException(
|
raise dbt.exceptions.RuntimeException(
|
||||||
f'Error cloning {repo} - never saw "Cloning into ..." from git'
|
f'Error cloning {repo} - never saw "Cloning into ..." from git'
|
||||||
)
|
)
|
||||||
directory = matches.group(1)
|
directory = matches.group(1)
|
||||||
logger.debug('Pulling new dependency {}.', directory)
|
logger.debug("Pulling new dependency {}.", directory)
|
||||||
full_path = os.path.join(cwd, directory)
|
full_path = os.path.join(cwd, directory)
|
||||||
start_sha = get_current_sha(full_path)
|
start_sha = get_current_sha(full_path)
|
||||||
checkout(full_path, repo, branch)
|
checkout(full_path, repo, branch)
|
||||||
end_sha = get_current_sha(full_path)
|
end_sha = get_current_sha(full_path)
|
||||||
if exists:
|
if exists:
|
||||||
if start_sha == end_sha:
|
if start_sha == end_sha:
|
||||||
logger.debug(' Already at {}, nothing to do.', start_sha[:7])
|
logger.debug(" Already at {}, nothing to do.", start_sha[:7])
|
||||||
else:
|
else:
|
||||||
logger.debug(' Updated checkout from {} to {}.',
|
logger.debug(
|
||||||
start_sha[:7], end_sha[:7])
|
" Updated checkout from {} to {}.", start_sha[:7], end_sha[:7]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug(' Checked out at {}.', end_sha[:7])
|
logger.debug(" Checked out at {}.", end_sha[:7])
|
||||||
return directory
|
return directory
|
||||||
|
|||||||
@@ -8,8 +8,17 @@ from ast import literal_eval
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from itertools import chain, islice
|
from itertools import chain, islice
|
||||||
from typing import (
|
from typing import (
|
||||||
List, Union, Set, Optional, Dict, Any, Iterator, Type, NoReturn, Tuple,
|
List,
|
||||||
Callable
|
Union,
|
||||||
|
Set,
|
||||||
|
Optional,
|
||||||
|
Dict,
|
||||||
|
Any,
|
||||||
|
Iterator,
|
||||||
|
Type,
|
||||||
|
NoReturn,
|
||||||
|
Tuple,
|
||||||
|
Callable,
|
||||||
)
|
)
|
||||||
|
|
||||||
import jinja2
|
import jinja2
|
||||||
@@ -20,16 +29,22 @@ import jinja2.parser
|
|||||||
import jinja2.sandbox
|
import jinja2.sandbox
|
||||||
|
|
||||||
from dbt.utils import (
|
from dbt.utils import (
|
||||||
get_dbt_macro_name, get_docs_macro_name, get_materialization_macro_name,
|
get_dbt_macro_name,
|
||||||
deep_map
|
get_docs_macro_name,
|
||||||
|
get_materialization_macro_name,
|
||||||
|
deep_map,
|
||||||
)
|
)
|
||||||
|
|
||||||
from dbt.clients._jinja_blocks import BlockIterator, BlockData, BlockTag
|
from dbt.clients._jinja_blocks import BlockIterator, BlockData, BlockTag
|
||||||
from dbt.contracts.graph.compiled import CompiledSchemaTestNode
|
from dbt.contracts.graph.compiled import CompiledSchemaTestNode
|
||||||
from dbt.contracts.graph.parsed import ParsedSchemaTestNode
|
from dbt.contracts.graph.parsed import ParsedSchemaTestNode
|
||||||
from dbt.exceptions import (
|
from dbt.exceptions import (
|
||||||
InternalException, raise_compiler_error, CompilationException,
|
InternalException,
|
||||||
invalid_materialization_argument, MacroReturn, JinjaRenderingException
|
raise_compiler_error,
|
||||||
|
CompilationException,
|
||||||
|
invalid_materialization_argument,
|
||||||
|
MacroReturn,
|
||||||
|
JinjaRenderingException,
|
||||||
)
|
)
|
||||||
from dbt import flags
|
from dbt import flags
|
||||||
from dbt.logger import GLOBAL_LOGGER as logger # noqa
|
from dbt.logger import GLOBAL_LOGGER as logger # noqa
|
||||||
@@ -40,26 +55,26 @@ def _linecache_inject(source, write):
|
|||||||
# this is the only reliable way to accomplish this. Obviously, it's
|
# this is the only reliable way to accomplish this. Obviously, it's
|
||||||
# really darn noisy and will fill your temporary directory
|
# really darn noisy and will fill your temporary directory
|
||||||
tmp_file = tempfile.NamedTemporaryFile(
|
tmp_file = tempfile.NamedTemporaryFile(
|
||||||
prefix='dbt-macro-compiled-',
|
prefix="dbt-macro-compiled-",
|
||||||
suffix='.py',
|
suffix=".py",
|
||||||
delete=False,
|
delete=False,
|
||||||
mode='w+',
|
mode="w+",
|
||||||
encoding='utf-8',
|
encoding="utf-8",
|
||||||
)
|
)
|
||||||
tmp_file.write(source)
|
tmp_file.write(source)
|
||||||
filename = tmp_file.name
|
filename = tmp_file.name
|
||||||
else:
|
else:
|
||||||
# `codecs.encode` actually takes a `bytes` as the first argument if
|
# `codecs.encode` actually takes a `bytes` as the first argument if
|
||||||
# the second argument is 'hex' - mypy does not know this.
|
# the second argument is 'hex' - mypy does not know this.
|
||||||
rnd = codecs.encode(os.urandom(12), 'hex') # type: ignore
|
rnd = codecs.encode(os.urandom(12), "hex") # type: ignore
|
||||||
filename = rnd.decode('ascii')
|
filename = rnd.decode("ascii")
|
||||||
|
|
||||||
# put ourselves in the cache
|
# put ourselves in the cache
|
||||||
cache_entry = (
|
cache_entry = (
|
||||||
len(source),
|
len(source),
|
||||||
None,
|
None,
|
||||||
[line + '\n' for line in source.splitlines()],
|
[line + "\n" for line in source.splitlines()],
|
||||||
filename
|
filename,
|
||||||
)
|
)
|
||||||
# linecache does in fact have an attribute `cache`, thanks
|
# linecache does in fact have an attribute `cache`, thanks
|
||||||
linecache.cache[filename] = cache_entry # type: ignore
|
linecache.cache[filename] = cache_entry # type: ignore
|
||||||
@@ -73,12 +88,10 @@ class MacroFuzzParser(jinja2.parser.Parser):
|
|||||||
# modified to fuzz macros defined in the same file. this way
|
# modified to fuzz macros defined in the same file. this way
|
||||||
# dbt can understand the stack of macros being called.
|
# dbt can understand the stack of macros being called.
|
||||||
# - @cmcarthur
|
# - @cmcarthur
|
||||||
node.name = get_dbt_macro_name(
|
node.name = get_dbt_macro_name(self.parse_assign_target(name_only=True).name)
|
||||||
self.parse_assign_target(name_only=True).name)
|
|
||||||
|
|
||||||
self.parse_signature(node)
|
self.parse_signature(node)
|
||||||
node.body = self.parse_statements(('name:endmacro',),
|
node.body = self.parse_statements(("name:endmacro",), drop_needle=True)
|
||||||
drop_needle=True)
|
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
@@ -94,8 +107,8 @@ class MacroFuzzEnvironment(jinja2.sandbox.SandboxedEnvironment):
|
|||||||
If the value is 'write', also write the files to disk.
|
If the value is 'write', also write the files to disk.
|
||||||
WARNING: This can write a ton of data if you aren't careful.
|
WARNING: This can write a ton of data if you aren't careful.
|
||||||
"""
|
"""
|
||||||
if filename == '<template>' and flags.MACRO_DEBUGGING:
|
if filename == "<template>" and flags.MACRO_DEBUGGING:
|
||||||
write = flags.MACRO_DEBUGGING == 'write'
|
write = flags.MACRO_DEBUGGING == "write"
|
||||||
filename = _linecache_inject(source, write)
|
filename = _linecache_inject(source, write)
|
||||||
|
|
||||||
return super()._compile(source, filename) # type: ignore
|
return super()._compile(source, filename) # type: ignore
|
||||||
@@ -138,7 +151,7 @@ def quoted_native_concat(nodes):
|
|||||||
head = list(islice(nodes, 2))
|
head = list(islice(nodes, 2))
|
||||||
|
|
||||||
if not head:
|
if not head:
|
||||||
return ''
|
return ""
|
||||||
|
|
||||||
if len(head) == 1:
|
if len(head) == 1:
|
||||||
raw = head[0]
|
raw = head[0]
|
||||||
@@ -180,9 +193,7 @@ class NativeSandboxTemplate(jinja2.nativetypes.NativeTemplate): # mypy: ignore
|
|||||||
vars = dict(*args, **kwargs)
|
vars = dict(*args, **kwargs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return quoted_native_concat(
|
return quoted_native_concat(self.root_render_func(self.new_context(vars)))
|
||||||
self.root_render_func(self.new_context(vars))
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
return self.environment.handle_exception()
|
return self.environment.handle_exception()
|
||||||
|
|
||||||
@@ -221,16 +232,17 @@ class BaseMacroGenerator:
|
|||||||
self.context: Optional[Dict[str, Any]] = context
|
self.context: Optional[Dict[str, Any]] = context
|
||||||
|
|
||||||
def get_template(self):
|
def get_template(self):
|
||||||
raise NotImplementedError('get_template not implemented!')
|
raise NotImplementedError("get_template not implemented!")
|
||||||
|
|
||||||
def get_name(self) -> str:
|
def get_name(self) -> str:
|
||||||
raise NotImplementedError('get_name not implemented!')
|
raise NotImplementedError("get_name not implemented!")
|
||||||
|
|
||||||
def get_macro(self):
|
def get_macro(self):
|
||||||
name = self.get_name()
|
name = self.get_name()
|
||||||
template = self.get_template()
|
template = self.get_template()
|
||||||
# make the module. previously we set both vars and local, but that's
|
# make the module. previously we set both vars and local, but that's
|
||||||
# redundant: They both end up in the same place
|
# redundant: They both end up in the same place
|
||||||
|
# make_module is in jinja2.environment. It returns a TemplateModule
|
||||||
module = template.make_module(vars=self.context, shared=False)
|
module = template.make_module(vars=self.context, shared=False)
|
||||||
macro = module.__dict__[get_dbt_macro_name(name)]
|
macro = module.__dict__[get_dbt_macro_name(name)]
|
||||||
module.__dict__.update(self.context)
|
module.__dict__.update(self.context)
|
||||||
@@ -244,10 +256,9 @@ class BaseMacroGenerator:
|
|||||||
raise_compiler_error(str(e))
|
raise_compiler_error(str(e))
|
||||||
|
|
||||||
def call_macro(self, *args, **kwargs):
|
def call_macro(self, *args, **kwargs):
|
||||||
|
# called from __call__ methods
|
||||||
if self.context is None:
|
if self.context is None:
|
||||||
raise InternalException(
|
raise InternalException("Context is still None in call_macro!")
|
||||||
'Context is still None in call_macro!'
|
|
||||||
)
|
|
||||||
assert self.context is not None
|
assert self.context is not None
|
||||||
|
|
||||||
macro = self.get_macro()
|
macro = self.get_macro()
|
||||||
@@ -274,7 +285,7 @@ class MacroStack(threading.local):
|
|||||||
def pop(self, name):
|
def pop(self, name):
|
||||||
got = self.call_stack.pop()
|
got = self.call_stack.pop()
|
||||||
if got != name:
|
if got != name:
|
||||||
raise InternalException(f'popped {got}, expected {name}')
|
raise InternalException(f"popped {got}, expected {name}")
|
||||||
|
|
||||||
|
|
||||||
class MacroGenerator(BaseMacroGenerator):
|
class MacroGenerator(BaseMacroGenerator):
|
||||||
@@ -283,7 +294,7 @@ class MacroGenerator(BaseMacroGenerator):
|
|||||||
macro,
|
macro,
|
||||||
context: Optional[Dict[str, Any]] = None,
|
context: Optional[Dict[str, Any]] = None,
|
||||||
node: Optional[Any] = None,
|
node: Optional[Any] = None,
|
||||||
stack: Optional[MacroStack] = None
|
stack: Optional[MacroStack] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(context)
|
super().__init__(context)
|
||||||
self.macro = macro
|
self.macro = macro
|
||||||
@@ -306,8 +317,10 @@ class MacroGenerator(BaseMacroGenerator):
|
|||||||
e.stack.append(self.macro)
|
e.stack.append(self.macro)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
# This adds the macro's unique id to the node's 'depends_on'
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def track_call(self):
|
def track_call(self):
|
||||||
|
# This is only called from __call__
|
||||||
if self.stack is None or self.node is None:
|
if self.stack is None or self.node is None:
|
||||||
yield
|
yield
|
||||||
else:
|
else:
|
||||||
@@ -322,15 +335,14 @@ class MacroGenerator(BaseMacroGenerator):
|
|||||||
finally:
|
finally:
|
||||||
self.stack.pop(unique_id)
|
self.stack.pop(unique_id)
|
||||||
|
|
||||||
|
# this makes MacroGenerator objects callable like functions
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
with self.track_call():
|
with self.track_call():
|
||||||
return self.call_macro(*args, **kwargs)
|
return self.call_macro(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class QueryStringGenerator(BaseMacroGenerator):
|
class QueryStringGenerator(BaseMacroGenerator):
|
||||||
def __init__(
|
def __init__(self, template_str: str, context: Dict[str, Any]) -> None:
|
||||||
self, template_str: str, context: Dict[str, Any]
|
|
||||||
) -> None:
|
|
||||||
super().__init__(context)
|
super().__init__(context)
|
||||||
self.template_str: str = template_str
|
self.template_str: str = template_str
|
||||||
env = get_environment()
|
env = get_environment()
|
||||||
@@ -340,7 +352,7 @@ class QueryStringGenerator(BaseMacroGenerator):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_name(self) -> str:
|
def get_name(self) -> str:
|
||||||
return 'query_comment_macro'
|
return "query_comment_macro"
|
||||||
|
|
||||||
def get_template(self):
|
def get_template(self):
|
||||||
"""Don't use the template cache, we don't have a node"""
|
"""Don't use the template cache, we don't have a node"""
|
||||||
@@ -351,45 +363,41 @@ class QueryStringGenerator(BaseMacroGenerator):
|
|||||||
|
|
||||||
|
|
||||||
class MaterializationExtension(jinja2.ext.Extension):
|
class MaterializationExtension(jinja2.ext.Extension):
|
||||||
tags = ['materialization']
|
tags = ["materialization"]
|
||||||
|
|
||||||
def parse(self, parser):
|
def parse(self, parser):
|
||||||
node = jinja2.nodes.Macro(lineno=next(parser.stream).lineno)
|
node = jinja2.nodes.Macro(lineno=next(parser.stream).lineno)
|
||||||
materialization_name = \
|
materialization_name = parser.parse_assign_target(name_only=True).name
|
||||||
parser.parse_assign_target(name_only=True).name
|
|
||||||
|
|
||||||
adapter_name = 'default'
|
adapter_name = "default"
|
||||||
node.args = []
|
node.args = []
|
||||||
node.defaults = []
|
node.defaults = []
|
||||||
|
|
||||||
while parser.stream.skip_if('comma'):
|
while parser.stream.skip_if("comma"):
|
||||||
target = parser.parse_assign_target(name_only=True)
|
target = parser.parse_assign_target(name_only=True)
|
||||||
|
|
||||||
if target.name == 'default':
|
if target.name == "default":
|
||||||
pass
|
pass
|
||||||
|
|
||||||
elif target.name == 'adapter':
|
elif target.name == "adapter":
|
||||||
parser.stream.expect('assign')
|
parser.stream.expect("assign")
|
||||||
value = parser.parse_expression()
|
value = parser.parse_expression()
|
||||||
adapter_name = value.value
|
adapter_name = value.value
|
||||||
|
|
||||||
else:
|
else:
|
||||||
invalid_materialization_argument(
|
invalid_materialization_argument(materialization_name, target.name)
|
||||||
materialization_name, target.name
|
|
||||||
)
|
|
||||||
|
|
||||||
node.name = get_materialization_macro_name(
|
node.name = get_materialization_macro_name(materialization_name, adapter_name)
|
||||||
materialization_name, adapter_name
|
|
||||||
|
node.body = parser.parse_statements(
|
||||||
|
("name:endmaterialization",), drop_needle=True
|
||||||
)
|
)
|
||||||
|
|
||||||
node.body = parser.parse_statements(('name:endmaterialization',),
|
|
||||||
drop_needle=True)
|
|
||||||
|
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
class DocumentationExtension(jinja2.ext.Extension):
|
class DocumentationExtension(jinja2.ext.Extension):
|
||||||
tags = ['docs']
|
tags = ["docs"]
|
||||||
|
|
||||||
def parse(self, parser):
|
def parse(self, parser):
|
||||||
node = jinja2.nodes.Macro(lineno=next(parser.stream).lineno)
|
node = jinja2.nodes.Macro(lineno=next(parser.stream).lineno)
|
||||||
@@ -398,13 +406,12 @@ class DocumentationExtension(jinja2.ext.Extension):
|
|||||||
node.args = []
|
node.args = []
|
||||||
node.defaults = []
|
node.defaults = []
|
||||||
node.name = get_docs_macro_name(docs_name)
|
node.name = get_docs_macro_name(docs_name)
|
||||||
node.body = parser.parse_statements(('name:enddocs',),
|
node.body = parser.parse_statements(("name:enddocs",), drop_needle=True)
|
||||||
drop_needle=True)
|
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
def _is_dunder_name(name):
|
def _is_dunder_name(name):
|
||||||
return name.startswith('__') and name.endswith('__')
|
return name.startswith("__") and name.endswith("__")
|
||||||
|
|
||||||
|
|
||||||
def create_undefined(node=None):
|
def create_undefined(node=None):
|
||||||
@@ -425,10 +432,11 @@ def create_undefined(node=None):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
if name == 'name' or _is_dunder_name(name):
|
if name == "name" or _is_dunder_name(name):
|
||||||
raise AttributeError(
|
raise AttributeError(
|
||||||
"'{}' object has no attribute '{}'"
|
"'{}' object has no attribute '{}'".format(
|
||||||
.format(type(self).__name__, name)
|
type(self).__name__, name
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.name = name
|
self.name = name
|
||||||
@@ -439,24 +447,24 @@ def create_undefined(node=None):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
raise_compiler_error(f'{self.name} is undefined', node=node)
|
raise_compiler_error(f"{self.name} is undefined", node=node)
|
||||||
|
|
||||||
return Undefined
|
return Undefined
|
||||||
|
|
||||||
|
|
||||||
NATIVE_FILTERS: Dict[str, Callable[[Any], Any]] = {
|
NATIVE_FILTERS: Dict[str, Callable[[Any], Any]] = {
|
||||||
'as_text': TextMarker,
|
"as_text": TextMarker,
|
||||||
'as_bool': BoolMarker,
|
"as_bool": BoolMarker,
|
||||||
'as_native': NativeMarker,
|
"as_native": NativeMarker,
|
||||||
'as_number': NumberMarker,
|
"as_number": NumberMarker,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
TEXT_FILTERS: Dict[str, Callable[[Any], Any]] = {
|
TEXT_FILTERS: Dict[str, Callable[[Any], Any]] = {
|
||||||
'as_text': lambda x: x,
|
"as_text": lambda x: x,
|
||||||
'as_bool': lambda x: x,
|
"as_bool": lambda x: x,
|
||||||
'as_native': lambda x: x,
|
"as_native": lambda x: x,
|
||||||
'as_number': lambda x: x,
|
"as_number": lambda x: x,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -466,14 +474,14 @@ def get_environment(
|
|||||||
native: bool = False,
|
native: bool = False,
|
||||||
) -> jinja2.Environment:
|
) -> jinja2.Environment:
|
||||||
args: Dict[str, List[Union[str, Type[jinja2.ext.Extension]]]] = {
|
args: Dict[str, List[Union[str, Type[jinja2.ext.Extension]]]] = {
|
||||||
'extensions': ['jinja2.ext.do']
|
"extensions": ["jinja2.ext.do"]
|
||||||
}
|
}
|
||||||
|
|
||||||
if capture_macros:
|
if capture_macros:
|
||||||
args['undefined'] = create_undefined(node)
|
args["undefined"] = create_undefined(node)
|
||||||
|
|
||||||
args['extensions'].append(MaterializationExtension)
|
args["extensions"].append(MaterializationExtension)
|
||||||
args['extensions'].append(DocumentationExtension)
|
args["extensions"].append(DocumentationExtension)
|
||||||
|
|
||||||
env_cls: Type[jinja2.Environment]
|
env_cls: Type[jinja2.Environment]
|
||||||
text_filter: Type
|
text_filter: Type
|
||||||
@@ -536,8 +544,8 @@ def _requote_result(raw_value: str, rendered: str) -> str:
|
|||||||
elif single_quoted:
|
elif single_quoted:
|
||||||
quote_char = "'"
|
quote_char = "'"
|
||||||
else:
|
else:
|
||||||
quote_char = ''
|
quote_char = ""
|
||||||
return f'{quote_char}{rendered}{quote_char}'
|
return f"{quote_char}{rendered}{quote_char}"
|
||||||
|
|
||||||
|
|
||||||
# performance note: Local benmcharking (so take it with a big grain of salt!)
|
# performance note: Local benmcharking (so take it with a big grain of salt!)
|
||||||
@@ -545,7 +553,7 @@ def _requote_result(raw_value: str, rendered: str) -> str:
|
|||||||
# checking two separate patterns, but the standard deviation is smaller with
|
# checking two separate patterns, but the standard deviation is smaller with
|
||||||
# one pattern. The time difference between the two was ~2 std deviations, which
|
# one pattern. The time difference between the two was ~2 std deviations, which
|
||||||
# is small enough that I've just chosen the more readable option.
|
# is small enough that I've just chosen the more readable option.
|
||||||
_HAS_RENDER_CHARS_PAT = re.compile(r'({[{%#]|[#}%]})')
|
_HAS_RENDER_CHARS_PAT = re.compile(r"({[{%#]|[#}%]})")
|
||||||
|
|
||||||
|
|
||||||
def get_rendered(
|
def get_rendered(
|
||||||
@@ -562,9 +570,9 @@ def get_rendered(
|
|||||||
# native=True case by passing the input string to ast.literal_eval, like
|
# native=True case by passing the input string to ast.literal_eval, like
|
||||||
# the native renderer does.
|
# the native renderer does.
|
||||||
if (
|
if (
|
||||||
not native and
|
not native
|
||||||
isinstance(string, str) and
|
and isinstance(string, str)
|
||||||
_HAS_RENDER_CHARS_PAT.search(string) is None
|
and _HAS_RENDER_CHARS_PAT.search(string) is None
|
||||||
):
|
):
|
||||||
return string
|
return string
|
||||||
template = get_template(
|
template = get_template(
|
||||||
@@ -601,12 +609,11 @@ def extract_toplevel_blocks(
|
|||||||
`collect_raw_data` is `True`) `BlockData` objects.
|
`collect_raw_data` is `True`) `BlockData` objects.
|
||||||
"""
|
"""
|
||||||
return BlockIterator(data).lex_for_blocks(
|
return BlockIterator(data).lex_for_blocks(
|
||||||
allowed_blocks=allowed_blocks,
|
allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data
|
||||||
collect_raw_data=collect_raw_data
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
SCHEMA_TEST_KWARGS_NAME = '_dbt_schema_test_kwargs'
|
SCHEMA_TEST_KWARGS_NAME = "_dbt_schema_test_kwargs"
|
||||||
|
|
||||||
|
|
||||||
def add_rendered_test_kwargs(
|
def add_rendered_test_kwargs(
|
||||||
@@ -618,24 +625,21 @@ def add_rendered_test_kwargs(
|
|||||||
renderer, then insert that value into the given context as the special test
|
renderer, then insert that value into the given context as the special test
|
||||||
keyword arguments member.
|
keyword arguments member.
|
||||||
"""
|
"""
|
||||||
looks_like_func = r'^\s*(env_var|ref|var|source|doc)\s*\(.+\)\s*$'
|
looks_like_func = r"^\s*(env_var|ref|var|source|doc)\s*\(.+\)\s*$"
|
||||||
|
|
||||||
def _convert_function(
|
def _convert_function(value: Any, keypath: Tuple[Union[str, int], ...]) -> Any:
|
||||||
value: Any, keypath: Tuple[Union[str, int], ...]
|
|
||||||
) -> Any:
|
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
if keypath == ('column_name',):
|
if keypath == ("column_name",):
|
||||||
# special case: Don't render column names as native, make them
|
# special case: Don't render column names as native, make them
|
||||||
# be strings
|
# be strings
|
||||||
return value
|
return value
|
||||||
|
|
||||||
if re.match(looks_like_func, value) is not None:
|
if re.match(looks_like_func, value) is not None:
|
||||||
# curly braces to make rendering happy
|
# curly braces to make rendering happy
|
||||||
value = f'{{{{ {value} }}}}'
|
value = f"{{{{ {value} }}}}"
|
||||||
|
|
||||||
value = get_rendered(
|
value = get_rendered(
|
||||||
value, context, node, capture_macros=capture_macros,
|
value, context, node, capture_macros=capture_macros, native=True
|
||||||
native=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|||||||
@@ -6,17 +6,17 @@ from dbt.logger import GLOBAL_LOGGER as logger
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
if os.getenv('DBT_PACKAGE_HUB_URL'):
|
if os.getenv("DBT_PACKAGE_HUB_URL"):
|
||||||
DEFAULT_REGISTRY_BASE_URL = os.getenv('DBT_PACKAGE_HUB_URL')
|
DEFAULT_REGISTRY_BASE_URL = os.getenv("DBT_PACKAGE_HUB_URL")
|
||||||
else:
|
else:
|
||||||
DEFAULT_REGISTRY_BASE_URL = 'https://hub.getdbt.com/'
|
DEFAULT_REGISTRY_BASE_URL = "https://hub.getdbt.com/"
|
||||||
|
|
||||||
|
|
||||||
def _get_url(url, registry_base_url=None):
|
def _get_url(url, registry_base_url=None):
|
||||||
if registry_base_url is None:
|
if registry_base_url is None:
|
||||||
registry_base_url = DEFAULT_REGISTRY_BASE_URL
|
registry_base_url = DEFAULT_REGISTRY_BASE_URL
|
||||||
|
|
||||||
return '{}{}'.format(registry_base_url, url)
|
return "{}{}".format(registry_base_url, url)
|
||||||
|
|
||||||
|
|
||||||
def _wrap_exceptions(fn):
|
def _wrap_exceptions(fn):
|
||||||
@@ -33,42 +33,40 @@ def _wrap_exceptions(fn):
|
|||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
raise RegistryException(
|
raise RegistryException("Unable to connect to registry hub") from exc
|
||||||
'Unable to connect to registry hub'
|
|
||||||
) from exc
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
@_wrap_exceptions
|
@_wrap_exceptions
|
||||||
def _get(path, registry_base_url=None):
|
def _get(path, registry_base_url=None):
|
||||||
url = _get_url(path, registry_base_url)
|
url = _get_url(path, registry_base_url)
|
||||||
logger.debug('Making package registry request: GET {}'.format(url))
|
logger.debug("Making package registry request: GET {}".format(url))
|
||||||
resp = requests.get(url)
|
resp = requests.get(url)
|
||||||
logger.debug('Response from registry: GET {} {}'.format(url,
|
logger.debug("Response from registry: GET {} {}".format(url, resp.status_code))
|
||||||
resp.status_code))
|
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
return resp.json()
|
return resp.json()
|
||||||
|
|
||||||
|
|
||||||
def index(registry_base_url=None):
|
def index(registry_base_url=None):
|
||||||
return _get('api/v1/index.json', registry_base_url)
|
return _get("api/v1/index.json", registry_base_url)
|
||||||
|
|
||||||
|
|
||||||
index_cached = memoized(index)
|
index_cached = memoized(index)
|
||||||
|
|
||||||
|
|
||||||
def packages(registry_base_url=None):
|
def packages(registry_base_url=None):
|
||||||
return _get('api/v1/packages.json', registry_base_url)
|
return _get("api/v1/packages.json", registry_base_url)
|
||||||
|
|
||||||
|
|
||||||
def package(name, registry_base_url=None):
|
def package(name, registry_base_url=None):
|
||||||
return _get('api/v1/{}.json'.format(name), registry_base_url)
|
return _get("api/v1/{}.json".format(name), registry_base_url)
|
||||||
|
|
||||||
|
|
||||||
def package_version(name, version, registry_base_url=None):
|
def package_version(name, version, registry_base_url=None):
|
||||||
return _get('api/v1/{}/{}.json'.format(name, version), registry_base_url)
|
return _get("api/v1/{}/{}.json".format(name, version), registry_base_url)
|
||||||
|
|
||||||
|
|
||||||
def get_available_versions(name):
|
def get_available_versions(name):
|
||||||
response = package(name)
|
response = package(name)
|
||||||
return list(response['versions'])
|
return list(response["versions"])
|
||||||
|
|||||||
@@ -10,16 +10,14 @@ import sys
|
|||||||
import tarfile
|
import tarfile
|
||||||
import requests
|
import requests
|
||||||
import stat
|
import stat
|
||||||
from typing import (
|
from typing import Type, NoReturn, List, Optional, Dict, Any, Tuple, Callable, Union
|
||||||
Type, NoReturn, List, Optional, Dict, Any, Tuple, Callable, Union
|
|
||||||
)
|
|
||||||
|
|
||||||
import dbt.exceptions
|
import dbt.exceptions
|
||||||
import dbt.utils
|
import dbt.utils
|
||||||
|
|
||||||
from dbt.logger import GLOBAL_LOGGER as logger
|
from dbt.logger import GLOBAL_LOGGER as logger
|
||||||
|
|
||||||
if sys.platform == 'win32':
|
if sys.platform == "win32":
|
||||||
from ctypes import WinDLL, c_bool
|
from ctypes import WinDLL, c_bool
|
||||||
else:
|
else:
|
||||||
WinDLL = None
|
WinDLL = None
|
||||||
@@ -51,30 +49,29 @@ def find_matching(
|
|||||||
reobj = re.compile(regex, re.IGNORECASE)
|
reobj = re.compile(regex, re.IGNORECASE)
|
||||||
|
|
||||||
for relative_path_to_search in relative_paths_to_search:
|
for relative_path_to_search in relative_paths_to_search:
|
||||||
absolute_path_to_search = os.path.join(
|
absolute_path_to_search = os.path.join(root_path, relative_path_to_search)
|
||||||
root_path, relative_path_to_search)
|
|
||||||
walk_results = os.walk(absolute_path_to_search)
|
walk_results = os.walk(absolute_path_to_search)
|
||||||
|
|
||||||
for current_path, subdirectories, local_files in walk_results:
|
for current_path, subdirectories, local_files in walk_results:
|
||||||
for local_file in local_files:
|
for local_file in local_files:
|
||||||
absolute_path = os.path.join(current_path, local_file)
|
absolute_path = os.path.join(current_path, local_file)
|
||||||
relative_path = os.path.relpath(
|
relative_path = os.path.relpath(absolute_path, absolute_path_to_search)
|
||||||
absolute_path, absolute_path_to_search
|
|
||||||
)
|
|
||||||
if reobj.match(local_file):
|
if reobj.match(local_file):
|
||||||
matching.append({
|
matching.append(
|
||||||
'searched_path': relative_path_to_search,
|
{
|
||||||
'absolute_path': absolute_path,
|
"searched_path": relative_path_to_search,
|
||||||
'relative_path': relative_path,
|
"absolute_path": absolute_path,
|
||||||
})
|
"relative_path": relative_path,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return matching
|
return matching
|
||||||
|
|
||||||
|
|
||||||
def load_file_contents(path: str, strip: bool = True) -> str:
|
def load_file_contents(path: str, strip: bool = True) -> str:
|
||||||
path = convert_path(path)
|
path = convert_path(path)
|
||||||
with open(path, 'rb') as handle:
|
with open(path, "rb") as handle:
|
||||||
to_return = handle.read().decode('utf-8')
|
to_return = handle.read().decode("utf-8")
|
||||||
|
|
||||||
if strip:
|
if strip:
|
||||||
to_return = to_return.strip()
|
to_return = to_return.strip()
|
||||||
@@ -101,14 +98,14 @@ def make_directory(path: str) -> None:
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def make_file(path: str, contents: str = '', overwrite: bool = False) -> bool:
|
def make_file(path: str, contents: str = "", overwrite: bool = False) -> bool:
|
||||||
"""
|
"""
|
||||||
Make a file at `path` assuming that the directory it resides in already
|
Make a file at `path` assuming that the directory it resides in already
|
||||||
exists. The file is saved with contents `contents`
|
exists. The file is saved with contents `contents`
|
||||||
"""
|
"""
|
||||||
if overwrite or not os.path.exists(path):
|
if overwrite or not os.path.exists(path):
|
||||||
path = convert_path(path)
|
path = convert_path(path)
|
||||||
with open(path, 'w') as fh:
|
with open(path, "w") as fh:
|
||||||
fh.write(contents)
|
fh.write(contents)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -120,7 +117,7 @@ def make_symlink(source: str, link_path: str) -> None:
|
|||||||
Create a symlink at `link_path` referring to `source`.
|
Create a symlink at `link_path` referring to `source`.
|
||||||
"""
|
"""
|
||||||
if not supports_symlinks():
|
if not supports_symlinks():
|
||||||
dbt.exceptions.system_error('create a symbolic link')
|
dbt.exceptions.system_error("create a symbolic link")
|
||||||
|
|
||||||
os.symlink(source, link_path)
|
os.symlink(source, link_path)
|
||||||
|
|
||||||
@@ -129,11 +126,11 @@ def supports_symlinks() -> bool:
|
|||||||
return getattr(os, "symlink", None) is not None
|
return getattr(os, "symlink", None) is not None
|
||||||
|
|
||||||
|
|
||||||
def write_file(path: str, contents: str = '') -> bool:
|
def write_file(path: str, contents: str = "") -> bool:
|
||||||
path = convert_path(path)
|
path = convert_path(path)
|
||||||
try:
|
try:
|
||||||
make_directory(os.path.dirname(path))
|
make_directory(os.path.dirname(path))
|
||||||
with open(path, 'w', encoding='utf-8') as f:
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
f.write(str(contents))
|
f.write(str(contents))
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
# note that you can't just catch FileNotFound, because sometimes
|
# note that you can't just catch FileNotFound, because sometimes
|
||||||
@@ -142,20 +139,20 @@ def write_file(path: str, contents: str = '') -> bool:
|
|||||||
# sometimes windows fails to write paths that are less than the length
|
# sometimes windows fails to write paths that are less than the length
|
||||||
# limit. So on windows, suppress all errors that happen from writing
|
# limit. So on windows, suppress all errors that happen from writing
|
||||||
# to disk.
|
# to disk.
|
||||||
if os.name == 'nt':
|
if os.name == "nt":
|
||||||
# sometimes we get a winerror of 3 which means the path was
|
# sometimes we get a winerror of 3 which means the path was
|
||||||
# definitely too long, but other times we don't and it means the
|
# definitely too long, but other times we don't and it means the
|
||||||
# path was just probably too long. This is probably based on the
|
# path was just probably too long. This is probably based on the
|
||||||
# windows/python version.
|
# windows/python version.
|
||||||
if getattr(exc, 'winerror', 0) == 3:
|
if getattr(exc, "winerror", 0) == 3:
|
||||||
reason = 'Path was too long'
|
reason = "Path was too long"
|
||||||
else:
|
else:
|
||||||
reason = 'Path was possibly too long'
|
reason = "Path was possibly too long"
|
||||||
# all our hard work and the path was still too long. Log and
|
# all our hard work and the path was still too long. Log and
|
||||||
# continue.
|
# continue.
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f'Could not write to path {path}({len(path)} characters): '
|
f"Could not write to path {path}({len(path)} characters): "
|
||||||
f'{reason}\nexception: {exc}'
|
f"{reason}\nexception: {exc}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
@@ -189,10 +186,7 @@ def resolve_path_from_base(path_to_resolve: str, base_path: str) -> str:
|
|||||||
If path_to_resolve is an absolute path or a user path (~), just
|
If path_to_resolve is an absolute path or a user path (~), just
|
||||||
resolve it to an absolute path and return.
|
resolve it to an absolute path and return.
|
||||||
"""
|
"""
|
||||||
return os.path.abspath(
|
return os.path.abspath(os.path.join(base_path, os.path.expanduser(path_to_resolve)))
|
||||||
os.path.join(
|
|
||||||
base_path,
|
|
||||||
os.path.expanduser(path_to_resolve)))
|
|
||||||
|
|
||||||
|
|
||||||
def rmdir(path: str) -> None:
|
def rmdir(path: str) -> None:
|
||||||
@@ -202,7 +196,7 @@ def rmdir(path: str) -> None:
|
|||||||
cloned via git) can cause rmtree to throw a PermissionError exception
|
cloned via git) can cause rmtree to throw a PermissionError exception
|
||||||
"""
|
"""
|
||||||
path = convert_path(path)
|
path = convert_path(path)
|
||||||
if sys.platform == 'win32':
|
if sys.platform == "win32":
|
||||||
onerror = _windows_rmdir_readonly
|
onerror = _windows_rmdir_readonly
|
||||||
else:
|
else:
|
||||||
onerror = None
|
onerror = None
|
||||||
@@ -221,7 +215,7 @@ def _win_prepare_path(path: str) -> str:
|
|||||||
# letter back in.
|
# letter back in.
|
||||||
# Unless it starts with '\\'. In that case, the path is a UNC mount point
|
# Unless it starts with '\\'. In that case, the path is a UNC mount point
|
||||||
# and splitdrive will be fine.
|
# and splitdrive will be fine.
|
||||||
if not path.startswith('\\\\') and path.startswith('\\'):
|
if not path.startswith("\\\\") and path.startswith("\\"):
|
||||||
curdrive = os.path.splitdrive(os.getcwd())[0]
|
curdrive = os.path.splitdrive(os.getcwd())[0]
|
||||||
path = curdrive + path
|
path = curdrive + path
|
||||||
|
|
||||||
@@ -236,7 +230,7 @@ def _win_prepare_path(path: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def _supports_long_paths() -> bool:
|
def _supports_long_paths() -> bool:
|
||||||
if sys.platform != 'win32':
|
if sys.platform != "win32":
|
||||||
return True
|
return True
|
||||||
# Eryk Sun says to use `WinDLL('ntdll')` instead of `windll.ntdll` because
|
# Eryk Sun says to use `WinDLL('ntdll')` instead of `windll.ntdll` because
|
||||||
# of pointer caching in a comment here:
|
# of pointer caching in a comment here:
|
||||||
@@ -244,11 +238,11 @@ def _supports_long_paths() -> bool:
|
|||||||
# I don't know exaclty what he means, but I am inclined to believe him as
|
# I don't know exaclty what he means, but I am inclined to believe him as
|
||||||
# he's pretty active on Python windows bugs!
|
# he's pretty active on Python windows bugs!
|
||||||
try:
|
try:
|
||||||
dll = WinDLL('ntdll')
|
dll = WinDLL("ntdll")
|
||||||
except OSError: # I don't think this happens? you need ntdll to run python
|
except OSError: # I don't think this happens? you need ntdll to run python
|
||||||
return False
|
return False
|
||||||
# not all windows versions have it at all
|
# not all windows versions have it at all
|
||||||
if not hasattr(dll, 'RtlAreLongPathsEnabled'):
|
if not hasattr(dll, "RtlAreLongPathsEnabled"):
|
||||||
return False
|
return False
|
||||||
# tell windows we want to get back a single unsigned byte (a bool).
|
# tell windows we want to get back a single unsigned byte (a bool).
|
||||||
dll.RtlAreLongPathsEnabled.restype = c_bool
|
dll.RtlAreLongPathsEnabled.restype = c_bool
|
||||||
@@ -268,7 +262,7 @@ def convert_path(path: str) -> str:
|
|||||||
if _supports_long_paths():
|
if _supports_long_paths():
|
||||||
return path
|
return path
|
||||||
|
|
||||||
prefix = '\\\\?\\'
|
prefix = "\\\\?\\"
|
||||||
# Nothing to do
|
# Nothing to do
|
||||||
if path.startswith(prefix):
|
if path.startswith(prefix):
|
||||||
return path
|
return path
|
||||||
@@ -299,39 +293,35 @@ def path_is_symlink(path: str) -> bool:
|
|||||||
|
|
||||||
def open_dir_cmd() -> str:
|
def open_dir_cmd() -> str:
|
||||||
# https://docs.python.org/2/library/sys.html#sys.platform
|
# https://docs.python.org/2/library/sys.html#sys.platform
|
||||||
if sys.platform == 'win32':
|
if sys.platform == "win32":
|
||||||
return 'start'
|
return "start"
|
||||||
|
|
||||||
elif sys.platform == 'darwin':
|
elif sys.platform == "darwin":
|
||||||
return 'open'
|
return "open"
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return 'xdg-open'
|
return "xdg-open"
|
||||||
|
|
||||||
|
|
||||||
def _handle_posix_cwd_error(
|
def _handle_posix_cwd_error(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn:
|
||||||
exc: OSError, cwd: str, cmd: List[str]
|
|
||||||
) -> NoReturn:
|
|
||||||
if exc.errno == errno.ENOENT:
|
if exc.errno == errno.ENOENT:
|
||||||
message = 'Directory does not exist'
|
message = "Directory does not exist"
|
||||||
elif exc.errno == errno.EACCES:
|
elif exc.errno == errno.EACCES:
|
||||||
message = 'Current user cannot access directory, check permissions'
|
message = "Current user cannot access directory, check permissions"
|
||||||
elif exc.errno == errno.ENOTDIR:
|
elif exc.errno == errno.ENOTDIR:
|
||||||
message = 'Not a directory'
|
message = "Not a directory"
|
||||||
else:
|
else:
|
||||||
message = 'Unknown OSError: {} - cwd'.format(str(exc))
|
message = "Unknown OSError: {} - cwd".format(str(exc))
|
||||||
raise dbt.exceptions.WorkingDirectoryError(cwd, cmd, message)
|
raise dbt.exceptions.WorkingDirectoryError(cwd, cmd, message)
|
||||||
|
|
||||||
|
|
||||||
def _handle_posix_cmd_error(
|
def _handle_posix_cmd_error(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn:
|
||||||
exc: OSError, cwd: str, cmd: List[str]
|
|
||||||
) -> NoReturn:
|
|
||||||
if exc.errno == errno.ENOENT:
|
if exc.errno == errno.ENOENT:
|
||||||
message = "Could not find command, ensure it is in the user's PATH"
|
message = "Could not find command, ensure it is in the user's PATH"
|
||||||
elif exc.errno == errno.EACCES:
|
elif exc.errno == errno.EACCES:
|
||||||
message = 'User does not have permissions for this command'
|
message = "User does not have permissions for this command"
|
||||||
else:
|
else:
|
||||||
message = 'Unknown OSError: {} - cmd'.format(str(exc))
|
message = "Unknown OSError: {} - cmd".format(str(exc))
|
||||||
raise dbt.exceptions.ExecutableError(cwd, cmd, message)
|
raise dbt.exceptions.ExecutableError(cwd, cmd, message)
|
||||||
|
|
||||||
|
|
||||||
@@ -356,7 +346,7 @@ def _handle_posix_error(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn:
|
|||||||
- exc.errno == EACCES
|
- exc.errno == EACCES
|
||||||
- exc.filename == None(?)
|
- exc.filename == None(?)
|
||||||
"""
|
"""
|
||||||
if getattr(exc, 'filename', None) == cwd:
|
if getattr(exc, "filename", None) == cwd:
|
||||||
_handle_posix_cwd_error(exc, cwd, cmd)
|
_handle_posix_cwd_error(exc, cwd, cmd)
|
||||||
else:
|
else:
|
||||||
_handle_posix_cmd_error(exc, cwd, cmd)
|
_handle_posix_cmd_error(exc, cwd, cmd)
|
||||||
@@ -365,46 +355,48 @@ def _handle_posix_error(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn:
|
|||||||
def _handle_windows_error(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn:
|
def _handle_windows_error(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn:
|
||||||
cls: Type[dbt.exceptions.Exception] = dbt.exceptions.CommandError
|
cls: Type[dbt.exceptions.Exception] = dbt.exceptions.CommandError
|
||||||
if exc.errno == errno.ENOENT:
|
if exc.errno == errno.ENOENT:
|
||||||
message = ("Could not find command, ensure it is in the user's PATH "
|
message = (
|
||||||
"and that the user has permissions to run it")
|
"Could not find command, ensure it is in the user's PATH "
|
||||||
|
"and that the user has permissions to run it"
|
||||||
|
)
|
||||||
cls = dbt.exceptions.ExecutableError
|
cls = dbt.exceptions.ExecutableError
|
||||||
elif exc.errno == errno.ENOEXEC:
|
elif exc.errno == errno.ENOEXEC:
|
||||||
message = ('Command was not executable, ensure it is valid')
|
message = "Command was not executable, ensure it is valid"
|
||||||
cls = dbt.exceptions.ExecutableError
|
cls = dbt.exceptions.ExecutableError
|
||||||
elif exc.errno == errno.ENOTDIR:
|
elif exc.errno == errno.ENOTDIR:
|
||||||
message = ('Unable to cd: path does not exist, user does not have'
|
message = (
|
||||||
' permissions, or not a directory')
|
"Unable to cd: path does not exist, user does not have"
|
||||||
|
" permissions, or not a directory"
|
||||||
|
)
|
||||||
cls = dbt.exceptions.WorkingDirectoryError
|
cls = dbt.exceptions.WorkingDirectoryError
|
||||||
else:
|
else:
|
||||||
message = 'Unknown error: {} (errno={}: "{}")'.format(
|
message = 'Unknown error: {} (errno={}: "{}")'.format(
|
||||||
str(exc), exc.errno, errno.errorcode.get(exc.errno, '<Unknown!>')
|
str(exc), exc.errno, errno.errorcode.get(exc.errno, "<Unknown!>")
|
||||||
)
|
)
|
||||||
raise cls(cwd, cmd, message)
|
raise cls(cwd, cmd, message)
|
||||||
|
|
||||||
|
|
||||||
def _interpret_oserror(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn:
|
def _interpret_oserror(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn:
|
||||||
"""Interpret an OSError exc and raise the appropriate dbt exception.
|
"""Interpret an OSError exc and raise the appropriate dbt exception."""
|
||||||
|
|
||||||
"""
|
|
||||||
if len(cmd) == 0:
|
if len(cmd) == 0:
|
||||||
raise dbt.exceptions.CommandError(cwd, cmd)
|
raise dbt.exceptions.CommandError(cwd, cmd)
|
||||||
|
|
||||||
# all of these functions raise unconditionally
|
# all of these functions raise unconditionally
|
||||||
if os.name == 'nt':
|
if os.name == "nt":
|
||||||
_handle_windows_error(exc, cwd, cmd)
|
_handle_windows_error(exc, cwd, cmd)
|
||||||
else:
|
else:
|
||||||
_handle_posix_error(exc, cwd, cmd)
|
_handle_posix_error(exc, cwd, cmd)
|
||||||
|
|
||||||
# this should not be reachable, raise _something_ at least!
|
# this should not be reachable, raise _something_ at least!
|
||||||
raise dbt.exceptions.InternalException(
|
raise dbt.exceptions.InternalException(
|
||||||
'Unhandled exception in _interpret_oserror: {}'.format(exc)
|
"Unhandled exception in _interpret_oserror: {}".format(exc)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_cmd(
|
def run_cmd(
|
||||||
cwd: str, cmd: List[str], env: Optional[Dict[str, Any]] = None
|
cwd: str, cmd: List[str], env: Optional[Dict[str, Any]] = None
|
||||||
) -> Tuple[bytes, bytes]:
|
) -> Tuple[bytes, bytes]:
|
||||||
logger.debug('Executing "{}"'.format(' '.join(cmd)))
|
logger.debug('Executing "{}"'.format(" ".join(cmd)))
|
||||||
if len(cmd) == 0:
|
if len(cmd) == 0:
|
||||||
raise dbt.exceptions.CommandError(cwd, cmd)
|
raise dbt.exceptions.CommandError(cwd, cmd)
|
||||||
|
|
||||||
@@ -417,11 +409,8 @@ def run_cmd(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
proc = subprocess.Popen(
|
proc = subprocess.Popen(
|
||||||
cmd,
|
cmd, cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=full_env
|
||||||
cwd=cwd,
|
)
|
||||||
stdout=subprocess.PIPE,
|
|
||||||
stderr=subprocess.PIPE,
|
|
||||||
env=full_env)
|
|
||||||
|
|
||||||
out, err = proc.communicate()
|
out, err = proc.communicate()
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
@@ -431,18 +420,19 @@ def run_cmd(
|
|||||||
logger.debug('STDERR: "{!s}"'.format(err))
|
logger.debug('STDERR: "{!s}"'.format(err))
|
||||||
|
|
||||||
if proc.returncode != 0:
|
if proc.returncode != 0:
|
||||||
logger.debug('command return code={}'.format(proc.returncode))
|
logger.debug("command return code={}".format(proc.returncode))
|
||||||
raise dbt.exceptions.CommandResultError(cwd, cmd, proc.returncode,
|
raise dbt.exceptions.CommandResultError(cwd, cmd, proc.returncode, out, err)
|
||||||
out, err)
|
|
||||||
|
|
||||||
return out, err
|
return out, err
|
||||||
|
|
||||||
|
|
||||||
def download(url: str, path: str, timeout: Union[float, tuple] = None) -> None:
|
def download(
|
||||||
|
url: str, path: str, timeout: Optional[Union[float, tuple]] = None
|
||||||
|
) -> None:
|
||||||
path = convert_path(path)
|
path = convert_path(path)
|
||||||
connection_timeout = timeout or float(os.getenv('DBT_HTTP_TIMEOUT', 10))
|
connection_timeout = timeout or float(os.getenv("DBT_HTTP_TIMEOUT", 10))
|
||||||
response = requests.get(url, timeout=connection_timeout)
|
response = requests.get(url, timeout=connection_timeout)
|
||||||
with open(path, 'wb') as handle:
|
with open(path, "wb") as handle:
|
||||||
for block in response.iter_content(1024 * 64):
|
for block in response.iter_content(1024 * 64):
|
||||||
handle.write(block)
|
handle.write(block)
|
||||||
|
|
||||||
@@ -466,7 +456,7 @@ def untar_package(
|
|||||||
) -> None:
|
) -> None:
|
||||||
tar_path = convert_path(tar_path)
|
tar_path = convert_path(tar_path)
|
||||||
tar_dir_name = None
|
tar_dir_name = None
|
||||||
with tarfile.open(tar_path, 'r') as tarball:
|
with tarfile.open(tar_path, "r") as tarball:
|
||||||
tarball.extractall(dest_dir)
|
tarball.extractall(dest_dir)
|
||||||
tar_dir_name = os.path.commonprefix(tarball.getnames())
|
tar_dir_name = os.path.commonprefix(tarball.getnames())
|
||||||
if rename_to:
|
if rename_to:
|
||||||
@@ -482,7 +472,7 @@ def chmod_and_retry(func, path, exc_info):
|
|||||||
We want to retry most operations here, but listdir is one that we know will
|
We want to retry most operations here, but listdir is one that we know will
|
||||||
be useless.
|
be useless.
|
||||||
"""
|
"""
|
||||||
if func is os.listdir or os.name != 'nt':
|
if func is os.listdir or os.name != "nt":
|
||||||
raise
|
raise
|
||||||
os.chmod(path, stat.S_IREAD | stat.S_IWRITE)
|
os.chmod(path, stat.S_IREAD | stat.S_IWRITE)
|
||||||
# on error,this will raise.
|
# on error,this will raise.
|
||||||
@@ -503,7 +493,7 @@ def move(src, dst):
|
|||||||
"""
|
"""
|
||||||
src = convert_path(src)
|
src = convert_path(src)
|
||||||
dst = convert_path(dst)
|
dst = convert_path(dst)
|
||||||
if os.name != 'nt':
|
if os.name != "nt":
|
||||||
return shutil.move(src, dst)
|
return shutil.move(src, dst)
|
||||||
|
|
||||||
if os.path.isdir(dst):
|
if os.path.isdir(dst):
|
||||||
@@ -511,7 +501,7 @@ def move(src, dst):
|
|||||||
os.rename(src, dst)
|
os.rename(src, dst)
|
||||||
return
|
return
|
||||||
|
|
||||||
dst = os.path.join(dst, os.path.basename(src.rstrip('/\\')))
|
dst = os.path.join(dst, os.path.basename(src.rstrip("/\\")))
|
||||||
if os.path.exists(dst):
|
if os.path.exists(dst):
|
||||||
raise EnvironmentError("Path '{}' already exists".format(dst))
|
raise EnvironmentError("Path '{}' already exists".format(dst))
|
||||||
|
|
||||||
@@ -520,11 +510,10 @@ def move(src, dst):
|
|||||||
except OSError:
|
except OSError:
|
||||||
# probably different drives
|
# probably different drives
|
||||||
if os.path.isdir(src):
|
if os.path.isdir(src):
|
||||||
if _absnorm(dst + '\\').startswith(_absnorm(src + '\\')):
|
if _absnorm(dst + "\\").startswith(_absnorm(src + "\\")):
|
||||||
# dst is inside src
|
# dst is inside src
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
"Cannot move a directory '{}' into itself '{}'"
|
"Cannot move a directory '{}' into itself '{}'".format(src, dst)
|
||||||
.format(src, dst)
|
|
||||||
)
|
)
|
||||||
shutil.copytree(src, dst, symlinks=True)
|
shutil.copytree(src, dst, symlinks=True)
|
||||||
rmtree(src)
|
rmtree(src)
|
||||||
|
|||||||
@@ -1,16 +1,13 @@
|
|||||||
from typing import Any
|
|
||||||
|
|
||||||
import dbt.exceptions
|
import dbt.exceptions
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
import yaml.scanner
|
import yaml.scanner
|
||||||
|
|
||||||
# the C version is faster, but it doesn't always exist
|
# the C version is faster, but it doesn't always exist
|
||||||
YamlLoader: Any
|
|
||||||
try:
|
try:
|
||||||
from yaml import CSafeLoader as YamlLoader
|
from yaml import CLoader as Loader, CSafeLoader as SafeLoader, CDumper as Dumper
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from yaml import SafeLoader as YamlLoader
|
from yaml import Loader, SafeLoader, Dumper # type: ignore # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
YAML_ERROR_MESSAGE = """
|
YAML_ERROR_MESSAGE = """
|
||||||
@@ -30,14 +27,14 @@ def line_no(i, line, width=3):
|
|||||||
|
|
||||||
|
|
||||||
def prefix_with_line_numbers(string, no_start, no_end):
|
def prefix_with_line_numbers(string, no_start, no_end):
|
||||||
line_list = string.split('\n')
|
line_list = string.split("\n")
|
||||||
|
|
||||||
numbers = range(no_start, no_end)
|
numbers = range(no_start, no_end)
|
||||||
relevant_lines = line_list[no_start:no_end]
|
relevant_lines = line_list[no_start:no_end]
|
||||||
|
|
||||||
return "\n".join([
|
return "\n".join(
|
||||||
line_no(i + 1, line) for (i, line) in zip(numbers, relevant_lines)
|
[line_no(i + 1, line) for (i, line) in zip(numbers, relevant_lines)]
|
||||||
])
|
)
|
||||||
|
|
||||||
|
|
||||||
def contextualized_yaml_error(raw_contents, error):
|
def contextualized_yaml_error(raw_contents, error):
|
||||||
@@ -48,20 +45,20 @@ def contextualized_yaml_error(raw_contents, error):
|
|||||||
|
|
||||||
nice_error = prefix_with_line_numbers(raw_contents, min_line, max_line)
|
nice_error = prefix_with_line_numbers(raw_contents, min_line, max_line)
|
||||||
|
|
||||||
return YAML_ERROR_MESSAGE.format(line_number=mark.line + 1,
|
return YAML_ERROR_MESSAGE.format(
|
||||||
nice_error=nice_error,
|
line_number=mark.line + 1, nice_error=nice_error, raw_error=error
|
||||||
raw_error=error)
|
)
|
||||||
|
|
||||||
|
|
||||||
def safe_load(contents):
|
def safe_load(contents):
|
||||||
return yaml.load(contents, Loader=YamlLoader)
|
return yaml.load(contents, Loader=SafeLoader)
|
||||||
|
|
||||||
|
|
||||||
def load_yaml_text(contents):
|
def load_yaml_text(contents):
|
||||||
try:
|
try:
|
||||||
return safe_load(contents)
|
return safe_load(contents)
|
||||||
except (yaml.scanner.ScannerError, yaml.YAMLError) as e:
|
except (yaml.scanner.ScannerError, yaml.YAMLError) as e:
|
||||||
if hasattr(e, 'problem_mark'):
|
if hasattr(e, "problem_mark"):
|
||||||
error = contextualized_yaml_error(contents, e)
|
error = contextualized_yaml_error(contents, e)
|
||||||
else:
|
else:
|
||||||
error = str(e)
|
error = str(e)
|
||||||
|
|||||||
@@ -30,38 +30,43 @@ from dbt.graph import Graph
|
|||||||
from dbt.logger import GLOBAL_LOGGER as logger
|
from dbt.logger import GLOBAL_LOGGER as logger
|
||||||
from dbt.node_types import NodeType
|
from dbt.node_types import NodeType
|
||||||
from dbt.utils import pluralize
|
from dbt.utils import pluralize
|
||||||
|
import dbt.tracking
|
||||||
|
|
||||||
graph_file_name = 'graph.gpickle'
|
graph_file_name = "graph.gpickle"
|
||||||
|
|
||||||
|
|
||||||
def _compiled_type_for(model: ParsedNode):
|
def _compiled_type_for(model: ParsedNode):
|
||||||
if type(model) not in COMPILED_TYPES:
|
if type(model) not in COMPILED_TYPES:
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
f'Asked to compile {type(model)} node, but it has no compiled form'
|
f"Asked to compile {type(model)} node, but it has no compiled form"
|
||||||
)
|
)
|
||||||
return COMPILED_TYPES[type(model)]
|
return COMPILED_TYPES[type(model)]
|
||||||
|
|
||||||
|
|
||||||
def print_compile_stats(stats):
|
def print_compile_stats(stats):
|
||||||
names = {
|
names = {
|
||||||
NodeType.Model: 'model',
|
NodeType.Model: "model",
|
||||||
NodeType.Test: 'test',
|
NodeType.Test: "test",
|
||||||
NodeType.Snapshot: 'snapshot',
|
NodeType.Snapshot: "snapshot",
|
||||||
NodeType.Analysis: 'analysis',
|
NodeType.Analysis: "analysis",
|
||||||
NodeType.Macro: 'macro',
|
NodeType.Macro: "macro",
|
||||||
NodeType.Operation: 'operation',
|
NodeType.Operation: "operation",
|
||||||
NodeType.Seed: 'seed file',
|
NodeType.Seed: "seed file",
|
||||||
NodeType.Source: 'source',
|
NodeType.Source: "source",
|
||||||
NodeType.Exposure: 'exposure',
|
NodeType.Exposure: "exposure",
|
||||||
}
|
}
|
||||||
|
|
||||||
results = {k: 0 for k in names.keys()}
|
results = {k: 0 for k in names.keys()}
|
||||||
results.update(stats)
|
results.update(stats)
|
||||||
|
|
||||||
stat_line = ", ".join([
|
# create tracking event for resource_counts
|
||||||
pluralize(ct, names.get(t)) for t, ct in results.items()
|
if dbt.tracking.active_user is not None:
|
||||||
if t in names
|
resource_counts = {k.pluralize(): v for k, v in results.items()}
|
||||||
])
|
dbt.tracking.track_resource_counts(resource_counts)
|
||||||
|
|
||||||
|
stat_line = ", ".join(
|
||||||
|
[pluralize(ct, names.get(t)) for t, ct in results.items() if t in names]
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("Found {}".format(stat_line))
|
logger.info("Found {}".format(stat_line))
|
||||||
|
|
||||||
@@ -138,7 +143,7 @@ class Linker:
|
|||||||
"""
|
"""
|
||||||
out_graph = self.graph.copy()
|
out_graph = self.graph.copy()
|
||||||
for node_id in self.graph.nodes():
|
for node_id in self.graph.nodes():
|
||||||
data = manifest.expect(node_id).to_dict()
|
data = manifest.expect(node_id).to_dict(omit_none=True)
|
||||||
out_graph.add_node(node_id, **data)
|
out_graph.add_node(node_id, **data)
|
||||||
nx.write_gpickle(out_graph, outfile)
|
nx.write_gpickle(out_graph, outfile)
|
||||||
|
|
||||||
@@ -160,9 +165,7 @@ class Compiler:
|
|||||||
extra_context: Dict[str, Any],
|
extra_context: Dict[str, Any],
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
|
|
||||||
context = generate_runtime_model(
|
context = generate_runtime_model(node, self.config, manifest)
|
||||||
node, self.config, manifest
|
|
||||||
)
|
|
||||||
context.update(extra_context)
|
context.update(extra_context)
|
||||||
if isinstance(node, CompiledSchemaTestNode):
|
if isinstance(node, CompiledSchemaTestNode):
|
||||||
# for test nodes, add a special keyword args value to the context
|
# for test nodes, add a special keyword args value to the context
|
||||||
@@ -177,8 +180,7 @@ class Compiler:
|
|||||||
|
|
||||||
def _get_relation_name(self, node: ParsedNode):
|
def _get_relation_name(self, node: ParsedNode):
|
||||||
relation_name = None
|
relation_name = None
|
||||||
if (node.resource_type in NodeType.refable() and
|
if node.resource_type in NodeType.refable() and not node.is_ephemeral_model:
|
||||||
not node.is_ephemeral_model):
|
|
||||||
adapter = get_adapter(self.config)
|
adapter = get_adapter(self.config)
|
||||||
relation_cls = adapter.Relation
|
relation_cls = adapter.Relation
|
||||||
relation_name = str(relation_cls.create_from(self.config, node))
|
relation_name = str(relation_cls.create_from(self.config, node))
|
||||||
@@ -221,32 +223,29 @@ class Compiler:
|
|||||||
|
|
||||||
with_stmt = None
|
with_stmt = None
|
||||||
for token in parsed.tokens:
|
for token in parsed.tokens:
|
||||||
if token.is_keyword and token.normalized == 'WITH':
|
if token.is_keyword and token.normalized == "WITH":
|
||||||
with_stmt = token
|
with_stmt = token
|
||||||
break
|
break
|
||||||
|
|
||||||
if with_stmt is None:
|
if with_stmt is None:
|
||||||
# no with stmt, add one, and inject CTEs right at the beginning
|
# no with stmt, add one, and inject CTEs right at the beginning
|
||||||
first_token = parsed.token_first()
|
first_token = parsed.token_first()
|
||||||
with_stmt = sqlparse.sql.Token(sqlparse.tokens.Keyword, 'with')
|
with_stmt = sqlparse.sql.Token(sqlparse.tokens.Keyword, "with")
|
||||||
parsed.insert_before(first_token, with_stmt)
|
parsed.insert_before(first_token, with_stmt)
|
||||||
else:
|
else:
|
||||||
# stmt exists, add a comma (which will come after injected CTEs)
|
# stmt exists, add a comma (which will come after injected CTEs)
|
||||||
trailing_comma = sqlparse.sql.Token(
|
trailing_comma = sqlparse.sql.Token(sqlparse.tokens.Punctuation, ",")
|
||||||
sqlparse.tokens.Punctuation, ','
|
|
||||||
)
|
|
||||||
parsed.insert_after(with_stmt, trailing_comma)
|
parsed.insert_after(with_stmt, trailing_comma)
|
||||||
|
|
||||||
token = sqlparse.sql.Token(
|
token = sqlparse.sql.Token(
|
||||||
sqlparse.tokens.Keyword,
|
sqlparse.tokens.Keyword, ", ".join(c.sql for c in ctes)
|
||||||
", ".join(c.sql for c in ctes)
|
|
||||||
)
|
)
|
||||||
parsed.insert_after(with_stmt, token)
|
parsed.insert_after(with_stmt, token)
|
||||||
|
|
||||||
return str(parsed)
|
return str(parsed)
|
||||||
|
|
||||||
def _get_dbt_test_name(self) -> str:
|
def _get_dbt_test_name(self) -> str:
|
||||||
return 'dbt__cte__internal_test'
|
return "dbt__cte__internal_test"
|
||||||
|
|
||||||
# This method is called by the 'compile_node' method. Starting
|
# This method is called by the 'compile_node' method. Starting
|
||||||
# from the node that it is passed in, it will recursively call
|
# from the node that it is passed in, it will recursively call
|
||||||
@@ -262,9 +261,7 @@ class Compiler:
|
|||||||
) -> Tuple[NonSourceCompiledNode, List[InjectedCTE]]:
|
) -> Tuple[NonSourceCompiledNode, List[InjectedCTE]]:
|
||||||
|
|
||||||
if model.compiled_sql is None:
|
if model.compiled_sql is None:
|
||||||
raise RuntimeException(
|
raise RuntimeException("Cannot inject ctes into an unparsed node", model)
|
||||||
'Cannot inject ctes into an unparsed node', model
|
|
||||||
)
|
|
||||||
if model.extra_ctes_injected:
|
if model.extra_ctes_injected:
|
||||||
return (model, model.extra_ctes)
|
return (model, model.extra_ctes)
|
||||||
|
|
||||||
@@ -290,19 +287,18 @@ class Compiler:
|
|||||||
else:
|
else:
|
||||||
if cte.id not in manifest.nodes:
|
if cte.id not in manifest.nodes:
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
f'During compilation, found a cte reference that '
|
f"During compilation, found a cte reference that "
|
||||||
f'could not be resolved: {cte.id}'
|
f"could not be resolved: {cte.id}"
|
||||||
)
|
)
|
||||||
cte_model = manifest.nodes[cte.id]
|
cte_model = manifest.nodes[cte.id]
|
||||||
|
|
||||||
if not cte_model.is_ephemeral_model:
|
if not cte_model.is_ephemeral_model:
|
||||||
raise InternalException(f'{cte.id} is not ephemeral')
|
raise InternalException(f"{cte.id} is not ephemeral")
|
||||||
|
|
||||||
# This model has already been compiled, so it's been
|
# This model has already been compiled, so it's been
|
||||||
# through here before
|
# through here before
|
||||||
if getattr(cte_model, 'compiled', False):
|
if getattr(cte_model, "compiled", False):
|
||||||
assert isinstance(cte_model,
|
assert isinstance(cte_model, tuple(COMPILED_TYPES.values()))
|
||||||
tuple(COMPILED_TYPES.values()))
|
|
||||||
cte_model = cast(NonSourceCompiledNode, cte_model)
|
cte_model = cast(NonSourceCompiledNode, cte_model)
|
||||||
new_prepended_ctes = cte_model.extra_ctes
|
new_prepended_ctes = cte_model.extra_ctes
|
||||||
|
|
||||||
@@ -310,13 +306,11 @@ class Compiler:
|
|||||||
else:
|
else:
|
||||||
# This is an ephemeral parsed model that we can compile.
|
# This is an ephemeral parsed model that we can compile.
|
||||||
# Compile and update the node
|
# Compile and update the node
|
||||||
cte_model = self._compile_node(
|
cte_model = self._compile_node(cte_model, manifest, extra_context)
|
||||||
cte_model, manifest, extra_context)
|
|
||||||
# recursively call this method
|
# recursively call this method
|
||||||
cte_model, new_prepended_ctes = \
|
cte_model, new_prepended_ctes = self._recursively_prepend_ctes(
|
||||||
self._recursively_prepend_ctes(
|
cte_model, manifest, extra_context
|
||||||
cte_model, manifest, extra_context
|
)
|
||||||
)
|
|
||||||
# Save compiled SQL file and sync manifest
|
# Save compiled SQL file and sync manifest
|
||||||
self._write_node(cte_model)
|
self._write_node(cte_model)
|
||||||
manifest.sync_update_node(cte_model)
|
manifest.sync_update_node(cte_model)
|
||||||
@@ -324,7 +318,7 @@ class Compiler:
|
|||||||
_extend_prepended_ctes(prepended_ctes, new_prepended_ctes)
|
_extend_prepended_ctes(prepended_ctes, new_prepended_ctes)
|
||||||
|
|
||||||
new_cte_name = self.add_ephemeral_prefix(cte_model.name)
|
new_cte_name = self.add_ephemeral_prefix(cte_model.name)
|
||||||
sql = f' {new_cte_name} as (\n{cte_model.compiled_sql}\n)'
|
sql = f" {new_cte_name} as (\n{cte_model.compiled_sql}\n)"
|
||||||
|
|
||||||
_add_prepended_cte(prepended_ctes, InjectedCTE(id=cte.id, sql=sql))
|
_add_prepended_cte(prepended_ctes, InjectedCTE(id=cte.id, sql=sql))
|
||||||
|
|
||||||
@@ -339,7 +333,7 @@ class Compiler:
|
|||||||
model.compiled_sql = injected_sql
|
model.compiled_sql = injected_sql
|
||||||
model.extra_ctes_injected = True
|
model.extra_ctes_injected = True
|
||||||
model.extra_ctes = prepended_ctes
|
model.extra_ctes = prepended_ctes
|
||||||
model.validate(model.to_dict())
|
model.validate(model.to_dict(omit_none=True))
|
||||||
|
|
||||||
manifest.update_node(model)
|
manifest.update_node(model)
|
||||||
|
|
||||||
@@ -365,11 +359,10 @@ class Compiler:
|
|||||||
# compiled_sql, and do the regular prepend logic from CTEs.
|
# compiled_sql, and do the regular prepend logic from CTEs.
|
||||||
name = self._get_dbt_test_name()
|
name = self._get_dbt_test_name()
|
||||||
cte = InjectedCTE(
|
cte = InjectedCTE(
|
||||||
id=name,
|
id=name, sql=f" {name} as (\n{compiled_node.compiled_sql}\n)"
|
||||||
sql=f' {name} as (\n{compiled_node.compiled_sql}\n)'
|
|
||||||
)
|
)
|
||||||
compiled_node.extra_ctes.append(cte)
|
compiled_node.extra_ctes.append(cte)
|
||||||
compiled_node.compiled_sql = f'\nselect count(*) from {name}'
|
compiled_node.compiled_sql = f"\nselect count(*) from {name}"
|
||||||
|
|
||||||
return compiled_node
|
return compiled_node
|
||||||
|
|
||||||
@@ -388,18 +381,18 @@ class Compiler:
|
|||||||
|
|
||||||
logger.debug("Compiling {}".format(node.unique_id))
|
logger.debug("Compiling {}".format(node.unique_id))
|
||||||
|
|
||||||
data = node.to_dict()
|
data = node.to_dict(omit_none=True)
|
||||||
data.update({
|
data.update(
|
||||||
'compiled': False,
|
{
|
||||||
'compiled_sql': None,
|
"compiled": False,
|
||||||
'extra_ctes_injected': False,
|
"compiled_sql": None,
|
||||||
'extra_ctes': [],
|
"extra_ctes_injected": False,
|
||||||
})
|
"extra_ctes": [],
|
||||||
|
}
|
||||||
|
)
|
||||||
compiled_node = _compiled_type_for(node).from_dict(data)
|
compiled_node = _compiled_type_for(node).from_dict(data)
|
||||||
|
|
||||||
context = self._create_node_context(
|
context = self._create_node_context(compiled_node, manifest, extra_context)
|
||||||
compiled_node, manifest, extra_context
|
|
||||||
)
|
|
||||||
|
|
||||||
compiled_node.compiled_sql = jinja.get_rendered(
|
compiled_node.compiled_sql = jinja.get_rendered(
|
||||||
node.raw_sql,
|
node.raw_sql,
|
||||||
@@ -413,9 +406,7 @@ class Compiler:
|
|||||||
|
|
||||||
# add ctes for specific test nodes, and also for
|
# add ctes for specific test nodes, and also for
|
||||||
# possible future use in adapters
|
# possible future use in adapters
|
||||||
compiled_node = self._add_ctes(
|
compiled_node = self._add_ctes(compiled_node, manifest, extra_context)
|
||||||
compiled_node, manifest, extra_context
|
|
||||||
)
|
|
||||||
|
|
||||||
return compiled_node
|
return compiled_node
|
||||||
|
|
||||||
@@ -425,21 +416,17 @@ class Compiler:
|
|||||||
if flags.WRITE_JSON:
|
if flags.WRITE_JSON:
|
||||||
linker.write_graph(graph_path, manifest)
|
linker.write_graph(graph_path, manifest)
|
||||||
|
|
||||||
def link_node(
|
def link_node(self, linker: Linker, node: GraphMemberNode, manifest: Manifest):
|
||||||
self, linker: Linker, node: GraphMemberNode, manifest: Manifest
|
|
||||||
):
|
|
||||||
linker.add_node(node.unique_id)
|
linker.add_node(node.unique_id)
|
||||||
|
|
||||||
for dependency in node.depends_on_nodes:
|
for dependency in node.depends_on_nodes:
|
||||||
if dependency in manifest.nodes:
|
if dependency in manifest.nodes:
|
||||||
linker.dependency(
|
linker.dependency(
|
||||||
node.unique_id,
|
node.unique_id, (manifest.nodes[dependency].unique_id)
|
||||||
(manifest.nodes[dependency].unique_id)
|
|
||||||
)
|
)
|
||||||
elif dependency in manifest.sources:
|
elif dependency in manifest.sources:
|
||||||
linker.dependency(
|
linker.dependency(
|
||||||
node.unique_id,
|
node.unique_id, (manifest.sources[dependency].unique_id)
|
||||||
(manifest.sources[dependency].unique_id)
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
dependency_not_found(node, dependency)
|
dependency_not_found(node, dependency)
|
||||||
@@ -474,16 +461,13 @@ class Compiler:
|
|||||||
|
|
||||||
# writes the "compiled_sql" into the target/compiled directory
|
# writes the "compiled_sql" into the target/compiled directory
|
||||||
def _write_node(self, node: NonSourceCompiledNode) -> ManifestNode:
|
def _write_node(self, node: NonSourceCompiledNode) -> ManifestNode:
|
||||||
if (not node.extra_ctes_injected or
|
if not node.extra_ctes_injected or node.resource_type == NodeType.Snapshot:
|
||||||
node.resource_type == NodeType.Snapshot):
|
|
||||||
return node
|
return node
|
||||||
logger.debug(f'Writing injected SQL for node "{node.unique_id}"')
|
logger.debug(f'Writing injected SQL for node "{node.unique_id}"')
|
||||||
|
|
||||||
if node.compiled_sql:
|
if node.compiled_sql:
|
||||||
node.build_path = node.write_node(
|
node.build_path = node.write_node(
|
||||||
self.config.target_path,
|
self.config.target_path, "compiled", node.compiled_sql
|
||||||
'compiled',
|
|
||||||
node.compiled_sql
|
|
||||||
)
|
)
|
||||||
return node
|
return node
|
||||||
|
|
||||||
@@ -501,9 +485,7 @@ class Compiler:
|
|||||||
) -> NonSourceCompiledNode:
|
) -> NonSourceCompiledNode:
|
||||||
node = self._compile_node(node, manifest, extra_context)
|
node = self._compile_node(node, manifest, extra_context)
|
||||||
|
|
||||||
node, _ = self._recursively_prepend_ctes(
|
node, _ = self._recursively_prepend_ctes(node, manifest, extra_context)
|
||||||
node, manifest, extra_context
|
|
||||||
)
|
|
||||||
if write:
|
if write:
|
||||||
self._write_node(node)
|
self._write_node(node)
|
||||||
return node
|
return node
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from dataclasses import dataclass
|
|||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from hologram import ValidationError
|
from dbt.dataclass_schema import ValidationError
|
||||||
|
|
||||||
from dbt.clients.system import load_file_contents
|
from dbt.clients.system import load_file_contents
|
||||||
from dbt.clients.yaml_helper import load_yaml_text
|
from dbt.clients.yaml_helper import load_yaml_text
|
||||||
@@ -20,10 +20,8 @@ from dbt.utils import coerce_dict_str
|
|||||||
from .renderer import ProfileRenderer
|
from .renderer import ProfileRenderer
|
||||||
|
|
||||||
DEFAULT_THREADS = 1
|
DEFAULT_THREADS = 1
|
||||||
DEFAULT_PROFILES_DIR = os.path.join(os.path.expanduser('~'), '.dbt')
|
DEFAULT_PROFILES_DIR = os.path.join(os.path.expanduser("~"), ".dbt")
|
||||||
PROFILES_DIR = os.path.expanduser(
|
PROFILES_DIR = os.path.expanduser(os.getenv("DBT_PROFILES_DIR", DEFAULT_PROFILES_DIR))
|
||||||
os.getenv('DBT_PROFILES_DIR', DEFAULT_PROFILES_DIR)
|
|
||||||
)
|
|
||||||
|
|
||||||
INVALID_PROFILE_MESSAGE = """
|
INVALID_PROFILE_MESSAGE = """
|
||||||
dbt encountered an error while trying to read your profiles.yml file.
|
dbt encountered an error while trying to read your profiles.yml file.
|
||||||
@@ -43,11 +41,13 @@ Here, [profile name] should be replaced with a profile name
|
|||||||
defined in your profiles.yml file. You can find profiles.yml here:
|
defined in your profiles.yml file. You can find profiles.yml here:
|
||||||
|
|
||||||
{profiles_file}/profiles.yml
|
{profiles_file}/profiles.yml
|
||||||
""".format(profiles_file=PROFILES_DIR)
|
""".format(
|
||||||
|
profiles_file=PROFILES_DIR
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def read_profile(profiles_dir: str) -> Dict[str, Any]:
|
def read_profile(profiles_dir: str) -> Dict[str, Any]:
|
||||||
path = os.path.join(profiles_dir, 'profiles.yml')
|
path = os.path.join(profiles_dir, "profiles.yml")
|
||||||
|
|
||||||
contents = None
|
contents = None
|
||||||
if os.path.isfile(path):
|
if os.path.isfile(path):
|
||||||
@@ -55,12 +55,8 @@ def read_profile(profiles_dir: str) -> Dict[str, Any]:
|
|||||||
contents = load_file_contents(path, strip=False)
|
contents = load_file_contents(path, strip=False)
|
||||||
yaml_content = load_yaml_text(contents)
|
yaml_content = load_yaml_text(contents)
|
||||||
if not yaml_content:
|
if not yaml_content:
|
||||||
msg = f'The profiles.yml file at {path} is empty'
|
msg = f"The profiles.yml file at {path} is empty"
|
||||||
raise DbtProfileError(
|
raise DbtProfileError(INVALID_PROFILE_MESSAGE.format(error_string=msg))
|
||||||
INVALID_PROFILE_MESSAGE.format(
|
|
||||||
error_string=msg
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return yaml_content
|
return yaml_content
|
||||||
except ValidationException as e:
|
except ValidationException as e:
|
||||||
msg = INVALID_PROFILE_MESSAGE.format(error_string=e)
|
msg = INVALID_PROFILE_MESSAGE.format(error_string=e)
|
||||||
@@ -73,8 +69,9 @@ def read_user_config(directory: str) -> UserConfig:
|
|||||||
try:
|
try:
|
||||||
profile = read_profile(directory)
|
profile = read_profile(directory)
|
||||||
if profile:
|
if profile:
|
||||||
user_cfg = coerce_dict_str(profile.get('config', {}))
|
user_cfg = coerce_dict_str(profile.get("config", {}))
|
||||||
if user_cfg is not None:
|
if user_cfg is not None:
|
||||||
|
UserConfig.validate(user_cfg)
|
||||||
return UserConfig.from_dict(user_cfg)
|
return UserConfig.from_dict(user_cfg)
|
||||||
except (RuntimeException, ValidationError):
|
except (RuntimeException, ValidationError):
|
||||||
pass
|
pass
|
||||||
@@ -91,9 +88,7 @@ class Profile(HasCredentials):
|
|||||||
threads: int
|
threads: int
|
||||||
credentials: Credentials
|
credentials: Credentials
|
||||||
|
|
||||||
def to_profile_info(
|
def to_profile_info(self, serialize_credentials: bool = False) -> Dict[str, Any]:
|
||||||
self, serialize_credentials: bool = False
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Unlike to_project_config, this dict is not a mirror of any existing
|
"""Unlike to_project_config, this dict is not a mirror of any existing
|
||||||
on-disk data structure. It's used when creating a new profile from an
|
on-disk data structure. It's used when creating a new profile from an
|
||||||
existing one.
|
existing one.
|
||||||
@@ -103,44 +98,45 @@ class Profile(HasCredentials):
|
|||||||
:returns dict: The serialized profile.
|
:returns dict: The serialized profile.
|
||||||
"""
|
"""
|
||||||
result = {
|
result = {
|
||||||
'profile_name': self.profile_name,
|
"profile_name": self.profile_name,
|
||||||
'target_name': self.target_name,
|
"target_name": self.target_name,
|
||||||
'config': self.config,
|
"config": self.config,
|
||||||
'threads': self.threads,
|
"threads": self.threads,
|
||||||
'credentials': self.credentials,
|
"credentials": self.credentials,
|
||||||
}
|
}
|
||||||
if serialize_credentials:
|
if serialize_credentials:
|
||||||
result['config'] = self.config.to_dict()
|
result["config"] = self.config.to_dict(omit_none=True)
|
||||||
result['credentials'] = self.credentials.to_dict()
|
result["credentials"] = self.credentials.to_dict(omit_none=True)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def to_target_dict(self) -> Dict[str, Any]:
|
def to_target_dict(self) -> Dict[str, Any]:
|
||||||
target = dict(
|
target = dict(self.credentials.connection_info(with_aliases=True))
|
||||||
self.credentials.connection_info(with_aliases=True)
|
target.update(
|
||||||
|
{
|
||||||
|
"type": self.credentials.type,
|
||||||
|
"threads": self.threads,
|
||||||
|
"name": self.target_name,
|
||||||
|
"target_name": self.target_name,
|
||||||
|
"profile_name": self.profile_name,
|
||||||
|
"config": self.config.to_dict(omit_none=True),
|
||||||
|
}
|
||||||
)
|
)
|
||||||
target.update({
|
|
||||||
'type': self.credentials.type,
|
|
||||||
'threads': self.threads,
|
|
||||||
'name': self.target_name,
|
|
||||||
'target_name': self.target_name,
|
|
||||||
'profile_name': self.profile_name,
|
|
||||||
'config': self.config.to_dict(),
|
|
||||||
})
|
|
||||||
return target
|
return target
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
def __eq__(self, other: object) -> bool:
|
||||||
if not (isinstance(other, self.__class__) and
|
if not (
|
||||||
isinstance(self, other.__class__)):
|
isinstance(other, self.__class__) and isinstance(self, other.__class__)
|
||||||
|
):
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
return self.to_profile_info() == other.to_profile_info()
|
return self.to_profile_info() == other.to_profile_info()
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
try:
|
try:
|
||||||
if self.credentials:
|
if self.credentials:
|
||||||
self.credentials.to_dict(validate=True)
|
dct = self.credentials.to_dict(omit_none=True)
|
||||||
ProfileConfig.from_dict(
|
self.credentials.validate(dct)
|
||||||
self.to_profile_info(serialize_credentials=True)
|
dct = self.to_profile_info(serialize_credentials=True)
|
||||||
)
|
ProfileConfig.validate(dct)
|
||||||
except ValidationError as exc:
|
except ValidationError as exc:
|
||||||
raise DbtProfileError(validator_error_message(exc)) from exc
|
raise DbtProfileError(validator_error_message(exc)) from exc
|
||||||
|
|
||||||
@@ -150,22 +146,28 @@ class Profile(HasCredentials):
|
|||||||
) -> Credentials:
|
) -> Credentials:
|
||||||
# avoid an import cycle
|
# avoid an import cycle
|
||||||
from dbt.adapters.factory import load_plugin
|
from dbt.adapters.factory import load_plugin
|
||||||
|
|
||||||
# credentials carry their 'type' in their actual type, not their
|
# credentials carry their 'type' in their actual type, not their
|
||||||
# attributes. We do want this in order to pick our Credentials class.
|
# attributes. We do want this in order to pick our Credentials class.
|
||||||
if 'type' not in profile:
|
if "type" not in profile:
|
||||||
raise DbtProfileError(
|
raise DbtProfileError(
|
||||||
'required field "type" not found in profile {} and target {}'
|
'required field "type" not found in profile {} and target {}'.format(
|
||||||
.format(profile_name, target_name))
|
profile_name, target_name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
typename = profile.pop('type')
|
typename = profile.pop("type")
|
||||||
try:
|
try:
|
||||||
cls = load_plugin(typename)
|
cls = load_plugin(typename)
|
||||||
credentials = cls.from_dict(profile)
|
data = cls.translate_aliases(profile)
|
||||||
|
cls.validate(data)
|
||||||
|
credentials = cls.from_dict(data)
|
||||||
except (RuntimeException, ValidationError) as e:
|
except (RuntimeException, ValidationError) as e:
|
||||||
msg = str(e) if isinstance(e, RuntimeException) else e.message
|
msg = str(e) if isinstance(e, RuntimeException) else e.message
|
||||||
raise DbtProfileError(
|
raise DbtProfileError(
|
||||||
'Credentials in profile "{}", target "{}" invalid: {}'
|
'Credentials in profile "{}", target "{}" invalid: {}'.format(
|
||||||
.format(profile_name, target_name, msg)
|
profile_name, target_name, msg
|
||||||
|
)
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
return credentials
|
return credentials
|
||||||
@@ -186,19 +188,21 @@ class Profile(HasCredentials):
|
|||||||
def _get_profile_data(
|
def _get_profile_data(
|
||||||
profile: Dict[str, Any], profile_name: str, target_name: str
|
profile: Dict[str, Any], profile_name: str, target_name: str
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
if 'outputs' not in profile:
|
if "outputs" not in profile:
|
||||||
raise DbtProfileError(
|
raise DbtProfileError(
|
||||||
"outputs not specified in profile '{}'".format(profile_name)
|
"outputs not specified in profile '{}'".format(profile_name)
|
||||||
)
|
)
|
||||||
outputs = profile['outputs']
|
outputs = profile["outputs"]
|
||||||
|
|
||||||
if target_name not in outputs:
|
if target_name not in outputs:
|
||||||
outputs = '\n'.join(' - {}'.format(output)
|
outputs = "\n".join(" - {}".format(output) for output in outputs)
|
||||||
for output in outputs)
|
msg = (
|
||||||
msg = ("The profile '{}' does not have a target named '{}'. The "
|
"The profile '{}' does not have a target named '{}'. The "
|
||||||
"valid target names for this profile are:\n{}"
|
"valid target names for this profile are:\n{}".format(
|
||||||
.format(profile_name, target_name, outputs))
|
profile_name, target_name, outputs
|
||||||
raise DbtProfileError(msg, result_type='invalid_target')
|
)
|
||||||
|
)
|
||||||
|
raise DbtProfileError(msg, result_type="invalid_target")
|
||||||
profile_data = outputs[target_name]
|
profile_data = outputs[target_name]
|
||||||
|
|
||||||
if not isinstance(profile_data, dict):
|
if not isinstance(profile_data, dict):
|
||||||
@@ -206,7 +210,7 @@ class Profile(HasCredentials):
|
|||||||
f"output '{target_name}' of profile '{profile_name}' is "
|
f"output '{target_name}' of profile '{profile_name}' is "
|
||||||
f"misconfigured in profiles.yml"
|
f"misconfigured in profiles.yml"
|
||||||
)
|
)
|
||||||
raise DbtProfileError(msg, result_type='invalid_target')
|
raise DbtProfileError(msg, result_type="invalid_target")
|
||||||
|
|
||||||
return profile_data
|
return profile_data
|
||||||
|
|
||||||
@@ -217,8 +221,8 @@ class Profile(HasCredentials):
|
|||||||
threads: int,
|
threads: int,
|
||||||
profile_name: str,
|
profile_name: str,
|
||||||
target_name: str,
|
target_name: str,
|
||||||
user_cfg: Optional[Dict[str, Any]] = None
|
user_cfg: Optional[Dict[str, Any]] = None,
|
||||||
) -> 'Profile':
|
) -> "Profile":
|
||||||
"""Create a profile from an existing set of Credentials and the
|
"""Create a profile from an existing set of Credentials and the
|
||||||
remaining information.
|
remaining information.
|
||||||
|
|
||||||
@@ -233,6 +237,7 @@ class Profile(HasCredentials):
|
|||||||
"""
|
"""
|
||||||
if user_cfg is None:
|
if user_cfg is None:
|
||||||
user_cfg = {}
|
user_cfg = {}
|
||||||
|
UserConfig.validate(user_cfg)
|
||||||
config = UserConfig.from_dict(user_cfg)
|
config = UserConfig.from_dict(user_cfg)
|
||||||
|
|
||||||
profile = cls(
|
profile = cls(
|
||||||
@@ -240,7 +245,7 @@ class Profile(HasCredentials):
|
|||||||
target_name=target_name,
|
target_name=target_name,
|
||||||
config=config,
|
config=config,
|
||||||
threads=threads,
|
threads=threads,
|
||||||
credentials=credentials
|
credentials=credentials,
|
||||||
)
|
)
|
||||||
profile.validate()
|
profile.validate()
|
||||||
return profile
|
return profile
|
||||||
@@ -265,19 +270,18 @@ class Profile(HasCredentials):
|
|||||||
# name to extract a profile that we can render.
|
# name to extract a profile that we can render.
|
||||||
if target_override is not None:
|
if target_override is not None:
|
||||||
target_name = target_override
|
target_name = target_override
|
||||||
elif 'target' in raw_profile:
|
elif "target" in raw_profile:
|
||||||
# render the target if it was parsed from yaml
|
# render the target if it was parsed from yaml
|
||||||
target_name = renderer.render_value(raw_profile['target'])
|
target_name = renderer.render_value(raw_profile["target"])
|
||||||
else:
|
else:
|
||||||
target_name = 'default'
|
target_name = "default"
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"target not specified in profile '{}', using '{}'"
|
"target not specified in profile '{}', using '{}'".format(
|
||||||
.format(profile_name, target_name)
|
profile_name, target_name
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
raw_profile_data = cls._get_profile_data(
|
raw_profile_data = cls._get_profile_data(raw_profile, profile_name, target_name)
|
||||||
raw_profile, profile_name, target_name
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
profile_data = renderer.render_data(raw_profile_data)
|
profile_data = renderer.render_data(raw_profile_data)
|
||||||
@@ -294,7 +298,7 @@ class Profile(HasCredentials):
|
|||||||
user_cfg: Optional[Dict[str, Any]] = None,
|
user_cfg: Optional[Dict[str, Any]] = None,
|
||||||
target_override: Optional[str] = None,
|
target_override: Optional[str] = None,
|
||||||
threads_override: Optional[int] = None,
|
threads_override: Optional[int] = None,
|
||||||
) -> 'Profile':
|
) -> "Profile":
|
||||||
"""Create a profile from its raw profile information.
|
"""Create a profile from its raw profile information.
|
||||||
|
|
||||||
(this is an intermediate step, mostly useful for unit testing)
|
(this is an intermediate step, mostly useful for unit testing)
|
||||||
@@ -315,7 +319,7 @@ class Profile(HasCredentials):
|
|||||||
"""
|
"""
|
||||||
# user_cfg is not rendered.
|
# user_cfg is not rendered.
|
||||||
if user_cfg is None:
|
if user_cfg is None:
|
||||||
user_cfg = raw_profile.get('config')
|
user_cfg = raw_profile.get("config")
|
||||||
# TODO: should it be, and the values coerced to bool?
|
# TODO: should it be, and the values coerced to bool?
|
||||||
target_name, profile_data = cls.render_profile(
|
target_name, profile_data = cls.render_profile(
|
||||||
raw_profile, profile_name, target_override, renderer
|
raw_profile, profile_name, target_override, renderer
|
||||||
@@ -323,7 +327,7 @@ class Profile(HasCredentials):
|
|||||||
|
|
||||||
# valid connections never include the number of threads, but it's
|
# valid connections never include the number of threads, but it's
|
||||||
# stored on a per-connection level in the raw configs
|
# stored on a per-connection level in the raw configs
|
||||||
threads = profile_data.pop('threads', DEFAULT_THREADS)
|
threads = profile_data.pop("threads", DEFAULT_THREADS)
|
||||||
if threads_override is not None:
|
if threads_override is not None:
|
||||||
threads = threads_override
|
threads = threads_override
|
||||||
|
|
||||||
@@ -336,7 +340,7 @@ class Profile(HasCredentials):
|
|||||||
profile_name=profile_name,
|
profile_name=profile_name,
|
||||||
target_name=target_name,
|
target_name=target_name,
|
||||||
threads=threads,
|
threads=threads,
|
||||||
user_cfg=user_cfg
|
user_cfg=user_cfg,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -347,7 +351,7 @@ class Profile(HasCredentials):
|
|||||||
renderer: ProfileRenderer,
|
renderer: ProfileRenderer,
|
||||||
target_override: Optional[str] = None,
|
target_override: Optional[str] = None,
|
||||||
threads_override: Optional[int] = None,
|
threads_override: Optional[int] = None,
|
||||||
) -> 'Profile':
|
) -> "Profile":
|
||||||
"""
|
"""
|
||||||
:param raw_profiles: The profile data, from disk as yaml.
|
:param raw_profiles: The profile data, from disk as yaml.
|
||||||
:param profile_name: The profile name to use.
|
:param profile_name: The profile name to use.
|
||||||
@@ -371,15 +375,9 @@ class Profile(HasCredentials):
|
|||||||
# don't render keys, so we can pluck that out
|
# don't render keys, so we can pluck that out
|
||||||
raw_profile = raw_profiles[profile_name]
|
raw_profile = raw_profiles[profile_name]
|
||||||
if not raw_profile:
|
if not raw_profile:
|
||||||
msg = (
|
msg = f"Profile {profile_name} in profiles.yml is empty"
|
||||||
f'Profile {profile_name} in profiles.yml is empty'
|
raise DbtProfileError(INVALID_PROFILE_MESSAGE.format(error_string=msg))
|
||||||
)
|
user_cfg = raw_profiles.get("config")
|
||||||
raise DbtProfileError(
|
|
||||||
INVALID_PROFILE_MESSAGE.format(
|
|
||||||
error_string=msg
|
|
||||||
)
|
|
||||||
)
|
|
||||||
user_cfg = raw_profiles.get('config')
|
|
||||||
|
|
||||||
return cls.from_raw_profile_info(
|
return cls.from_raw_profile_info(
|
||||||
raw_profile=raw_profile,
|
raw_profile=raw_profile,
|
||||||
@@ -396,7 +394,7 @@ class Profile(HasCredentials):
|
|||||||
args: Any,
|
args: Any,
|
||||||
renderer: ProfileRenderer,
|
renderer: ProfileRenderer,
|
||||||
project_profile_name: Optional[str],
|
project_profile_name: Optional[str],
|
||||||
) -> 'Profile':
|
) -> "Profile":
|
||||||
"""Given the raw profiles as read from disk and the name of the desired
|
"""Given the raw profiles as read from disk and the name of the desired
|
||||||
profile if specified, return the profile component of the runtime
|
profile if specified, return the profile component of the runtime
|
||||||
config.
|
config.
|
||||||
@@ -411,15 +409,16 @@ class Profile(HasCredentials):
|
|||||||
target could not be found.
|
target could not be found.
|
||||||
:returns Profile: The new Profile object.
|
:returns Profile: The new Profile object.
|
||||||
"""
|
"""
|
||||||
threads_override = getattr(args, 'threads', None)
|
threads_override = getattr(args, "threads", None)
|
||||||
target_override = getattr(args, 'target', None)
|
target_override = getattr(args, "target", None)
|
||||||
raw_profiles = read_profile(args.profiles_dir)
|
raw_profiles = read_profile(args.profiles_dir)
|
||||||
profile_name = cls.pick_profile_name(getattr(args, 'profile', None),
|
profile_name = cls.pick_profile_name(
|
||||||
project_profile_name)
|
getattr(args, "profile", None), project_profile_name
|
||||||
|
)
|
||||||
return cls.from_raw_profiles(
|
return cls.from_raw_profiles(
|
||||||
raw_profiles=raw_profiles,
|
raw_profiles=raw_profiles,
|
||||||
profile_name=profile_name,
|
profile_name=profile_name,
|
||||||
renderer=renderer,
|
renderer=renderer,
|
||||||
target_override=target_override,
|
target_override=target_override,
|
||||||
threads_override=threads_override
|
threads_override=threads_override,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2,7 +2,13 @@ from copy import deepcopy
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import (
|
from typing import (
|
||||||
List, Dict, Any, Optional, TypeVar, Union, Mapping,
|
List,
|
||||||
|
Dict,
|
||||||
|
Any,
|
||||||
|
Optional,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
Mapping,
|
||||||
)
|
)
|
||||||
from typing_extensions import Protocol, runtime_checkable
|
from typing_extensions import Protocol, runtime_checkable
|
||||||
|
|
||||||
@@ -26,15 +32,12 @@ from dbt.version import get_installed_version
|
|||||||
from dbt.utils import MultiDict
|
from dbt.utils import MultiDict
|
||||||
from dbt.node_types import NodeType
|
from dbt.node_types import NodeType
|
||||||
from dbt.config.selectors import SelectorDict
|
from dbt.config.selectors import SelectorDict
|
||||||
|
|
||||||
from dbt.contracts.project import (
|
from dbt.contracts.project import (
|
||||||
Project as ProjectContract,
|
Project as ProjectContract,
|
||||||
SemverString,
|
SemverString,
|
||||||
)
|
)
|
||||||
from dbt.contracts.project import PackageConfig
|
from dbt.contracts.project import PackageConfig
|
||||||
|
from dbt.dataclass_schema import ValidationError
|
||||||
from hologram import ValidationError
|
|
||||||
|
|
||||||
from .renderer import DbtProjectYamlRenderer
|
from .renderer import DbtProjectYamlRenderer
|
||||||
from .selectors import (
|
from .selectors import (
|
||||||
selector_config_from_data,
|
selector_config_from_data,
|
||||||
@@ -85,9 +88,7 @@ def _load_yaml(path):
|
|||||||
|
|
||||||
|
|
||||||
def package_data_from_root(project_root):
|
def package_data_from_root(project_root):
|
||||||
package_filepath = resolve_path_from_base(
|
package_filepath = resolve_path_from_base("packages.yml", project_root)
|
||||||
'packages.yml', project_root
|
|
||||||
)
|
|
||||||
|
|
||||||
if path_exists(package_filepath):
|
if path_exists(package_filepath):
|
||||||
packages_dict = _load_yaml(package_filepath)
|
packages_dict = _load_yaml(package_filepath)
|
||||||
@@ -98,9 +99,10 @@ def package_data_from_root(project_root):
|
|||||||
|
|
||||||
def package_config_from_data(packages_data: Dict[str, Any]):
|
def package_config_from_data(packages_data: Dict[str, Any]):
|
||||||
if not packages_data:
|
if not packages_data:
|
||||||
packages_data = {'packages': []}
|
packages_data = {"packages": []}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
PackageConfig.validate(packages_data)
|
||||||
packages = PackageConfig.from_dict(packages_data)
|
packages = PackageConfig.from_dict(packages_data)
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
raise DbtProjectError(
|
raise DbtProjectError(
|
||||||
@@ -120,7 +122,7 @@ def _parse_versions(versions: Union[List[str], str]) -> List[VersionSpecifier]:
|
|||||||
Regardless, this will return a list of VersionSpecifiers
|
Regardless, this will return a list of VersionSpecifiers
|
||||||
"""
|
"""
|
||||||
if isinstance(versions, str):
|
if isinstance(versions, str):
|
||||||
versions = versions.split(',')
|
versions = versions.split(",")
|
||||||
return [VersionSpecifier.from_version_string(v) for v in versions]
|
return [VersionSpecifier.from_version_string(v) for v in versions]
|
||||||
|
|
||||||
|
|
||||||
@@ -131,11 +133,12 @@ def _all_source_paths(
|
|||||||
analysis_paths: List[str],
|
analysis_paths: List[str],
|
||||||
macro_paths: List[str],
|
macro_paths: List[str],
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
return list(chain(source_paths, data_paths, snapshot_paths, analysis_paths,
|
return list(
|
||||||
macro_paths))
|
chain(source_paths, data_paths, snapshot_paths, analysis_paths, macro_paths)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar('T')
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
def value_or(value: Optional[T], default: T) -> T:
|
def value_or(value: Optional[T], default: T) -> T:
|
||||||
@@ -148,30 +151,27 @@ def value_or(value: Optional[T], default: T) -> T:
|
|||||||
def _raw_project_from(project_root: str) -> Dict[str, Any]:
|
def _raw_project_from(project_root: str) -> Dict[str, Any]:
|
||||||
|
|
||||||
project_root = os.path.normpath(project_root)
|
project_root = os.path.normpath(project_root)
|
||||||
project_yaml_filepath = os.path.join(project_root, 'dbt_project.yml')
|
project_yaml_filepath = os.path.join(project_root, "dbt_project.yml")
|
||||||
|
|
||||||
# get the project.yml contents
|
# get the project.yml contents
|
||||||
if not path_exists(project_yaml_filepath):
|
if not path_exists(project_yaml_filepath):
|
||||||
raise DbtProjectError(
|
raise DbtProjectError(
|
||||||
'no dbt_project.yml found at expected path {}'
|
"no dbt_project.yml found at expected path {}".format(project_yaml_filepath)
|
||||||
.format(project_yaml_filepath)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
project_dict = _load_yaml(project_yaml_filepath)
|
project_dict = _load_yaml(project_yaml_filepath)
|
||||||
|
|
||||||
if not isinstance(project_dict, dict):
|
if not isinstance(project_dict, dict):
|
||||||
raise DbtProjectError(
|
raise DbtProjectError("dbt_project.yml does not parse to a dictionary")
|
||||||
'dbt_project.yml does not parse to a dictionary'
|
|
||||||
)
|
|
||||||
|
|
||||||
return project_dict
|
return project_dict
|
||||||
|
|
||||||
|
|
||||||
def _query_comment_from_cfg(
|
def _query_comment_from_cfg(
|
||||||
cfg_query_comment: Union[QueryComment, NoValue, str, None]
|
cfg_query_comment: Union[QueryComment, NoValue, str, None]
|
||||||
) -> QueryComment:
|
) -> QueryComment:
|
||||||
if not cfg_query_comment:
|
if not cfg_query_comment:
|
||||||
return QueryComment(comment='')
|
return QueryComment(comment="")
|
||||||
|
|
||||||
if isinstance(cfg_query_comment, str):
|
if isinstance(cfg_query_comment, str):
|
||||||
return QueryComment(comment=cfg_query_comment)
|
return QueryComment(comment=cfg_query_comment)
|
||||||
@@ -188,9 +188,7 @@ def validate_version(dbt_version: List[VersionSpecifier], project_name: str):
|
|||||||
if not versions_compatible(*dbt_version):
|
if not versions_compatible(*dbt_version):
|
||||||
msg = IMPOSSIBLE_VERSION_ERROR.format(
|
msg = IMPOSSIBLE_VERSION_ERROR.format(
|
||||||
package=project_name,
|
package=project_name,
|
||||||
version_spec=[
|
version_spec=[x.to_version_string() for x in dbt_version],
|
||||||
x.to_version_string() for x in dbt_version
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
raise DbtProjectError(msg)
|
raise DbtProjectError(msg)
|
||||||
|
|
||||||
@@ -198,9 +196,7 @@ def validate_version(dbt_version: List[VersionSpecifier], project_name: str):
|
|||||||
msg = INVALID_VERSION_ERROR.format(
|
msg = INVALID_VERSION_ERROR.format(
|
||||||
package=project_name,
|
package=project_name,
|
||||||
installed=installed.to_version_string(),
|
installed=installed.to_version_string(),
|
||||||
version_spec=[
|
version_spec=[x.to_version_string() for x in dbt_version],
|
||||||
x.to_version_string() for x in dbt_version
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
raise DbtProjectError(msg)
|
raise DbtProjectError(msg)
|
||||||
|
|
||||||
@@ -209,8 +205,8 @@ def _get_required_version(
|
|||||||
project_dict: Dict[str, Any],
|
project_dict: Dict[str, Any],
|
||||||
verify_version: bool,
|
verify_version: bool,
|
||||||
) -> List[VersionSpecifier]:
|
) -> List[VersionSpecifier]:
|
||||||
dbt_raw_version: Union[List[str], str] = '>=0.0.0'
|
dbt_raw_version: Union[List[str], str] = ">=0.0.0"
|
||||||
required = project_dict.get('require-dbt-version')
|
required = project_dict.get("require-dbt-version")
|
||||||
if required is not None:
|
if required is not None:
|
||||||
dbt_raw_version = required
|
dbt_raw_version = required
|
||||||
|
|
||||||
@@ -221,11 +217,11 @@ def _get_required_version(
|
|||||||
|
|
||||||
if verify_version:
|
if verify_version:
|
||||||
# no name is also an error that we want to raise
|
# no name is also an error that we want to raise
|
||||||
if 'name' not in project_dict:
|
if "name" not in project_dict:
|
||||||
raise DbtProjectError(
|
raise DbtProjectError(
|
||||||
'Required "name" field not present in project',
|
'Required "name" field not present in project',
|
||||||
)
|
)
|
||||||
validate_version(dbt_version, project_dict['name'])
|
validate_version(dbt_version, project_dict["name"])
|
||||||
|
|
||||||
return dbt_version
|
return dbt_version
|
||||||
|
|
||||||
@@ -233,34 +229,36 @@ def _get_required_version(
|
|||||||
@dataclass
|
@dataclass
|
||||||
class RenderComponents:
|
class RenderComponents:
|
||||||
project_dict: Dict[str, Any] = field(
|
project_dict: Dict[str, Any] = field(
|
||||||
metadata=dict(description='The project dictionary')
|
metadata=dict(description="The project dictionary")
|
||||||
)
|
)
|
||||||
packages_dict: Dict[str, Any] = field(
|
packages_dict: Dict[str, Any] = field(
|
||||||
metadata=dict(description='The packages dictionary')
|
metadata=dict(description="The packages dictionary")
|
||||||
)
|
)
|
||||||
selectors_dict: Dict[str, Any] = field(
|
selectors_dict: Dict[str, Any] = field(
|
||||||
metadata=dict(description='The selectors dictionary')
|
metadata=dict(description="The selectors dictionary")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PartialProject(RenderComponents):
|
class PartialProject(RenderComponents):
|
||||||
profile_name: Optional[str] = field(metadata=dict(
|
profile_name: Optional[str] = field(
|
||||||
description='The unrendered profile name in the project, if set'
|
metadata=dict(description="The unrendered profile name in the project, if set")
|
||||||
))
|
)
|
||||||
project_name: Optional[str] = field(metadata=dict(
|
project_name: Optional[str] = field(
|
||||||
description=(
|
metadata=dict(
|
||||||
'The name of the project. This should always be set and will not '
|
description=(
|
||||||
'be rendered'
|
"The name of the project. This should always be set and will not "
|
||||||
|
"be rendered"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
))
|
)
|
||||||
project_root: str = field(
|
project_root: str = field(
|
||||||
metadata=dict(description='The root directory of the project'),
|
metadata=dict(description="The root directory of the project"),
|
||||||
)
|
)
|
||||||
verify_version: bool = field(
|
verify_version: bool = field(
|
||||||
metadata=dict(description=(
|
metadata=dict(
|
||||||
'If True, verify the dbt version matches the required version'
|
description=("If True, verify the dbt version matches the required version")
|
||||||
))
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def render_profile_name(self, renderer) -> Optional[str]:
|
def render_profile_name(self, renderer) -> Optional[str]:
|
||||||
@@ -273,9 +271,7 @@ class PartialProject(RenderComponents):
|
|||||||
renderer: DbtProjectYamlRenderer,
|
renderer: DbtProjectYamlRenderer,
|
||||||
) -> RenderComponents:
|
) -> RenderComponents:
|
||||||
|
|
||||||
rendered_project = renderer.render_project(
|
rendered_project = renderer.render_project(self.project_dict, self.project_root)
|
||||||
self.project_dict, self.project_root
|
|
||||||
)
|
|
||||||
rendered_packages = renderer.render_packages(self.packages_dict)
|
rendered_packages = renderer.render_packages(self.packages_dict)
|
||||||
rendered_selectors = renderer.render_selectors(self.selectors_dict)
|
rendered_selectors = renderer.render_selectors(self.selectors_dict)
|
||||||
|
|
||||||
@@ -285,16 +281,16 @@ class PartialProject(RenderComponents):
|
|||||||
selectors_dict=rendered_selectors,
|
selectors_dict=rendered_selectors,
|
||||||
)
|
)
|
||||||
|
|
||||||
def render(self, renderer: DbtProjectYamlRenderer) -> 'Project':
|
def render(self, renderer: DbtProjectYamlRenderer) -> "Project":
|
||||||
try:
|
try:
|
||||||
rendered = self.get_rendered(renderer)
|
rendered = self.get_rendered(renderer)
|
||||||
return self.create_project(rendered)
|
return self.create_project(rendered)
|
||||||
except DbtProjectError as exc:
|
except DbtProjectError as exc:
|
||||||
if exc.path is None:
|
if exc.path is None:
|
||||||
exc.path = os.path.join(self.project_root, 'dbt_project.yml')
|
exc.path = os.path.join(self.project_root, "dbt_project.yml")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def create_project(self, rendered: RenderComponents) -> 'Project':
|
def create_project(self, rendered: RenderComponents) -> "Project":
|
||||||
unrendered = RenderComponents(
|
unrendered = RenderComponents(
|
||||||
project_dict=self.project_dict,
|
project_dict=self.project_dict,
|
||||||
packages_dict=self.packages_dict,
|
packages_dict=self.packages_dict,
|
||||||
@@ -306,6 +302,7 @@ class PartialProject(RenderComponents):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
ProjectContract.validate(rendered.project_dict)
|
||||||
cfg = ProjectContract.from_dict(rendered.project_dict)
|
cfg = ProjectContract.from_dict(rendered.project_dict)
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
raise DbtProjectError(validator_error_message(e)) from e
|
raise DbtProjectError(validator_error_message(e)) from e
|
||||||
@@ -316,37 +313,36 @@ class PartialProject(RenderComponents):
|
|||||||
# this is added at project_dict parse time and should always be here
|
# this is added at project_dict parse time and should always be here
|
||||||
# once we see it.
|
# once we see it.
|
||||||
if cfg.project_root is None:
|
if cfg.project_root is None:
|
||||||
raise DbtProjectError('cfg must have a project root!')
|
raise DbtProjectError("cfg must have a project root!")
|
||||||
else:
|
else:
|
||||||
project_root = cfg.project_root
|
project_root = cfg.project_root
|
||||||
# this is only optional in the sense that if it's not present, it needs
|
# this is only optional in the sense that if it's not present, it needs
|
||||||
# to have been a cli argument.
|
# to have been a cli argument.
|
||||||
profile_name = cfg.profile
|
profile_name = cfg.profile
|
||||||
# these are all the defaults
|
# these are all the defaults
|
||||||
source_paths: List[str] = value_or(cfg.source_paths, ['models'])
|
source_paths: List[str] = value_or(cfg.source_paths, ["models"])
|
||||||
macro_paths: List[str] = value_or(cfg.macro_paths, ['macros'])
|
macro_paths: List[str] = value_or(cfg.macro_paths, ["macros"])
|
||||||
data_paths: List[str] = value_or(cfg.data_paths, ['data'])
|
data_paths: List[str] = value_or(cfg.data_paths, ["data"])
|
||||||
test_paths: List[str] = value_or(cfg.test_paths, ['test'])
|
test_paths: List[str] = value_or(cfg.test_paths, ["test"])
|
||||||
analysis_paths: List[str] = value_or(cfg.analysis_paths, [])
|
analysis_paths: List[str] = value_or(cfg.analysis_paths, [])
|
||||||
snapshot_paths: List[str] = value_or(cfg.snapshot_paths, ['snapshots'])
|
snapshot_paths: List[str] = value_or(cfg.snapshot_paths, ["snapshots"])
|
||||||
|
|
||||||
all_source_paths: List[str] = _all_source_paths(
|
all_source_paths: List[str] = _all_source_paths(
|
||||||
source_paths, data_paths, snapshot_paths, analysis_paths,
|
source_paths, data_paths, snapshot_paths, analysis_paths, macro_paths
|
||||||
macro_paths
|
|
||||||
)
|
)
|
||||||
|
|
||||||
docs_paths: List[str] = value_or(cfg.docs_paths, all_source_paths)
|
docs_paths: List[str] = value_or(cfg.docs_paths, all_source_paths)
|
||||||
asset_paths: List[str] = value_or(cfg.asset_paths, [])
|
asset_paths: List[str] = value_or(cfg.asset_paths, [])
|
||||||
target_path: str = value_or(cfg.target_path, 'target')
|
target_path: str = value_or(cfg.target_path, "target")
|
||||||
clean_targets: List[str] = value_or(cfg.clean_targets, [target_path])
|
clean_targets: List[str] = value_or(cfg.clean_targets, [target_path])
|
||||||
log_path: str = value_or(cfg.log_path, 'logs')
|
log_path: str = value_or(cfg.log_path, "logs")
|
||||||
modules_path: str = value_or(cfg.modules_path, 'dbt_modules')
|
modules_path: str = value_or(cfg.modules_path, "dbt_modules")
|
||||||
# in the default case we'll populate this once we know the adapter type
|
# in the default case we'll populate this once we know the adapter type
|
||||||
# It would be nice to just pass along a Quoting here, but that would
|
# It would be nice to just pass along a Quoting here, but that would
|
||||||
# break many things
|
# break many things
|
||||||
quoting: Dict[str, Any] = {}
|
quoting: Dict[str, Any] = {}
|
||||||
if cfg.quoting is not None:
|
if cfg.quoting is not None:
|
||||||
quoting = cfg.quoting.to_dict()
|
quoting = cfg.quoting.to_dict(omit_none=True)
|
||||||
|
|
||||||
models: Dict[str, Any]
|
models: Dict[str, Any]
|
||||||
seeds: Dict[str, Any]
|
seeds: Dict[str, Any]
|
||||||
@@ -372,11 +368,12 @@ class PartialProject(RenderComponents):
|
|||||||
packages = package_config_from_data(rendered.packages_dict)
|
packages = package_config_from_data(rendered.packages_dict)
|
||||||
selectors = selector_config_from_data(rendered.selectors_dict)
|
selectors = selector_config_from_data(rendered.selectors_dict)
|
||||||
manifest_selectors: Dict[str, Any] = {}
|
manifest_selectors: Dict[str, Any] = {}
|
||||||
if rendered.selectors_dict and rendered.selectors_dict['selectors']:
|
if rendered.selectors_dict and rendered.selectors_dict["selectors"]:
|
||||||
# this is a dict with a single key 'selectors' pointing to a list
|
# this is a dict with a single key 'selectors' pointing to a list
|
||||||
# of dicts.
|
# of dicts.
|
||||||
manifest_selectors = SelectorDict.parse_from_selectors_list(
|
manifest_selectors = SelectorDict.parse_from_selectors_list(
|
||||||
rendered.selectors_dict['selectors'])
|
rendered.selectors_dict["selectors"]
|
||||||
|
)
|
||||||
|
|
||||||
project = Project(
|
project = Project(
|
||||||
project_name=name,
|
project_name=name,
|
||||||
@@ -425,10 +422,9 @@ class PartialProject(RenderComponents):
|
|||||||
*,
|
*,
|
||||||
verify_version: bool = False,
|
verify_version: bool = False,
|
||||||
):
|
):
|
||||||
"""Construct a partial project from its constituent dicts.
|
"""Construct a partial project from its constituent dicts."""
|
||||||
"""
|
project_name = project_dict.get("name")
|
||||||
project_name = project_dict.get('name')
|
profile_name = project_dict.get("profile")
|
||||||
profile_name = project_dict.get('profile')
|
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
profile_name=profile_name,
|
profile_name=profile_name,
|
||||||
@@ -443,14 +439,14 @@ class PartialProject(RenderComponents):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_project_root(
|
def from_project_root(
|
||||||
cls, project_root: str, *, verify_version: bool = False
|
cls, project_root: str, *, verify_version: bool = False
|
||||||
) -> 'PartialProject':
|
) -> "PartialProject":
|
||||||
project_root = os.path.normpath(project_root)
|
project_root = os.path.normpath(project_root)
|
||||||
project_dict = _raw_project_from(project_root)
|
project_dict = _raw_project_from(project_root)
|
||||||
config_version = project_dict.get('config-version', 1)
|
config_version = project_dict.get("config-version", 1)
|
||||||
if config_version != 2:
|
if config_version != 2:
|
||||||
raise DbtProjectError(
|
raise DbtProjectError(
|
||||||
f'Invalid config version: {config_version}, expected 2',
|
f"Invalid config version: {config_version}, expected 2",
|
||||||
path=os.path.join(project_root, 'dbt_project.yml')
|
path=os.path.join(project_root, "dbt_project.yml"),
|
||||||
)
|
)
|
||||||
|
|
||||||
packages_dict = package_data_from_root(project_root)
|
packages_dict = package_data_from_root(project_root)
|
||||||
@@ -467,15 +463,10 @@ class PartialProject(RenderComponents):
|
|||||||
class VarProvider:
|
class VarProvider:
|
||||||
"""Var providers are tied to a particular Project."""
|
"""Var providers are tied to a particular Project."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, vars: Dict[str, Dict[str, Any]]) -> None:
|
||||||
self,
|
|
||||||
vars: Dict[str, Dict[str, Any]]
|
|
||||||
) -> None:
|
|
||||||
self.vars = vars
|
self.vars = vars
|
||||||
|
|
||||||
def vars_for(
|
def vars_for(self, node: IsFQNResource, adapter_type: str) -> Mapping[str, Any]:
|
||||||
self, node: IsFQNResource, adapter_type: str
|
|
||||||
) -> Mapping[str, Any]:
|
|
||||||
# in v2, vars are only either project or globally scoped
|
# in v2, vars are only either project or globally scoped
|
||||||
merged = MultiDict([self.vars])
|
merged = MultiDict([self.vars])
|
||||||
merged.add(self.vars.get(node.package_name, {}))
|
merged.add(self.vars.get(node.package_name, {}))
|
||||||
@@ -524,8 +515,11 @@ class Project:
|
|||||||
@property
|
@property
|
||||||
def all_source_paths(self) -> List[str]:
|
def all_source_paths(self) -> List[str]:
|
||||||
return _all_source_paths(
|
return _all_source_paths(
|
||||||
self.source_paths, self.data_paths, self.snapshot_paths,
|
self.source_paths,
|
||||||
self.analysis_paths, self.macro_paths
|
self.data_paths,
|
||||||
|
self.snapshot_paths,
|
||||||
|
self.analysis_paths,
|
||||||
|
self.macro_paths,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
@@ -533,11 +527,13 @@ class Project:
|
|||||||
return str(cfg)
|
return str(cfg)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if not (isinstance(other, self.__class__) and
|
if not (
|
||||||
isinstance(self, other.__class__)):
|
isinstance(other, self.__class__) and isinstance(self, other.__class__)
|
||||||
|
):
|
||||||
return False
|
return False
|
||||||
return self.to_project_config(with_packages=True) == \
|
return self.to_project_config(with_packages=True) == other.to_project_config(
|
||||||
other.to_project_config(with_packages=True)
|
with_packages=True
|
||||||
|
)
|
||||||
|
|
||||||
def to_project_config(self, with_packages=False):
|
def to_project_config(self, with_packages=False):
|
||||||
"""Return a dict representation of the config that could be written to
|
"""Return a dict representation of the config that could be written to
|
||||||
@@ -547,46 +543,48 @@ class Project:
|
|||||||
file in the root.
|
file in the root.
|
||||||
:returns dict: The serialized profile.
|
:returns dict: The serialized profile.
|
||||||
"""
|
"""
|
||||||
result = deepcopy({
|
result = deepcopy(
|
||||||
'name': self.project_name,
|
{
|
||||||
'version': self.version,
|
"name": self.project_name,
|
||||||
'project-root': self.project_root,
|
"version": self.version,
|
||||||
'profile': self.profile_name,
|
"project-root": self.project_root,
|
||||||
'source-paths': self.source_paths,
|
"profile": self.profile_name,
|
||||||
'macro-paths': self.macro_paths,
|
"source-paths": self.source_paths,
|
||||||
'data-paths': self.data_paths,
|
"macro-paths": self.macro_paths,
|
||||||
'test-paths': self.test_paths,
|
"data-paths": self.data_paths,
|
||||||
'analysis-paths': self.analysis_paths,
|
"test-paths": self.test_paths,
|
||||||
'docs-paths': self.docs_paths,
|
"analysis-paths": self.analysis_paths,
|
||||||
'asset-paths': self.asset_paths,
|
"docs-paths": self.docs_paths,
|
||||||
'target-path': self.target_path,
|
"asset-paths": self.asset_paths,
|
||||||
'snapshot-paths': self.snapshot_paths,
|
"target-path": self.target_path,
|
||||||
'clean-targets': self.clean_targets,
|
"snapshot-paths": self.snapshot_paths,
|
||||||
'log-path': self.log_path,
|
"clean-targets": self.clean_targets,
|
||||||
'quoting': self.quoting,
|
"log-path": self.log_path,
|
||||||
'models': self.models,
|
"quoting": self.quoting,
|
||||||
'on-run-start': self.on_run_start,
|
"models": self.models,
|
||||||
'on-run-end': self.on_run_end,
|
"on-run-start": self.on_run_start,
|
||||||
'seeds': self.seeds,
|
"on-run-end": self.on_run_end,
|
||||||
'snapshots': self.snapshots,
|
"seeds": self.seeds,
|
||||||
'sources': self.sources,
|
"snapshots": self.snapshots,
|
||||||
'vars': self.vars.to_dict(),
|
"sources": self.sources,
|
||||||
'require-dbt-version': [
|
"vars": self.vars.to_dict(),
|
||||||
v.to_version_string() for v in self.dbt_version
|
"require-dbt-version": [
|
||||||
],
|
v.to_version_string() for v in self.dbt_version
|
||||||
'config-version': self.config_version,
|
],
|
||||||
})
|
"config-version": self.config_version,
|
||||||
|
}
|
||||||
|
)
|
||||||
if self.query_comment:
|
if self.query_comment:
|
||||||
result['query-comment'] = self.query_comment.to_dict()
|
result["query-comment"] = self.query_comment.to_dict(omit_none=True)
|
||||||
|
|
||||||
if with_packages:
|
if with_packages:
|
||||||
result.update(self.packages.to_dict())
|
result.update(self.packages.to_dict(omit_none=True))
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
try:
|
try:
|
||||||
ProjectContract.from_dict(self.to_project_config())
|
ProjectContract.validate(self.to_project_config())
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
raise DbtProjectError(validator_error_message(e)) from e
|
raise DbtProjectError(validator_error_message(e)) from e
|
||||||
|
|
||||||
@@ -608,8 +606,8 @@ class Project:
|
|||||||
selectors_dict: Dict[str, Any],
|
selectors_dict: Dict[str, Any],
|
||||||
renderer: DbtProjectYamlRenderer,
|
renderer: DbtProjectYamlRenderer,
|
||||||
*,
|
*,
|
||||||
verify_version: bool = False
|
verify_version: bool = False,
|
||||||
) -> 'Project':
|
) -> "Project":
|
||||||
partial = PartialProject.from_dicts(
|
partial = PartialProject.from_dicts(
|
||||||
project_root=project_root,
|
project_root=project_root,
|
||||||
project_dict=project_dict,
|
project_dict=project_dict,
|
||||||
@@ -626,17 +624,17 @@ class Project:
|
|||||||
renderer: DbtProjectYamlRenderer,
|
renderer: DbtProjectYamlRenderer,
|
||||||
*,
|
*,
|
||||||
verify_version: bool = False,
|
verify_version: bool = False,
|
||||||
) -> 'Project':
|
) -> "Project":
|
||||||
partial = cls.partial_load(project_root, verify_version=verify_version)
|
partial = cls.partial_load(project_root, verify_version=verify_version)
|
||||||
return partial.render(renderer)
|
return partial.render(renderer)
|
||||||
|
|
||||||
def hashed_name(self):
|
def hashed_name(self):
|
||||||
return hashlib.md5(self.project_name.encode('utf-8')).hexdigest()
|
return hashlib.md5(self.project_name.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
def get_selector(self, name: str) -> SelectionSpec:
|
def get_selector(self, name: str) -> SelectionSpec:
|
||||||
if name not in self.selectors:
|
if name not in self.selectors:
|
||||||
raise RuntimeException(
|
raise RuntimeException(
|
||||||
f'Could not find selector named {name}, expected one of '
|
f"Could not find selector named {name}, expected one of "
|
||||||
f'{list(self.selectors)}'
|
f"{list(self.selectors)}"
|
||||||
)
|
)
|
||||||
return self.selectors[name]
|
return self.selectors[name]
|
||||||
|
|||||||
@@ -2,9 +2,7 @@ from typing import Dict, Any, Tuple, Optional, Union, Callable
|
|||||||
|
|
||||||
from dbt.clients.jinja import get_rendered, catch_jinja
|
from dbt.clients.jinja import get_rendered, catch_jinja
|
||||||
|
|
||||||
from dbt.exceptions import (
|
from dbt.exceptions import DbtProjectError, CompilationException, RecursionException
|
||||||
DbtProjectError, CompilationException, RecursionException
|
|
||||||
)
|
|
||||||
from dbt.node_types import NodeType
|
from dbt.node_types import NodeType
|
||||||
from dbt.utils import deep_map
|
from dbt.utils import deep_map
|
||||||
|
|
||||||
@@ -18,7 +16,7 @@ class BaseRenderer:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
return 'Rendering'
|
return "Rendering"
|
||||||
|
|
||||||
def should_render_keypath(self, keypath: Keypath) -> bool:
|
def should_render_keypath(self, keypath: Keypath) -> bool:
|
||||||
return True
|
return True
|
||||||
@@ -29,9 +27,7 @@ class BaseRenderer:
|
|||||||
|
|
||||||
return self.render_value(value, keypath)
|
return self.render_value(value, keypath)
|
||||||
|
|
||||||
def render_value(
|
def render_value(self, value: Any, keypath: Optional[Keypath] = None) -> Any:
|
||||||
self, value: Any, keypath: Optional[Keypath] = None
|
|
||||||
) -> Any:
|
|
||||||
# keypath is ignored.
|
# keypath is ignored.
|
||||||
# if it wasn't read as a string, ignore it
|
# if it wasn't read as a string, ignore it
|
||||||
if not isinstance(value, str):
|
if not isinstance(value, str):
|
||||||
@@ -40,18 +36,16 @@ class BaseRenderer:
|
|||||||
with catch_jinja():
|
with catch_jinja():
|
||||||
return get_rendered(value, self.context, native=True)
|
return get_rendered(value, self.context, native=True)
|
||||||
except CompilationException as exc:
|
except CompilationException as exc:
|
||||||
msg = f'Could not render {value}: {exc.msg}'
|
msg = f"Could not render {value}: {exc.msg}"
|
||||||
raise CompilationException(msg) from exc
|
raise CompilationException(msg) from exc
|
||||||
|
|
||||||
def render_data(
|
def render_data(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
self, data: Dict[str, Any]
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
try:
|
try:
|
||||||
return deep_map(self.render_entry, data)
|
return deep_map(self.render_entry, data)
|
||||||
except RecursionException:
|
except RecursionException:
|
||||||
raise DbtProjectError(
|
raise DbtProjectError(
|
||||||
f'Cycle detected: {self.name} input has a reference to itself',
|
f"Cycle detected: {self.name} input has a reference to itself",
|
||||||
project=data
|
project=data,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -78,15 +72,15 @@ class ProjectPostprocessor(Dict[Keypath, Callable[[Any], Any]]):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self[('on-run-start',)] = _list_if_none_or_string
|
self[("on-run-start",)] = _list_if_none_or_string
|
||||||
self[('on-run-end',)] = _list_if_none_or_string
|
self[("on-run-end",)] = _list_if_none_or_string
|
||||||
|
|
||||||
for k in ('models', 'seeds', 'snapshots'):
|
for k in ("models", "seeds", "snapshots"):
|
||||||
self[(k,)] = _dict_if_none
|
self[(k,)] = _dict_if_none
|
||||||
self[(k, 'vars')] = _dict_if_none
|
self[(k, "vars")] = _dict_if_none
|
||||||
self[(k, 'pre-hook')] = _list_if_none_or_string
|
self[(k, "pre-hook")] = _list_if_none_or_string
|
||||||
self[(k, 'post-hook')] = _list_if_none_or_string
|
self[(k, "post-hook")] = _list_if_none_or_string
|
||||||
self[('seeds', 'column_types')] = _dict_if_none
|
self[("seeds", "column_types")] = _dict_if_none
|
||||||
|
|
||||||
def postprocess(self, value: Any, key: Keypath) -> Any:
|
def postprocess(self, value: Any, key: Keypath) -> Any:
|
||||||
if key in self:
|
if key in self:
|
||||||
@@ -101,7 +95,7 @@ class DbtProjectYamlRenderer(BaseRenderer):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
'Project config'
|
"Project config"
|
||||||
|
|
||||||
def get_package_renderer(self) -> BaseRenderer:
|
def get_package_renderer(self) -> BaseRenderer:
|
||||||
return PackageRenderer(self.context)
|
return PackageRenderer(self.context)
|
||||||
@@ -116,7 +110,7 @@ class DbtProjectYamlRenderer(BaseRenderer):
|
|||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Render the project and insert the project root after rendering."""
|
"""Render the project and insert the project root after rendering."""
|
||||||
rendered_project = self.render_data(project)
|
rendered_project = self.render_data(project)
|
||||||
rendered_project['project-root'] = project_root
|
rendered_project["project-root"] = project_root
|
||||||
return rendered_project
|
return rendered_project
|
||||||
|
|
||||||
def render_packages(self, packages: Dict[str, Any]):
|
def render_packages(self, packages: Dict[str, Any]):
|
||||||
@@ -138,20 +132,19 @@ class DbtProjectYamlRenderer(BaseRenderer):
|
|||||||
|
|
||||||
first = keypath[0]
|
first = keypath[0]
|
||||||
# run hooks are not rendered
|
# run hooks are not rendered
|
||||||
if first in {'on-run-start', 'on-run-end', 'query-comment'}:
|
if first in {"on-run-start", "on-run-end", "query-comment"}:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# don't render vars blocks until runtime
|
# don't render vars blocks until runtime
|
||||||
if first == 'vars':
|
if first == "vars":
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if first in {'seeds', 'models', 'snapshots', 'seeds'}:
|
if first in {"seeds", "models", "snapshots", "seeds"}:
|
||||||
keypath_parts = {
|
keypath_parts = {
|
||||||
(k.lstrip('+') if isinstance(k, str) else k)
|
(k.lstrip("+") if isinstance(k, str) else k) for k in keypath
|
||||||
for k in keypath
|
|
||||||
}
|
}
|
||||||
# model-level hooks
|
# model-level hooks
|
||||||
if 'pre-hook' in keypath_parts or 'post-hook' in keypath_parts:
|
if "pre-hook" in keypath_parts or "post-hook" in keypath_parts:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
@@ -160,17 +153,15 @@ class DbtProjectYamlRenderer(BaseRenderer):
|
|||||||
class ProfileRenderer(BaseRenderer):
|
class ProfileRenderer(BaseRenderer):
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
'Profile'
|
"Profile"
|
||||||
|
|
||||||
|
|
||||||
class SchemaYamlRenderer(BaseRenderer):
|
class SchemaYamlRenderer(BaseRenderer):
|
||||||
DOCUMENTABLE_NODES = frozenset(
|
DOCUMENTABLE_NODES = frozenset(n.pluralize() for n in NodeType.documentable())
|
||||||
n.pluralize() for n in NodeType.documentable()
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
return 'Rendering yaml'
|
return "Rendering yaml"
|
||||||
|
|
||||||
def _is_norender_key(self, keypath: Keypath) -> bool:
|
def _is_norender_key(self, keypath: Keypath) -> bool:
|
||||||
"""
|
"""
|
||||||
@@ -185,13 +176,13 @@ class SchemaYamlRenderer(BaseRenderer):
|
|||||||
|
|
||||||
Return True if it's tests or description - those aren't rendered
|
Return True if it's tests or description - those aren't rendered
|
||||||
"""
|
"""
|
||||||
if len(keypath) >= 2 and keypath[1] in ('tests', 'description'):
|
if len(keypath) >= 2 and keypath[1] in ("tests", "description"):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if (
|
if (
|
||||||
len(keypath) >= 4 and
|
len(keypath) >= 4
|
||||||
keypath[1] == 'columns' and
|
and keypath[1] == "columns"
|
||||||
keypath[3] in ('tests', 'description')
|
and keypath[3] in ("tests", "description")
|
||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -209,13 +200,13 @@ class SchemaYamlRenderer(BaseRenderer):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
if keypath[0] == NodeType.Source.pluralize():
|
if keypath[0] == NodeType.Source.pluralize():
|
||||||
if keypath[2] == 'description':
|
if keypath[2] == "description":
|
||||||
return False
|
return False
|
||||||
if keypath[2] == 'tables':
|
if keypath[2] == "tables":
|
||||||
if self._is_norender_key(keypath[3:]):
|
if self._is_norender_key(keypath[3:]):
|
||||||
return False
|
return False
|
||||||
elif keypath[0] == NodeType.Macro.pluralize():
|
elif keypath[0] == NodeType.Macro.pluralize():
|
||||||
if keypath[2] == 'arguments':
|
if keypath[2] == "arguments":
|
||||||
if self._is_norender_key(keypath[3:]):
|
if self._is_norender_key(keypath[3:]):
|
||||||
return False
|
return False
|
||||||
elif self._is_norender_key(keypath[1:]):
|
elif self._is_norender_key(keypath[1:]):
|
||||||
@@ -229,10 +220,10 @@ class SchemaYamlRenderer(BaseRenderer):
|
|||||||
class PackageRenderer(BaseRenderer):
|
class PackageRenderer(BaseRenderer):
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
return 'Packages config'
|
return "Packages config"
|
||||||
|
|
||||||
|
|
||||||
class SelectorRenderer(BaseRenderer):
|
class SelectorRenderer(BaseRenderer):
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
return 'Selector config'
|
return "Selector config"
|
||||||
|
|||||||
@@ -4,8 +4,16 @@ from copy import deepcopy
|
|||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass, fields
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import (
|
||||||
Dict, Any, Optional, Mapping, Iterator, Iterable, Tuple, List, MutableSet,
|
Dict,
|
||||||
Type
|
Any,
|
||||||
|
Optional,
|
||||||
|
Mapping,
|
||||||
|
Iterator,
|
||||||
|
Iterable,
|
||||||
|
Tuple,
|
||||||
|
List,
|
||||||
|
MutableSet,
|
||||||
|
Type,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .profile import Profile
|
from .profile import Profile
|
||||||
@@ -15,7 +23,7 @@ from .utils import parse_cli_vars
|
|||||||
from dbt import tracking
|
from dbt import tracking
|
||||||
from dbt.adapters.factory import get_relation_class_by_name, get_include_paths
|
from dbt.adapters.factory import get_relation_class_by_name, get_include_paths
|
||||||
from dbt.helper_types import FQNPath, PathSet
|
from dbt.helper_types import FQNPath, PathSet
|
||||||
from dbt.context.base import generate_base_context
|
from dbt.context import generate_base_context
|
||||||
from dbt.context.target import generate_target_context
|
from dbt.context.target import generate_target_context
|
||||||
from dbt.contracts.connection import AdapterRequiredConfig, Credentials
|
from dbt.contracts.connection import AdapterRequiredConfig, Credentials
|
||||||
from dbt.contracts.graph.manifest import ManifestMetadata
|
from dbt.contracts.graph.manifest import ManifestMetadata
|
||||||
@@ -30,15 +38,13 @@ from dbt.exceptions import (
|
|||||||
DbtProjectError,
|
DbtProjectError,
|
||||||
validator_error_message,
|
validator_error_message,
|
||||||
warn_or_error,
|
warn_or_error,
|
||||||
raise_compiler_error
|
raise_compiler_error,
|
||||||
)
|
)
|
||||||
|
|
||||||
from hologram import ValidationError
|
from dbt.dataclass_schema import ValidationError
|
||||||
|
|
||||||
|
|
||||||
def _project_quoting_dict(
|
def _project_quoting_dict(proj: Project, profile: Profile) -> Dict[ComponentName, bool]:
|
||||||
proj: Project, profile: Profile
|
|
||||||
) -> Dict[ComponentName, bool]:
|
|
||||||
src: Dict[str, Any] = profile.credentials.translate_aliases(proj.quoting)
|
src: Dict[str, Any] = profile.credentials.translate_aliases(proj.quoting)
|
||||||
result: Dict[ComponentName, bool] = {}
|
result: Dict[ComponentName, bool] = {}
|
||||||
for key in ComponentName:
|
for key in ComponentName:
|
||||||
@@ -54,7 +60,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
|||||||
args: Any
|
args: Any
|
||||||
profile_name: str
|
profile_name: str
|
||||||
cli_vars: Dict[str, Any]
|
cli_vars: Dict[str, Any]
|
||||||
dependencies: Optional[Mapping[str, 'RuntimeConfig']] = None
|
dependencies: Optional[Mapping[str, "RuntimeConfig"]] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.validate()
|
self.validate()
|
||||||
@@ -65,8 +71,8 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
|||||||
project: Project,
|
project: Project,
|
||||||
profile: Profile,
|
profile: Profile,
|
||||||
args: Any,
|
args: Any,
|
||||||
dependencies: Optional[Mapping[str, 'RuntimeConfig']] = None,
|
dependencies: Optional[Mapping[str, "RuntimeConfig"]] = None,
|
||||||
) -> 'RuntimeConfig':
|
) -> "RuntimeConfig":
|
||||||
"""Instantiate a RuntimeConfig from its components.
|
"""Instantiate a RuntimeConfig from its components.
|
||||||
|
|
||||||
:param profile: A parsed dbt Profile.
|
:param profile: A parsed dbt Profile.
|
||||||
@@ -78,9 +84,9 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
|||||||
get_relation_class_by_name(profile.credentials.type)
|
get_relation_class_by_name(profile.credentials.type)
|
||||||
.get_default_quote_policy()
|
.get_default_quote_policy()
|
||||||
.replace_dict(_project_quoting_dict(project, profile))
|
.replace_dict(_project_quoting_dict(project, profile))
|
||||||
).to_dict()
|
).to_dict(omit_none=True)
|
||||||
|
|
||||||
cli_vars: Dict[str, Any] = parse_cli_vars(getattr(args, 'vars', '{}'))
|
cli_vars: Dict[str, Any] = parse_cli_vars(getattr(args, "vars", "{}"))
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
project_name=project.project_name,
|
project_name=project.project_name,
|
||||||
@@ -123,7 +129,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
|||||||
dependencies=dependencies,
|
dependencies=dependencies,
|
||||||
)
|
)
|
||||||
|
|
||||||
def new_project(self, project_root: str) -> 'RuntimeConfig':
|
def new_project(self, project_root: str) -> "RuntimeConfig":
|
||||||
"""Given a new project root, read in its project dictionary, supply the
|
"""Given a new project root, read in its project dictionary, supply the
|
||||||
existing project's profile info, and create a new project file.
|
existing project's profile info, and create a new project file.
|
||||||
|
|
||||||
@@ -142,7 +148,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
|||||||
project = Project.from_project_root(
|
project = Project.from_project_root(
|
||||||
project_root,
|
project_root,
|
||||||
renderer,
|
renderer,
|
||||||
verify_version=getattr(self.args, 'version_check', False),
|
verify_version=getattr(self.args, "version_check", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
cfg = self.from_parts(
|
cfg = self.from_parts(
|
||||||
@@ -165,7 +171,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
|||||||
"""
|
"""
|
||||||
result = self.to_project_config(with_packages=True)
|
result = self.to_project_config(with_packages=True)
|
||||||
result.update(self.to_profile_info(serialize_credentials=True))
|
result.update(self.to_profile_info(serialize_credentials=True))
|
||||||
result['cli_vars'] = deepcopy(self.cli_vars)
|
result["cli_vars"] = deepcopy(self.cli_vars)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
@@ -174,7 +180,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
|||||||
:raises DbtProjectError: If the configuration fails validation.
|
:raises DbtProjectError: If the configuration fails validation.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
Configuration.from_dict(self.serialize())
|
Configuration.validate(self.serialize())
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
raise DbtProjectError(validator_error_message(e)) from e
|
raise DbtProjectError(validator_error_message(e)) from e
|
||||||
|
|
||||||
@@ -185,30 +191,21 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
|||||||
profile_renderer: ProfileRenderer,
|
profile_renderer: ProfileRenderer,
|
||||||
profile_name: Optional[str],
|
profile_name: Optional[str],
|
||||||
) -> Profile:
|
) -> Profile:
|
||||||
return Profile.render_from_args(
|
return Profile.render_from_args(args, profile_renderer, profile_name)
|
||||||
args, profile_renderer, profile_name
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def collect_parts(
|
def collect_parts(cls: Type["RuntimeConfig"], args: Any) -> Tuple[Project, Profile]:
|
||||||
cls: Type['RuntimeConfig'], args: Any
|
|
||||||
) -> Tuple[Project, Profile]:
|
|
||||||
# profile_name from the project
|
# profile_name from the project
|
||||||
project_root = args.project_dir if args.project_dir else os.getcwd()
|
project_root = args.project_dir if args.project_dir else os.getcwd()
|
||||||
version_check = getattr(args, 'version_check', False)
|
version_check = getattr(args, "version_check", False)
|
||||||
partial = Project.partial_load(
|
partial = Project.partial_load(project_root, verify_version=version_check)
|
||||||
project_root,
|
|
||||||
verify_version=version_check
|
|
||||||
)
|
|
||||||
|
|
||||||
# build the profile using the base renderer and the one fact we know
|
# build the profile using the base renderer and the one fact we know
|
||||||
cli_vars: Dict[str, Any] = parse_cli_vars(getattr(args, 'vars', '{}'))
|
cli_vars: Dict[str, Any] = parse_cli_vars(getattr(args, "vars", "{}"))
|
||||||
profile_renderer = ProfileRenderer(generate_base_context(cli_vars))
|
profile_renderer = ProfileRenderer(generate_base_context(cli_vars))
|
||||||
profile_name = partial.render_profile_name(profile_renderer)
|
profile_name = partial.render_profile_name(profile_renderer)
|
||||||
|
|
||||||
profile = cls._get_rendered_profile(
|
profile = cls._get_rendered_profile(args, profile_renderer, profile_name)
|
||||||
args, profile_renderer, profile_name
|
|
||||||
)
|
|
||||||
|
|
||||||
# get a new renderer using our target information and render the
|
# get a new renderer using our target information and render the
|
||||||
# project
|
# project
|
||||||
@@ -218,7 +215,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
|||||||
return (project, profile)
|
return (project, profile)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_args(cls, args: Any) -> 'RuntimeConfig':
|
def from_args(cls, args: Any) -> "RuntimeConfig":
|
||||||
"""Given arguments, read in dbt_project.yml from the current directory,
|
"""Given arguments, read in dbt_project.yml from the current directory,
|
||||||
read in packages.yml if it exists, and use them to find the profile to
|
read in packages.yml if it exists, and use them to find the profile to
|
||||||
load.
|
load.
|
||||||
@@ -238,8 +235,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
|||||||
|
|
||||||
def get_metadata(self) -> ManifestMetadata:
|
def get_metadata(self) -> ManifestMetadata:
|
||||||
return ManifestMetadata(
|
return ManifestMetadata(
|
||||||
project_id=self.hashed_name(),
|
project_id=self.hashed_name(), adapter_type=self.credentials.type
|
||||||
adapter_type=self.credentials.type
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_v2_config_paths(
|
def _get_v2_config_paths(
|
||||||
@@ -249,7 +245,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
|||||||
paths: MutableSet[FQNPath],
|
paths: MutableSet[FQNPath],
|
||||||
) -> PathSet:
|
) -> PathSet:
|
||||||
for key, value in config.items():
|
for key, value in config.items():
|
||||||
if isinstance(value, dict) and not key.startswith('+'):
|
if isinstance(value, dict) and not key.startswith("+"):
|
||||||
self._get_v2_config_paths(value, path + (key,), paths)
|
self._get_v2_config_paths(value, path + (key,), paths)
|
||||||
else:
|
else:
|
||||||
paths.add(path)
|
paths.add(path)
|
||||||
@@ -265,7 +261,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
|||||||
paths = set()
|
paths = set()
|
||||||
|
|
||||||
for key, value in config.items():
|
for key, value in config.items():
|
||||||
if isinstance(value, dict) and not key.startswith('+'):
|
if isinstance(value, dict) and not key.startswith("+"):
|
||||||
self._get_v2_config_paths(value, path + (key,), paths)
|
self._get_v2_config_paths(value, path + (key,), paths)
|
||||||
else:
|
else:
|
||||||
paths.add(path)
|
paths.add(path)
|
||||||
@@ -277,10 +273,10 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
|||||||
a configured path in the resource.
|
a configured path in the resource.
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
'models': self._get_config_paths(self.models),
|
"models": self._get_config_paths(self.models),
|
||||||
'seeds': self._get_config_paths(self.seeds),
|
"seeds": self._get_config_paths(self.seeds),
|
||||||
'snapshots': self._get_config_paths(self.snapshots),
|
"snapshots": self._get_config_paths(self.snapshots),
|
||||||
'sources': self._get_config_paths(self.sources),
|
"sources": self._get_config_paths(self.sources),
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_unused_resource_config_paths(
|
def get_unused_resource_config_paths(
|
||||||
@@ -301,9 +297,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
|||||||
|
|
||||||
for config_path in config_paths:
|
for config_path in config_paths:
|
||||||
if not _is_config_used(config_path, fqns):
|
if not _is_config_used(config_path, fqns):
|
||||||
unused_resource_config_paths.append(
|
unused_resource_config_paths.append((resource_type,) + config_path)
|
||||||
(resource_type,) + config_path
|
|
||||||
)
|
|
||||||
return unused_resource_config_paths
|
return unused_resource_config_paths
|
||||||
|
|
||||||
def warn_for_unused_resource_config_paths(
|
def warn_for_unused_resource_config_paths(
|
||||||
@@ -316,27 +310,25 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
|||||||
return
|
return
|
||||||
|
|
||||||
msg = UNUSED_RESOURCE_CONFIGURATION_PATH_MESSAGE.format(
|
msg = UNUSED_RESOURCE_CONFIGURATION_PATH_MESSAGE.format(
|
||||||
len(unused),
|
len(unused), "\n".join("- {}".format(".".join(u)) for u in unused)
|
||||||
'\n'.join('- {}'.format('.'.join(u)) for u in unused)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
warn_or_error(msg, log_fmt=warning_tag('{}'))
|
warn_or_error(msg, log_fmt=warning_tag("{}"))
|
||||||
|
|
||||||
def load_dependencies(self) -> Mapping[str, 'RuntimeConfig']:
|
def load_dependencies(self) -> Mapping[str, "RuntimeConfig"]:
|
||||||
if self.dependencies is None:
|
if self.dependencies is None:
|
||||||
all_projects = {self.project_name: self}
|
all_projects = {self.project_name: self}
|
||||||
internal_packages = get_include_paths(self.credentials.type)
|
internal_packages = get_include_paths(self.credentials.type)
|
||||||
project_paths = itertools.chain(
|
project_paths = itertools.chain(
|
||||||
internal_packages,
|
internal_packages, self._get_project_directories()
|
||||||
self._get_project_directories()
|
|
||||||
)
|
)
|
||||||
for project_name, project in self.load_projects(project_paths):
|
for project_name, project in self.load_projects(project_paths):
|
||||||
if project_name in all_projects:
|
if project_name in all_projects:
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
f'dbt found more than one package with the name '
|
f"dbt found more than one package with the name "
|
||||||
f'"{project_name}" included in this project. Package '
|
f'"{project_name}" included in this project. Package '
|
||||||
f'names must be unique in a project. Please rename '
|
f"names must be unique in a project. Please rename "
|
||||||
f'one of these packages.'
|
f"one of these packages."
|
||||||
)
|
)
|
||||||
all_projects[project_name] = project
|
all_projects[project_name] = project
|
||||||
self.dependencies = all_projects
|
self.dependencies = all_projects
|
||||||
@@ -347,14 +339,14 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
|||||||
|
|
||||||
def load_projects(
|
def load_projects(
|
||||||
self, paths: Iterable[Path]
|
self, paths: Iterable[Path]
|
||||||
) -> Iterator[Tuple[str, 'RuntimeConfig']]:
|
) -> Iterator[Tuple[str, "RuntimeConfig"]]:
|
||||||
for path in paths:
|
for path in paths:
|
||||||
try:
|
try:
|
||||||
project = self.new_project(str(path))
|
project = self.new_project(str(path))
|
||||||
except DbtProjectError as e:
|
except DbtProjectError as e:
|
||||||
raise DbtProjectError(
|
raise DbtProjectError(
|
||||||
f'Failed to read package: {e}',
|
f"Failed to read package: {e}",
|
||||||
result_type='invalid_project',
|
result_type="invalid_project",
|
||||||
path=path,
|
path=path,
|
||||||
) from e
|
) from e
|
||||||
else:
|
else:
|
||||||
@@ -365,13 +357,13 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
|||||||
|
|
||||||
if root.exists():
|
if root.exists():
|
||||||
for path in root.iterdir():
|
for path in root.iterdir():
|
||||||
if path.is_dir() and not path.name.startswith('__'):
|
if path.is_dir() and not path.name.startswith("__"):
|
||||||
yield path
|
yield path
|
||||||
|
|
||||||
|
|
||||||
class UnsetCredentials(Credentials):
|
class UnsetCredentials(Credentials):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__('', '')
|
super().__init__("", "")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def type(self):
|
def type(self):
|
||||||
@@ -387,11 +379,9 @@ class UnsetCredentials(Credentials):
|
|||||||
class UnsetConfig(UserConfig):
|
class UnsetConfig(UserConfig):
|
||||||
def __getattribute__(self, name):
|
def __getattribute__(self, name):
|
||||||
if name in {f.name for f in fields(UserConfig)}:
|
if name in {f.name for f in fields(UserConfig)}:
|
||||||
raise AttributeError(
|
raise AttributeError(f"'UnsetConfig' object has no attribute {name}")
|
||||||
f"'UnsetConfig' object has no attribute {name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_dict(self):
|
def __post_serialize__(self, dct):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
@@ -399,15 +389,15 @@ class UnsetProfile(Profile):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.credentials = UnsetCredentials()
|
self.credentials = UnsetCredentials()
|
||||||
self.config = UnsetConfig()
|
self.config = UnsetConfig()
|
||||||
self.profile_name = ''
|
self.profile_name = ""
|
||||||
self.target_name = ''
|
self.target_name = ""
|
||||||
self.threads = -1
|
self.threads = -1
|
||||||
|
|
||||||
def to_target_dict(self):
|
def to_target_dict(self):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def __getattribute__(self, name):
|
def __getattribute__(self, name):
|
||||||
if name in {'profile_name', 'target_name', 'threads'}:
|
if name in {"profile_name", "target_name", "threads"}:
|
||||||
raise RuntimeException(
|
raise RuntimeException(
|
||||||
f'Error: disallowed attribute "{name}" - no profile!'
|
f'Error: disallowed attribute "{name}" - no profile!'
|
||||||
)
|
)
|
||||||
@@ -431,7 +421,7 @@ class UnsetProfileConfig(RuntimeConfig):
|
|||||||
|
|
||||||
def __getattribute__(self, name):
|
def __getattribute__(self, name):
|
||||||
# Override __getattribute__ to check that the attribute isn't 'banned'.
|
# Override __getattribute__ to check that the attribute isn't 'banned'.
|
||||||
if name in {'profile_name', 'target_name'}:
|
if name in {"profile_name", "target_name"}:
|
||||||
raise RuntimeException(
|
raise RuntimeException(
|
||||||
f'Error: disallowed attribute "{name}" - no profile!'
|
f'Error: disallowed attribute "{name}" - no profile!'
|
||||||
)
|
)
|
||||||
@@ -449,8 +439,8 @@ class UnsetProfileConfig(RuntimeConfig):
|
|||||||
project: Project,
|
project: Project,
|
||||||
profile: Profile,
|
profile: Profile,
|
||||||
args: Any,
|
args: Any,
|
||||||
dependencies: Optional[Mapping[str, 'RuntimeConfig']] = None,
|
dependencies: Optional[Mapping[str, "RuntimeConfig"]] = None,
|
||||||
) -> 'RuntimeConfig':
|
) -> "RuntimeConfig":
|
||||||
"""Instantiate a RuntimeConfig from its components.
|
"""Instantiate a RuntimeConfig from its components.
|
||||||
|
|
||||||
:param profile: Ignored.
|
:param profile: Ignored.
|
||||||
@@ -458,7 +448,7 @@ class UnsetProfileConfig(RuntimeConfig):
|
|||||||
:param args: The parsed command-line arguments.
|
:param args: The parsed command-line arguments.
|
||||||
:returns RuntimeConfig: The new configuration.
|
:returns RuntimeConfig: The new configuration.
|
||||||
"""
|
"""
|
||||||
cli_vars: Dict[str, Any] = parse_cli_vars(getattr(args, 'vars', '{}'))
|
cli_vars: Dict[str, Any] = parse_cli_vars(getattr(args, "vars", "{}"))
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
project_name=project.project_name,
|
project_name=project.project_name,
|
||||||
@@ -491,10 +481,10 @@ class UnsetProfileConfig(RuntimeConfig):
|
|||||||
vars=project.vars,
|
vars=project.vars,
|
||||||
config_version=project.config_version,
|
config_version=project.config_version,
|
||||||
unrendered=project.unrendered,
|
unrendered=project.unrendered,
|
||||||
profile_name='',
|
profile_name="",
|
||||||
target_name='',
|
target_name="",
|
||||||
config=UnsetConfig(),
|
config=UnsetConfig(),
|
||||||
threads=getattr(args, 'threads', 1),
|
threads=getattr(args, "threads", 1),
|
||||||
credentials=UnsetCredentials(),
|
credentials=UnsetCredentials(),
|
||||||
args=args,
|
args=args,
|
||||||
cli_vars=cli_vars,
|
cli_vars=cli_vars,
|
||||||
@@ -509,16 +499,11 @@ class UnsetProfileConfig(RuntimeConfig):
|
|||||||
profile_name: Optional[str],
|
profile_name: Optional[str],
|
||||||
) -> Profile:
|
) -> Profile:
|
||||||
try:
|
try:
|
||||||
profile = Profile.render_from_args(
|
profile = Profile.render_from_args(args, profile_renderer, profile_name)
|
||||||
args, profile_renderer, profile_name
|
|
||||||
)
|
|
||||||
except (DbtProjectError, DbtProfileError) as exc:
|
except (DbtProjectError, DbtProfileError) as exc:
|
||||||
logger.debug(
|
logger.debug("Profile not loaded due to error: {}", exc, exc_info=True)
|
||||||
'Profile not loaded due to error: {}', exc, exc_info=True
|
|
||||||
)
|
|
||||||
logger.info(
|
logger.info(
|
||||||
'No profile "{}" found, continuing with no target',
|
'No profile "{}" found, continuing with no target', profile_name
|
||||||
profile_name
|
|
||||||
)
|
)
|
||||||
# return the poisoned form
|
# return the poisoned form
|
||||||
profile = UnsetProfile()
|
profile = UnsetProfile()
|
||||||
@@ -527,7 +512,7 @@ class UnsetProfileConfig(RuntimeConfig):
|
|||||||
return profile
|
return profile
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_args(cls: Type[RuntimeConfig], args: Any) -> 'RuntimeConfig':
|
def from_args(cls: Type[RuntimeConfig], args: Any) -> "RuntimeConfig":
|
||||||
"""Given arguments, read in dbt_project.yml from the current directory,
|
"""Given arguments, read in dbt_project.yml from the current directory,
|
||||||
read in packages.yml if it exists, and use them to find the profile to
|
read in packages.yml if it exists, and use them to find the profile to
|
||||||
load.
|
load.
|
||||||
@@ -542,11 +527,7 @@ class UnsetProfileConfig(RuntimeConfig):
|
|||||||
# if it's a real profile, return a real config
|
# if it's a real profile, return a real config
|
||||||
cls = RuntimeConfig
|
cls = RuntimeConfig
|
||||||
|
|
||||||
return cls.from_parts(
|
return cls.from_parts(project=project, profile=profile, args=args)
|
||||||
project=project,
|
|
||||||
profile=profile,
|
|
||||||
args=args
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
UNUSED_RESOURCE_CONFIGURATION_PATH_MESSAGE = """\
|
UNUSED_RESOURCE_CONFIGURATION_PATH_MESSAGE = """\
|
||||||
@@ -560,6 +541,6 @@ There are {} unused configuration paths:
|
|||||||
def _is_config_used(path, fqns):
|
def _is_config_used(path, fqns):
|
||||||
if fqns:
|
if fqns:
|
||||||
for fqn in fqns:
|
for fqn in fqns:
|
||||||
if len(path) <= len(fqn) and fqn[:len(path)] == path:
|
if len(path) <= len(fqn) and fqn[: len(path)] == path:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
import yaml
|
from dbt.clients.yaml_helper import yaml, Loader, Dumper, load_yaml_text # noqa: F401
|
||||||
|
from dbt.dataclass_schema import ValidationError
|
||||||
from hologram import ValidationError
|
|
||||||
|
|
||||||
from .renderer import SelectorRenderer
|
from .renderer import SelectorRenderer
|
||||||
|
|
||||||
@@ -11,7 +10,6 @@ from dbt.clients.system import (
|
|||||||
path_exists,
|
path_exists,
|
||||||
resolve_path_from_base,
|
resolve_path_from_base,
|
||||||
)
|
)
|
||||||
from dbt.clients.yaml_helper import load_yaml_text
|
|
||||||
from dbt.contracts.selection import SelectorFile
|
from dbt.contracts.selection import SelectorFile
|
||||||
from dbt.exceptions import DbtSelectorsError, RuntimeException
|
from dbt.exceptions import DbtSelectorsError, RuntimeException
|
||||||
from dbt.graph import parse_from_selectors_definition, SelectionSpec
|
from dbt.graph import parse_from_selectors_definition, SelectionSpec
|
||||||
@@ -31,8 +29,9 @@ Validator Error:
|
|||||||
|
|
||||||
class SelectorConfig(Dict[str, SelectionSpec]):
|
class SelectorConfig(Dict[str, SelectionSpec]):
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: Dict[str, Any]) -> 'SelectorConfig':
|
def selectors_from_dict(cls, data: Dict[str, Any]) -> "SelectorConfig":
|
||||||
try:
|
try:
|
||||||
|
SelectorFile.validate(data)
|
||||||
selector_file = SelectorFile.from_dict(data)
|
selector_file = SelectorFile.from_dict(data)
|
||||||
selectors = parse_from_selectors_definition(selector_file)
|
selectors = parse_from_selectors_definition(selector_file)
|
||||||
except ValidationError as exc:
|
except ValidationError as exc:
|
||||||
@@ -43,12 +42,12 @@ class SelectorConfig(Dict[str, SelectionSpec]):
|
|||||||
f"union, intersection, string, dictionary. No lists. "
|
f"union, intersection, string, dictionary. No lists. "
|
||||||
f"\nhttps://docs.getdbt.com/reference/node-selection/"
|
f"\nhttps://docs.getdbt.com/reference/node-selection/"
|
||||||
f"yaml-selectors",
|
f"yaml-selectors",
|
||||||
result_type='invalid_selector'
|
result_type="invalid_selector",
|
||||||
) from exc
|
) from exc
|
||||||
except RuntimeException as exc:
|
except RuntimeException as exc:
|
||||||
raise DbtSelectorsError(
|
raise DbtSelectorsError(
|
||||||
f'Could not read selector file data: {exc}',
|
f"Could not read selector file data: {exc}",
|
||||||
result_type='invalid_selector',
|
result_type="invalid_selector",
|
||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
return cls(selectors)
|
return cls(selectors)
|
||||||
@@ -58,26 +57,28 @@ class SelectorConfig(Dict[str, SelectionSpec]):
|
|||||||
cls,
|
cls,
|
||||||
data: Dict[str, Any],
|
data: Dict[str, Any],
|
||||||
renderer: SelectorRenderer,
|
renderer: SelectorRenderer,
|
||||||
) -> 'SelectorConfig':
|
) -> "SelectorConfig":
|
||||||
try:
|
try:
|
||||||
rendered = renderer.render_data(data)
|
rendered = renderer.render_data(data)
|
||||||
except (ValidationError, RuntimeException) as exc:
|
except (ValidationError, RuntimeException) as exc:
|
||||||
raise DbtSelectorsError(
|
raise DbtSelectorsError(
|
||||||
f'Could not render selector data: {exc}',
|
f"Could not render selector data: {exc}",
|
||||||
result_type='invalid_selector',
|
result_type="invalid_selector",
|
||||||
) from exc
|
) from exc
|
||||||
return cls.from_dict(rendered)
|
return cls.selectors_from_dict(rendered)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_path(
|
def from_path(
|
||||||
cls, path: Path, renderer: SelectorRenderer,
|
cls,
|
||||||
) -> 'SelectorConfig':
|
path: Path,
|
||||||
|
renderer: SelectorRenderer,
|
||||||
|
) -> "SelectorConfig":
|
||||||
try:
|
try:
|
||||||
data = load_yaml_text(load_file_contents(str(path)))
|
data = load_yaml_text(load_file_contents(str(path)))
|
||||||
except (ValidationError, RuntimeException) as exc:
|
except (ValidationError, RuntimeException) as exc:
|
||||||
raise DbtSelectorsError(
|
raise DbtSelectorsError(
|
||||||
f'Could not read selector file: {exc}',
|
f"Could not read selector file: {exc}",
|
||||||
result_type='invalid_selector',
|
result_type="invalid_selector",
|
||||||
path=path,
|
path=path,
|
||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
@@ -89,9 +90,7 @@ class SelectorConfig(Dict[str, SelectionSpec]):
|
|||||||
|
|
||||||
|
|
||||||
def selector_data_from_root(project_root: str) -> Dict[str, Any]:
|
def selector_data_from_root(project_root: str) -> Dict[str, Any]:
|
||||||
selector_filepath = resolve_path_from_base(
|
selector_filepath = resolve_path_from_base("selectors.yml", project_root)
|
||||||
'selectors.yml', project_root
|
|
||||||
)
|
|
||||||
|
|
||||||
if path_exists(selector_filepath):
|
if path_exists(selector_filepath):
|
||||||
selectors_dict = load_yaml_text(load_file_contents(selector_filepath))
|
selectors_dict = load_yaml_text(load_file_contents(selector_filepath))
|
||||||
@@ -100,18 +99,16 @@ def selector_data_from_root(project_root: str) -> Dict[str, Any]:
|
|||||||
return selectors_dict
|
return selectors_dict
|
||||||
|
|
||||||
|
|
||||||
def selector_config_from_data(
|
def selector_config_from_data(selectors_data: Dict[str, Any]) -> SelectorConfig:
|
||||||
selectors_data: Dict[str, Any]
|
|
||||||
) -> SelectorConfig:
|
|
||||||
if not selectors_data:
|
if not selectors_data:
|
||||||
selectors_data = {'selectors': []}
|
selectors_data = {"selectors": []}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
selectors = SelectorConfig.from_dict(selectors_data)
|
selectors = SelectorConfig.selectors_from_dict(selectors_data)
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
raise DbtSelectorsError(
|
raise DbtSelectorsError(
|
||||||
MALFORMED_SELECTOR_ERROR.format(error=str(e.message)),
|
MALFORMED_SELECTOR_ERROR.format(error=str(e.message)),
|
||||||
result_type='invalid_selector',
|
result_type="invalid_selector",
|
||||||
) from e
|
) from e
|
||||||
return selectors
|
return selectors
|
||||||
|
|
||||||
@@ -123,7 +120,6 @@ def selector_config_from_data(
|
|||||||
# be necessary to make changes here. Ideally it would be
|
# be necessary to make changes here. Ideally it would be
|
||||||
# good to combine the two flows into one at some point.
|
# good to combine the two flows into one at some point.
|
||||||
class SelectorDict:
|
class SelectorDict:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parse_dict_definition(cls, definition):
|
def parse_dict_definition(cls, definition):
|
||||||
key = list(definition)[0]
|
key = list(definition)[0]
|
||||||
@@ -134,10 +130,10 @@ class SelectorDict:
|
|||||||
new_value = cls.parse_from_definition(sel_def)
|
new_value = cls.parse_from_definition(sel_def)
|
||||||
new_values.append(new_value)
|
new_values.append(new_value)
|
||||||
value = new_values
|
value = new_values
|
||||||
if key == 'exclude':
|
if key == "exclude":
|
||||||
definition = {key: value}
|
definition = {key: value}
|
||||||
elif len(definition) == 1:
|
elif len(definition) == 1:
|
||||||
definition = {'method': key, 'value': value}
|
definition = {"method": key, "value": value}
|
||||||
return definition
|
return definition
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -159,10 +155,10 @@ class SelectorDict:
|
|||||||
def parse_from_definition(cls, definition):
|
def parse_from_definition(cls, definition):
|
||||||
if isinstance(definition, str):
|
if isinstance(definition, str):
|
||||||
definition = SelectionCriteria.dict_from_single_spec(definition)
|
definition = SelectionCriteria.dict_from_single_spec(definition)
|
||||||
elif 'union' in definition:
|
elif "union" in definition:
|
||||||
definition = cls.parse_a_definition('union', definition)
|
definition = cls.parse_a_definition("union", definition)
|
||||||
elif 'intersection' in definition:
|
elif "intersection" in definition:
|
||||||
definition = cls.parse_a_definition('intersection', definition)
|
definition = cls.parse_a_definition("intersection", definition)
|
||||||
elif isinstance(definition, dict):
|
elif isinstance(definition, dict):
|
||||||
definition = cls.parse_dict_definition(definition)
|
definition = cls.parse_dict_definition(definition)
|
||||||
return definition
|
return definition
|
||||||
@@ -173,8 +169,8 @@ class SelectorDict:
|
|||||||
def parse_from_selectors_list(cls, selectors):
|
def parse_from_selectors_list(cls, selectors):
|
||||||
selector_dict = {}
|
selector_dict = {}
|
||||||
for selector in selectors:
|
for selector in selectors:
|
||||||
sel_name = selector['name']
|
sel_name = selector["name"]
|
||||||
selector_dict[sel_name] = selector
|
selector_dict[sel_name] = selector
|
||||||
definition = cls.parse_from_definition(selector['definition'])
|
definition = cls.parse_from_definition(selector["definition"])
|
||||||
selector_dict[sel_name]['definition'] = definition
|
selector_dict[sel_name]["definition"] = definition
|
||||||
return selector_dict
|
return selector_dict
|
||||||
|
|||||||
@@ -15,9 +15,8 @@ def parse_cli_vars(var_string: str) -> Dict[str, Any]:
|
|||||||
type_name = var_type.__name__
|
type_name = var_type.__name__
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
"The --vars argument must be a YAML dictionary, but was "
|
"The --vars argument must be a YAML dictionary, but was "
|
||||||
"of type '{}'".format(type_name))
|
"of type '{}'".format(type_name)
|
||||||
|
)
|
||||||
except ValidationException:
|
except ValidationException:
|
||||||
logger.error(
|
logger.error("The YAML provided in the --vars argument is not valid.\n")
|
||||||
"The YAML provided in the --vars argument is not valid.\n"
|
|
||||||
)
|
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -1,19 +1,22 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import (
|
from typing import Any, Dict, NoReturn, Optional, Mapping
|
||||||
Any, Dict, NoReturn, Optional, Mapping
|
|
||||||
)
|
|
||||||
|
|
||||||
from dbt import flags
|
from dbt import flags
|
||||||
from dbt import tracking
|
from dbt import tracking
|
||||||
from dbt.clients.jinja import undefined_error, get_rendered
|
from dbt.clients.jinja import undefined_error, get_rendered
|
||||||
from dbt.clients import yaml_helper
|
from dbt.clients.yaml_helper import ( # noqa: F401
|
||||||
|
yaml,
|
||||||
|
safe_load,
|
||||||
|
SafeLoader,
|
||||||
|
Loader,
|
||||||
|
Dumper,
|
||||||
|
)
|
||||||
from dbt.contracts.graph.compiled import CompiledResource
|
from dbt.contracts.graph.compiled import CompiledResource
|
||||||
from dbt.exceptions import raise_compiler_error, MacroReturn
|
from dbt.exceptions import raise_compiler_error, MacroReturn
|
||||||
from dbt.logger import GLOBAL_LOGGER as logger
|
from dbt.logger import GLOBAL_LOGGER as logger
|
||||||
from dbt.version import __version__ as dbt_version
|
from dbt.version import __version__ as dbt_version
|
||||||
|
|
||||||
import yaml
|
|
||||||
# These modules are added to the context. Consider alternative
|
# These modules are added to the context. Consider alternative
|
||||||
# approaches which will extend well to potentially many modules
|
# approaches which will extend well to potentially many modules
|
||||||
import pytz
|
import pytz
|
||||||
@@ -24,38 +27,26 @@ import re
|
|||||||
def get_pytz_module_context() -> Dict[str, Any]:
|
def get_pytz_module_context() -> Dict[str, Any]:
|
||||||
context_exports = pytz.__all__ # type: ignore
|
context_exports = pytz.__all__ # type: ignore
|
||||||
|
|
||||||
return {
|
return {name: getattr(pytz, name) for name in context_exports}
|
||||||
name: getattr(pytz, name) for name in context_exports
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_datetime_module_context() -> Dict[str, Any]:
|
def get_datetime_module_context() -> Dict[str, Any]:
|
||||||
context_exports = [
|
context_exports = ["date", "datetime", "time", "timedelta", "tzinfo"]
|
||||||
'date',
|
|
||||||
'datetime',
|
|
||||||
'time',
|
|
||||||
'timedelta',
|
|
||||||
'tzinfo'
|
|
||||||
]
|
|
||||||
|
|
||||||
return {
|
return {name: getattr(datetime, name) for name in context_exports}
|
||||||
name: getattr(datetime, name) for name in context_exports
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_re_module_context() -> Dict[str, Any]:
|
def get_re_module_context() -> Dict[str, Any]:
|
||||||
context_exports = re.__all__
|
context_exports = re.__all__
|
||||||
|
|
||||||
return {
|
return {name: getattr(re, name) for name in context_exports}
|
||||||
name: getattr(re, name) for name in context_exports
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_context_modules() -> Dict[str, Dict[str, Any]]:
|
def get_context_modules() -> Dict[str, Dict[str, Any]]:
|
||||||
return {
|
return {
|
||||||
'pytz': get_pytz_module_context(),
|
"pytz": get_pytz_module_context(),
|
||||||
'datetime': get_datetime_module_context(),
|
"datetime": get_datetime_module_context(),
|
||||||
're': get_re_module_context(),
|
"re": get_re_module_context(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -89,8 +80,8 @@ class ContextMeta(type):
|
|||||||
new_dct = {}
|
new_dct = {}
|
||||||
|
|
||||||
for base in bases:
|
for base in bases:
|
||||||
context_members.update(getattr(base, '_context_members_', {}))
|
context_members.update(getattr(base, "_context_members_", {}))
|
||||||
context_attrs.update(getattr(base, '_context_attrs_', {}))
|
context_attrs.update(getattr(base, "_context_attrs_", {}))
|
||||||
|
|
||||||
for key, value in dct.items():
|
for key, value in dct.items():
|
||||||
if isinstance(value, ContextMember):
|
if isinstance(value, ContextMember):
|
||||||
@@ -99,21 +90,22 @@ class ContextMeta(type):
|
|||||||
context_attrs[context_key] = key
|
context_attrs[context_key] = key
|
||||||
value = value.inner
|
value = value.inner
|
||||||
new_dct[key] = value
|
new_dct[key] = value
|
||||||
new_dct['_context_members_'] = context_members
|
new_dct["_context_members_"] = context_members
|
||||||
new_dct['_context_attrs_'] = context_attrs
|
new_dct["_context_attrs_"] = context_attrs
|
||||||
return type.__new__(mcls, name, bases, new_dct)
|
return type.__new__(mcls, name, bases, new_dct)
|
||||||
|
|
||||||
|
|
||||||
class Var:
|
class Var:
|
||||||
UndefinedVarError = "Required var '{}' not found in config:\nVars "\
|
UndefinedVarError = (
|
||||||
"supplied to {} = {}"
|
"Required var '{}' not found in config:\nVars " "supplied to {} = {}"
|
||||||
|
)
|
||||||
_VAR_NOTSET = object()
|
_VAR_NOTSET = object()
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
context: Mapping[str, Any],
|
context: Mapping[str, Any],
|
||||||
cli_vars: Mapping[str, Any],
|
cli_vars: Mapping[str, Any],
|
||||||
node: Optional[CompiledResource] = None
|
node: Optional[CompiledResource] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._context: Mapping[str, Any] = context
|
self._context: Mapping[str, Any] = context
|
||||||
self._cli_vars: Mapping[str, Any] = cli_vars
|
self._cli_vars: Mapping[str, Any] = cli_vars
|
||||||
@@ -128,14 +120,12 @@ class Var:
|
|||||||
if self._node is not None:
|
if self._node is not None:
|
||||||
return self._node.name
|
return self._node.name
|
||||||
else:
|
else:
|
||||||
return '<Configuration>'
|
return "<Configuration>"
|
||||||
|
|
||||||
def get_missing_var(self, var_name):
|
def get_missing_var(self, var_name):
|
||||||
dct = {k: self._merged[k] for k in self._merged}
|
dct = {k: self._merged[k] for k in self._merged}
|
||||||
pretty_vars = json.dumps(dct, sort_keys=True, indent=4)
|
pretty_vars = json.dumps(dct, sort_keys=True, indent=4)
|
||||||
msg = self.UndefinedVarError.format(
|
msg = self.UndefinedVarError.format(var_name, self.node_name, pretty_vars)
|
||||||
var_name, self.node_name, pretty_vars
|
|
||||||
)
|
|
||||||
raise_compiler_error(msg, self._node)
|
raise_compiler_error(msg, self._node)
|
||||||
|
|
||||||
def has_var(self, var_name: str):
|
def has_var(self, var_name: str):
|
||||||
@@ -166,16 +156,17 @@ class BaseContext(metaclass=ContextMeta):
|
|||||||
def generate_builtins(self):
|
def generate_builtins(self):
|
||||||
builtins: Dict[str, Any] = {}
|
builtins: Dict[str, Any] = {}
|
||||||
for key, value in self._context_members_.items():
|
for key, value in self._context_members_.items():
|
||||||
if hasattr(value, '__get__'):
|
if hasattr(value, "__get__"):
|
||||||
# handle properties, bound methods, etc
|
# handle properties, bound methods, etc
|
||||||
value = value.__get__(self)
|
value = value.__get__(self)
|
||||||
builtins[key] = value
|
builtins[key] = value
|
||||||
return builtins
|
return builtins
|
||||||
|
|
||||||
|
# no dbtClassMixin so this is not an actual override
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
self._ctx['context'] = self._ctx
|
self._ctx["context"] = self._ctx
|
||||||
builtins = self.generate_builtins()
|
builtins = self.generate_builtins()
|
||||||
self._ctx['builtins'] = builtins
|
self._ctx["builtins"] = builtins
|
||||||
self._ctx.update(builtins)
|
self._ctx.update(builtins)
|
||||||
return self._ctx
|
return self._ctx
|
||||||
|
|
||||||
@@ -284,18 +275,20 @@ class BaseContext(metaclass=ContextMeta):
|
|||||||
msg = f"Env var required but not provided: '{var}'"
|
msg = f"Env var required but not provided: '{var}'"
|
||||||
undefined_error(msg)
|
undefined_error(msg)
|
||||||
|
|
||||||
if os.environ.get('DBT_MACRO_DEBUGGING'):
|
if os.environ.get("DBT_MACRO_DEBUGGING"):
|
||||||
|
|
||||||
@contextmember
|
@contextmember
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def debug():
|
def debug():
|
||||||
"""Enter a debugger at this line in the compiled jinja code."""
|
"""Enter a debugger at this line in the compiled jinja code."""
|
||||||
import sys
|
import sys
|
||||||
import ipdb # type: ignore
|
import ipdb # type: ignore
|
||||||
|
|
||||||
frame = sys._getframe(3)
|
frame = sys._getframe(3)
|
||||||
ipdb.set_trace(frame)
|
ipdb.set_trace(frame)
|
||||||
return ''
|
return ""
|
||||||
|
|
||||||
@contextmember('return')
|
@contextmember("return")
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _return(data: Any) -> NoReturn:
|
def _return(data: Any) -> NoReturn:
|
||||||
"""The `return` function can be used in macros to return data to the
|
"""The `return` function can be used in macros to return data to the
|
||||||
@@ -346,9 +339,7 @@ class BaseContext(metaclass=ContextMeta):
|
|||||||
|
|
||||||
@contextmember
|
@contextmember
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def tojson(
|
def tojson(value: Any, default: Any = None, sort_keys: bool = False) -> Any:
|
||||||
value: Any, default: Any = None, sort_keys: bool = False
|
|
||||||
) -> Any:
|
|
||||||
"""The `tojson` context method can be used to serialize a Python
|
"""The `tojson` context method can be used to serialize a Python
|
||||||
object primitive, eg. a `dict` or `list` to a json string.
|
object primitive, eg. a `dict` or `list` to a json string.
|
||||||
|
|
||||||
@@ -394,7 +385,7 @@ class BaseContext(metaclass=ContextMeta):
|
|||||||
-- ["good"]
|
-- ["good"]
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return yaml_helper.safe_load(value)
|
return safe_load(value)
|
||||||
except (AttributeError, ValueError, yaml.YAMLError):
|
except (AttributeError, ValueError, yaml.YAMLError):
|
||||||
return default
|
return default
|
||||||
|
|
||||||
@@ -444,7 +435,7 @@ class BaseContext(metaclass=ContextMeta):
|
|||||||
logger.info(msg)
|
logger.info(msg)
|
||||||
else:
|
else:
|
||||||
logger.debug(msg)
|
logger.debug(msg)
|
||||||
return ''
|
return ""
|
||||||
|
|
||||||
@contextproperty
|
@contextproperty
|
||||||
def run_started_at(self) -> Optional[datetime.datetime]:
|
def run_started_at(self) -> Optional[datetime.datetime]:
|
||||||
@@ -536,4 +527,5 @@ class BaseContext(metaclass=ContextMeta):
|
|||||||
|
|
||||||
def generate_base_context(cli_vars: Dict[str, Any]) -> Dict[str, Any]:
|
def generate_base_context(cli_vars: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
ctx = BaseContext(cli_vars)
|
ctx = BaseContext(cli_vars)
|
||||||
|
# This is not a Mashumaro to_dict call
|
||||||
return ctx.to_dict()
|
return ctx.to_dict()
|
||||||
@@ -4,16 +4,14 @@ from dbt.contracts.connection import AdapterRequiredConfig
|
|||||||
from dbt.node_types import NodeType
|
from dbt.node_types import NodeType
|
||||||
from dbt.utils import MultiDict
|
from dbt.utils import MultiDict
|
||||||
|
|
||||||
from dbt.context.base import contextproperty, Var
|
from dbt.context import contextproperty, Var
|
||||||
from dbt.context.target import TargetContext
|
from dbt.context.target import TargetContext
|
||||||
|
|
||||||
|
|
||||||
class ConfiguredContext(TargetContext):
|
class ConfiguredContext(TargetContext):
|
||||||
config: AdapterRequiredConfig
|
config: AdapterRequiredConfig
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, config: AdapterRequiredConfig) -> None:
|
||||||
self, config: AdapterRequiredConfig
|
|
||||||
) -> None:
|
|
||||||
super().__init__(config, config.cli_vars)
|
super().__init__(config, config.cli_vars)
|
||||||
|
|
||||||
@contextproperty
|
@contextproperty
|
||||||
@@ -70,9 +68,7 @@ class SchemaYamlContext(ConfiguredContext):
|
|||||||
|
|
||||||
@contextproperty
|
@contextproperty
|
||||||
def var(self) -> ConfiguredVar:
|
def var(self) -> ConfiguredVar:
|
||||||
return ConfiguredVar(
|
return ConfiguredVar(self._ctx, self.config, self._project_name)
|
||||||
self._ctx, self.config, self._project_name
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_schema_yml(
|
def generate_schema_yml(
|
||||||
|
|||||||
@@ -17,8 +17,8 @@ class ModelParts(IsFQNResource):
|
|||||||
package_name: str
|
package_name: str
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar('T') # any old type
|
T = TypeVar("T") # any old type
|
||||||
C = TypeVar('C', bound=BaseConfig)
|
C = TypeVar("C", bound=BaseConfig)
|
||||||
|
|
||||||
|
|
||||||
class ConfigSource:
|
class ConfigSource:
|
||||||
@@ -36,13 +36,13 @@ class UnrenderedConfig(ConfigSource):
|
|||||||
def get_config_dict(self, resource_type: NodeType) -> Dict[str, Any]:
|
def get_config_dict(self, resource_type: NodeType) -> Dict[str, Any]:
|
||||||
unrendered = self.project.unrendered.project_dict
|
unrendered = self.project.unrendered.project_dict
|
||||||
if resource_type == NodeType.Seed:
|
if resource_type == NodeType.Seed:
|
||||||
model_configs = unrendered.get('seeds')
|
model_configs = unrendered.get("seeds")
|
||||||
elif resource_type == NodeType.Snapshot:
|
elif resource_type == NodeType.Snapshot:
|
||||||
model_configs = unrendered.get('snapshots')
|
model_configs = unrendered.get("snapshots")
|
||||||
elif resource_type == NodeType.Source:
|
elif resource_type == NodeType.Source:
|
||||||
model_configs = unrendered.get('sources')
|
model_configs = unrendered.get("sources")
|
||||||
else:
|
else:
|
||||||
model_configs = unrendered.get('models')
|
model_configs = unrendered.get("models")
|
||||||
|
|
||||||
if model_configs is None:
|
if model_configs is None:
|
||||||
return {}
|
return {}
|
||||||
@@ -79,8 +79,8 @@ class BaseContextConfigGenerator(Generic[T]):
|
|||||||
dependencies = self._active_project.load_dependencies()
|
dependencies = self._active_project.load_dependencies()
|
||||||
if project_name not in dependencies:
|
if project_name not in dependencies:
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
f'Project name {project_name} not found in dependencies '
|
f"Project name {project_name} not found in dependencies "
|
||||||
f'(found {list(dependencies)})'
|
f"(found {list(dependencies)})"
|
||||||
)
|
)
|
||||||
return dependencies[project_name]
|
return dependencies[project_name]
|
||||||
|
|
||||||
@@ -92,7 +92,7 @@ class BaseContextConfigGenerator(Generic[T]):
|
|||||||
for level_config in fqn_search(model_configs, fqn):
|
for level_config in fqn_search(model_configs, fqn):
|
||||||
result = {}
|
result = {}
|
||||||
for key, value in level_config.items():
|
for key, value in level_config.items():
|
||||||
if key.startswith('+'):
|
if key.startswith("+"):
|
||||||
result[key[1:]] = deepcopy(value)
|
result[key[1:]] = deepcopy(value)
|
||||||
elif not isinstance(value, dict):
|
elif not isinstance(value, dict):
|
||||||
result[key] = deepcopy(value)
|
result[key] = deepcopy(value)
|
||||||
@@ -165,19 +165,15 @@ class ContextConfigGenerator(BaseContextConfigGenerator[C]):
|
|||||||
# Calculate the defaults. We don't want to validate the defaults,
|
# Calculate the defaults. We don't want to validate the defaults,
|
||||||
# because it might be invalid in the case of required config members
|
# because it might be invalid in the case of required config members
|
||||||
# (such as on snapshots!)
|
# (such as on snapshots!)
|
||||||
result = config_cls.from_dict({}, validate=False)
|
result = config_cls.from_dict({})
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _update_from_config(
|
def _update_from_config(
|
||||||
self, result: C, partial: Dict[str, Any], validate: bool = False
|
self, result: C, partial: Dict[str, Any], validate: bool = False
|
||||||
) -> C:
|
) -> C:
|
||||||
translated = self._active_project.credentials.translate_aliases(
|
translated = self._active_project.credentials.translate_aliases(partial)
|
||||||
partial
|
|
||||||
)
|
|
||||||
return result.update_from(
|
return result.update_from(
|
||||||
translated,
|
translated, self._active_project.credentials.type, validate=validate
|
||||||
self._active_project.credentials.type,
|
|
||||||
validate=validate
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def calculate_node_config_dict(
|
def calculate_node_config_dict(
|
||||||
@@ -196,7 +192,7 @@ class ContextConfigGenerator(BaseContextConfigGenerator[C]):
|
|||||||
base=base,
|
base=base,
|
||||||
)
|
)
|
||||||
finalized = config.finalize_and_validate()
|
finalized = config.finalize_and_validate()
|
||||||
return finalized.to_dict()
|
return finalized.to_dict(omit_none=True)
|
||||||
|
|
||||||
|
|
||||||
class UnrenderedConfigGenerator(BaseContextConfigGenerator[Dict[str, Any]]):
|
class UnrenderedConfigGenerator(BaseContextConfigGenerator[Dict[str, Any]]):
|
||||||
@@ -219,11 +215,7 @@ class UnrenderedConfigGenerator(BaseContextConfigGenerator[Dict[str, Any]]):
|
|||||||
base=base,
|
base=base,
|
||||||
)
|
)
|
||||||
|
|
||||||
def initial_result(
|
def initial_result(self, resource_type: NodeType, base: bool) -> Dict[str, Any]:
|
||||||
self,
|
|
||||||
resource_type: NodeType,
|
|
||||||
base: bool
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def _update_from_config(
|
def _update_from_config(
|
||||||
@@ -232,9 +224,7 @@ class UnrenderedConfigGenerator(BaseContextConfigGenerator[Dict[str, Any]]):
|
|||||||
partial: Dict[str, Any],
|
partial: Dict[str, Any],
|
||||||
validate: bool = False,
|
validate: bool = False,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
translated = self._active_project.credentials.translate_aliases(
|
translated = self._active_project.credentials.translate_aliases(partial)
|
||||||
partial
|
|
||||||
)
|
|
||||||
result.update(translated)
|
result.update(translated)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
from typing import (
|
from typing import Any, Dict, Union
|
||||||
Any, Dict, Union
|
|
||||||
)
|
|
||||||
|
|
||||||
from dbt.exceptions import (
|
from dbt.exceptions import (
|
||||||
doc_invalid_args,
|
doc_invalid_args,
|
||||||
@@ -11,7 +9,7 @@ from dbt.contracts.graph.compiled import CompileResultNode
|
|||||||
from dbt.contracts.graph.manifest import Manifest
|
from dbt.contracts.graph.manifest import Manifest
|
||||||
from dbt.contracts.graph.parsed import ParsedMacro
|
from dbt.contracts.graph.parsed import ParsedMacro
|
||||||
|
|
||||||
from dbt.context.base import contextmember
|
from dbt.context import contextmember
|
||||||
from dbt.context.configured import SchemaYamlContext
|
from dbt.context.configured import SchemaYamlContext
|
||||||
|
|
||||||
|
|
||||||
@@ -77,4 +75,5 @@ def generate_runtime_docs(
|
|||||||
current_project: str,
|
current_project: str,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
ctx = DocsRuntimeContext(config, target, manifest, current_project)
|
ctx = DocsRuntimeContext(config, target, manifest, current_project)
|
||||||
|
# This is not a Mashumaro to_dict call
|
||||||
return ctx.to_dict()
|
return ctx.to_dict()
|
||||||
|
|||||||
147
core/dbt/context/macro_resolver.py
Normal file
147
core/dbt/context/macro_resolver.py
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
from typing import Dict, MutableMapping, Optional
|
||||||
|
from dbt.contracts.graph.parsed import ParsedMacro
|
||||||
|
from dbt.exceptions import raise_duplicate_macro_name, raise_compiler_error
|
||||||
|
from dbt.include.global_project import PROJECT_NAME as GLOBAL_PROJECT_NAME
|
||||||
|
from dbt.clients.jinja import MacroGenerator
|
||||||
|
|
||||||
|
MacroNamespace = Dict[str, ParsedMacro]
|
||||||
|
|
||||||
|
|
||||||
|
# This class builds the MacroResolver by adding macros
|
||||||
|
# to various categories for finding macros in the right order,
|
||||||
|
# so that higher precedence macros are found first.
|
||||||
|
# This functionality is also provided by the MacroNamespace,
|
||||||
|
# but the intention is to eventually replace that class.
|
||||||
|
# This enables us to get the macor unique_id without
|
||||||
|
# processing every macro in the project.
|
||||||
|
class MacroResolver:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
macros: MutableMapping[str, ParsedMacro],
|
||||||
|
root_project_name: str,
|
||||||
|
internal_package_names,
|
||||||
|
) -> None:
|
||||||
|
self.root_project_name = root_project_name
|
||||||
|
self.macros = macros
|
||||||
|
# internal packages comes from get_adapter_package_names
|
||||||
|
self.internal_package_names = internal_package_names
|
||||||
|
|
||||||
|
# To be filled in from macros.
|
||||||
|
self.internal_packages: Dict[str, MacroNamespace] = {}
|
||||||
|
self.packages: Dict[str, MacroNamespace] = {}
|
||||||
|
self.root_package_macros: MacroNamespace = {}
|
||||||
|
|
||||||
|
# add the macros to internal_packages, packages, and root packages
|
||||||
|
self.add_macros()
|
||||||
|
self._build_internal_packages_namespace()
|
||||||
|
self._build_macros_by_name()
|
||||||
|
|
||||||
|
def _build_internal_packages_namespace(self):
|
||||||
|
# Iterate in reverse-order and overwrite: the packages that are first
|
||||||
|
# in the list are the ones we want to "win".
|
||||||
|
self.internal_packages_namespace: MacroNamespace = {}
|
||||||
|
for pkg in reversed(self.internal_package_names):
|
||||||
|
if pkg in self.internal_packages:
|
||||||
|
# Turn the internal packages into a flat namespace
|
||||||
|
self.internal_packages_namespace.update(self.internal_packages[pkg])
|
||||||
|
|
||||||
|
def _build_macros_by_name(self):
|
||||||
|
macros_by_name = {}
|
||||||
|
# search root package macros
|
||||||
|
for macro in self.root_package_macros.values():
|
||||||
|
macros_by_name[macro.name] = macro
|
||||||
|
# search miscellaneous non-internal packages
|
||||||
|
for fnamespace in self.packages.values():
|
||||||
|
for macro in fnamespace.values():
|
||||||
|
macros_by_name[macro.name] = macro
|
||||||
|
# search all internal packages
|
||||||
|
for macro in self.internal_packages_namespace.values():
|
||||||
|
macros_by_name[macro.name] = macro
|
||||||
|
self.macros_by_name = macros_by_name
|
||||||
|
|
||||||
|
def _add_macro_to(
|
||||||
|
self,
|
||||||
|
package_namespaces: Dict[str, MacroNamespace],
|
||||||
|
macro: ParsedMacro,
|
||||||
|
):
|
||||||
|
if macro.package_name in package_namespaces:
|
||||||
|
namespace = package_namespaces[macro.package_name]
|
||||||
|
else:
|
||||||
|
namespace = {}
|
||||||
|
package_namespaces[macro.package_name] = namespace
|
||||||
|
|
||||||
|
if macro.name in namespace:
|
||||||
|
raise_duplicate_macro_name(macro, macro, macro.package_name)
|
||||||
|
package_namespaces[macro.package_name][macro.name] = macro
|
||||||
|
|
||||||
|
def add_macro(self, macro: ParsedMacro):
|
||||||
|
macro_name: str = macro.name
|
||||||
|
|
||||||
|
# internal macros (from plugins) will be processed separately from
|
||||||
|
# project macros, so store them in a different place
|
||||||
|
if macro.package_name in self.internal_package_names:
|
||||||
|
self._add_macro_to(self.internal_packages, macro)
|
||||||
|
else:
|
||||||
|
# if it's not an internal package
|
||||||
|
self._add_macro_to(self.packages, macro)
|
||||||
|
# add to root_package_macros if it's in the root package
|
||||||
|
if macro.package_name == self.root_project_name:
|
||||||
|
self.root_package_macros[macro_name] = macro
|
||||||
|
|
||||||
|
def add_macros(self):
|
||||||
|
for macro in self.macros.values():
|
||||||
|
self.add_macro(macro)
|
||||||
|
|
||||||
|
def get_macro_id(self, local_package, macro_name):
|
||||||
|
local_package_macros = {}
|
||||||
|
if (
|
||||||
|
local_package not in self.internal_package_names
|
||||||
|
and local_package in self.packages
|
||||||
|
):
|
||||||
|
local_package_macros = self.packages[local_package]
|
||||||
|
# First: search the local packages for this macro
|
||||||
|
if macro_name in local_package_macros:
|
||||||
|
return local_package_macros[macro_name].unique_id
|
||||||
|
if macro_name in self.macros_by_name:
|
||||||
|
return self.macros_by_name[macro_name].unique_id
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# Currently this is just used by test processing in the schema
|
||||||
|
# parser (in connection with the MacroResolver). Future work
|
||||||
|
# will extend the use of these classes to other parsing areas.
|
||||||
|
# One of the features of this class compared to the MacroNamespace
|
||||||
|
# is that you can limit the number of macros provided to the
|
||||||
|
# context dictionary in the 'to_dict' manifest method.
|
||||||
|
class TestMacroNamespace:
|
||||||
|
def __init__(self, macro_resolver, ctx, node, thread_ctx, depends_on_macros):
|
||||||
|
self.macro_resolver = macro_resolver
|
||||||
|
self.ctx = ctx
|
||||||
|
self.node = node
|
||||||
|
self.thread_ctx = thread_ctx
|
||||||
|
local_namespace = {}
|
||||||
|
if depends_on_macros:
|
||||||
|
for macro_unique_id in depends_on_macros:
|
||||||
|
macro = self.manifest.macros[macro_unique_id]
|
||||||
|
local_namespace[macro.name] = MacroGenerator(
|
||||||
|
macro,
|
||||||
|
self.ctx,
|
||||||
|
self.node,
|
||||||
|
self.thread_ctx,
|
||||||
|
)
|
||||||
|
self.local_namespace = local_namespace
|
||||||
|
|
||||||
|
def get_from_package(
|
||||||
|
self, package_name: Optional[str], name: str
|
||||||
|
) -> Optional[MacroGenerator]:
|
||||||
|
macro = None
|
||||||
|
if package_name is None:
|
||||||
|
macro = self.macro_resolver.macros_by_name.get(name)
|
||||||
|
elif package_name == GLOBAL_PROJECT_NAME:
|
||||||
|
macro = self.macro_resolver.internal_packages_namespace.get(name)
|
||||||
|
elif package_name in self.resolver.packages:
|
||||||
|
macro = self.macro_resolver.packages[package_name].get(name)
|
||||||
|
else:
|
||||||
|
raise_compiler_error(f"Could not find package '{package_name}'")
|
||||||
|
macro_func = MacroGenerator(macro, self.ctx, self.node, self.thread_ctx)
|
||||||
|
return macro_func
|
||||||
@@ -1,13 +1,9 @@
|
|||||||
from typing import (
|
from typing import Any, Dict, Iterable, Union, Optional, List, Iterator, Mapping, Set
|
||||||
Any, Dict, Iterable, Union, Optional, List, Iterator, Mapping, Set
|
|
||||||
)
|
|
||||||
|
|
||||||
from dbt.clients.jinja import MacroGenerator, MacroStack
|
from dbt.clients.jinja import MacroGenerator, MacroStack
|
||||||
from dbt.contracts.graph.parsed import ParsedMacro
|
from dbt.contracts.graph.parsed import ParsedMacro
|
||||||
from dbt.include.global_project import PROJECT_NAME as GLOBAL_PROJECT_NAME
|
from dbt.include.global_project import PROJECT_NAME as GLOBAL_PROJECT_NAME
|
||||||
from dbt.exceptions import (
|
from dbt.exceptions import raise_duplicate_macro_name, raise_compiler_error
|
||||||
raise_duplicate_macro_name, raise_compiler_error
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
FlatNamespace = Dict[str, MacroGenerator]
|
FlatNamespace = Dict[str, MacroGenerator]
|
||||||
@@ -15,6 +11,10 @@ NamespaceMember = Union[FlatNamespace, MacroGenerator]
|
|||||||
FullNamespace = Dict[str, NamespaceMember]
|
FullNamespace = Dict[str, NamespaceMember]
|
||||||
|
|
||||||
|
|
||||||
|
# The point of this class is to collect the various macros
|
||||||
|
# and provide the ability to flatten them into the ManifestContexts
|
||||||
|
# that are created for jinja, so that macro calls can be resolved.
|
||||||
|
# Creates special iterators and _keys methods to flatten the lists.
|
||||||
class MacroNamespace(Mapping):
|
class MacroNamespace(Mapping):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -37,12 +37,16 @@ class MacroNamespace(Mapping):
|
|||||||
}
|
}
|
||||||
yield self.global_project_namespace
|
yield self.global_project_namespace
|
||||||
|
|
||||||
|
# provides special keys method for MacroNamespace iterator
|
||||||
|
# returns keys from local_namespace, global_namespace, packages,
|
||||||
|
# global_project_namespace
|
||||||
def _keys(self) -> Set[str]:
|
def _keys(self) -> Set[str]:
|
||||||
keys: Set[str] = set()
|
keys: Set[str] = set()
|
||||||
for search in self._search_order():
|
for search in self._search_order():
|
||||||
keys.update(search)
|
keys.update(search)
|
||||||
return keys
|
return keys
|
||||||
|
|
||||||
|
# special iterator using special keys
|
||||||
def __iter__(self) -> Iterator[str]:
|
def __iter__(self) -> Iterator[str]:
|
||||||
for key in self._keys():
|
for key in self._keys():
|
||||||
yield key
|
yield key
|
||||||
@@ -67,11 +71,13 @@ class MacroNamespace(Mapping):
|
|||||||
elif package_name in self.packages:
|
elif package_name in self.packages:
|
||||||
return self.packages[package_name].get(name)
|
return self.packages[package_name].get(name)
|
||||||
else:
|
else:
|
||||||
raise_compiler_error(
|
raise_compiler_error(f"Could not find package '{package_name}'")
|
||||||
f"Could not find package '{package_name}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
# This class builds the MacroNamespace by adding macros to
|
||||||
|
# internal_packages or packages, and locals/globals.
|
||||||
|
# Call 'build_namespace' to return a MacroNamespace.
|
||||||
|
# This is used by ManifestContext (and subclasses)
|
||||||
class MacroNamespaceBuilder:
|
class MacroNamespaceBuilder:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -83,10 +89,15 @@ class MacroNamespaceBuilder:
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.root_package = root_package
|
self.root_package = root_package
|
||||||
self.search_package = search_package
|
self.search_package = search_package
|
||||||
|
# internal packages comes from get_adapter_package_names
|
||||||
self.internal_package_names = set(internal_packages)
|
self.internal_package_names = set(internal_packages)
|
||||||
self.internal_package_names_order = internal_packages
|
self.internal_package_names_order = internal_packages
|
||||||
|
# macro_func is added here if in root package
|
||||||
self.globals: FlatNamespace = {}
|
self.globals: FlatNamespace = {}
|
||||||
|
# macro_func is added here if it's the package for this node
|
||||||
self.locals: FlatNamespace = {}
|
self.locals: FlatNamespace = {}
|
||||||
|
# Create a dictionary of [package name][macro name] =
|
||||||
|
# MacroGenerator object which acts like a function
|
||||||
self.internal_packages: Dict[str, FlatNamespace] = {}
|
self.internal_packages: Dict[str, FlatNamespace] = {}
|
||||||
self.packages: Dict[str, FlatNamespace] = {}
|
self.packages: Dict[str, FlatNamespace] = {}
|
||||||
self.thread_ctx = thread_ctx
|
self.thread_ctx = thread_ctx
|
||||||
@@ -94,25 +105,26 @@ class MacroNamespaceBuilder:
|
|||||||
|
|
||||||
def _add_macro_to(
|
def _add_macro_to(
|
||||||
self,
|
self,
|
||||||
heirarchy: Dict[str, FlatNamespace],
|
hierarchy: Dict[str, FlatNamespace],
|
||||||
macro: ParsedMacro,
|
macro: ParsedMacro,
|
||||||
macro_func: MacroGenerator,
|
macro_func: MacroGenerator,
|
||||||
):
|
):
|
||||||
if macro.package_name in heirarchy:
|
if macro.package_name in hierarchy:
|
||||||
namespace = heirarchy[macro.package_name]
|
namespace = hierarchy[macro.package_name]
|
||||||
else:
|
else:
|
||||||
namespace = {}
|
namespace = {}
|
||||||
heirarchy[macro.package_name] = namespace
|
hierarchy[macro.package_name] = namespace
|
||||||
|
|
||||||
if macro.name in namespace:
|
if macro.name in namespace:
|
||||||
raise_duplicate_macro_name(
|
raise_duplicate_macro_name(macro_func.macro, macro, macro.package_name)
|
||||||
macro_func.macro, macro, macro.package_name
|
hierarchy[macro.package_name][macro.name] = macro_func
|
||||||
)
|
|
||||||
heirarchy[macro.package_name][macro.name] = macro_func
|
|
||||||
|
|
||||||
def add_macro(self, macro: ParsedMacro, ctx: Dict[str, Any]):
|
def add_macro(self, macro: ParsedMacro, ctx: Dict[str, Any]):
|
||||||
macro_name: str = macro.name
|
macro_name: str = macro.name
|
||||||
|
|
||||||
|
# MacroGenerator is in clients/jinja.py
|
||||||
|
# a MacroGenerator object is a callable object that will
|
||||||
|
# execute the MacroGenerator.__call__ function
|
||||||
macro_func: MacroGenerator = MacroGenerator(
|
macro_func: MacroGenerator = MacroGenerator(
|
||||||
macro, ctx, self.node, self.thread_ctx
|
macro, ctx, self.node, self.thread_ctx
|
||||||
)
|
)
|
||||||
@@ -122,10 +134,12 @@ class MacroNamespaceBuilder:
|
|||||||
if macro.package_name in self.internal_package_names:
|
if macro.package_name in self.internal_package_names:
|
||||||
self._add_macro_to(self.internal_packages, macro, macro_func)
|
self._add_macro_to(self.internal_packages, macro, macro_func)
|
||||||
else:
|
else:
|
||||||
|
# if it's not an internal package
|
||||||
self._add_macro_to(self.packages, macro, macro_func)
|
self._add_macro_to(self.packages, macro, macro_func)
|
||||||
|
# add to locals if it's the package this node is in
|
||||||
if macro.package_name == self.search_package:
|
if macro.package_name == self.search_package:
|
||||||
self.locals[macro_name] = macro_func
|
self.locals[macro_name] = macro_func
|
||||||
|
# add to globals if it's in the root package
|
||||||
elif macro.package_name == self.root_package:
|
elif macro.package_name == self.root_package:
|
||||||
self.globals[macro_name] = macro_func
|
self.globals[macro_name] = macro_func
|
||||||
|
|
||||||
@@ -143,6 +157,7 @@ class MacroNamespaceBuilder:
|
|||||||
global_project_namespace: FlatNamespace = {}
|
global_project_namespace: FlatNamespace = {}
|
||||||
for pkg in reversed(self.internal_package_names_order):
|
for pkg in reversed(self.internal_package_names_order):
|
||||||
if pkg in self.internal_packages:
|
if pkg in self.internal_packages:
|
||||||
|
# add the macros pointed to by this package name
|
||||||
global_project_namespace.update(self.internal_packages[pkg])
|
global_project_namespace.update(self.internal_packages[pkg])
|
||||||
|
|
||||||
return MacroNamespace(
|
return MacroNamespace(
|
||||||
|
|||||||
@@ -2,7 +2,8 @@ from typing import List
|
|||||||
|
|
||||||
from dbt.clients.jinja import MacroStack
|
from dbt.clients.jinja import MacroStack
|
||||||
from dbt.contracts.connection import AdapterRequiredConfig
|
from dbt.contracts.connection import AdapterRequiredConfig
|
||||||
from dbt.contracts.graph.manifest import Manifest
|
from dbt.contracts.graph.manifest import Manifest, AnyManifest
|
||||||
|
from dbt.context.macro_resolver import TestMacroNamespace
|
||||||
|
|
||||||
|
|
||||||
from .configured import ConfiguredContext
|
from .configured import ConfiguredContext
|
||||||
@@ -16,25 +17,33 @@ class ManifestContext(ConfiguredContext):
|
|||||||
The given macros can override any previous context values, which will be
|
The given macros can override any previous context values, which will be
|
||||||
available as if they were accessed relative to the package name.
|
available as if they were accessed relative to the package name.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: AdapterRequiredConfig,
|
config: AdapterRequiredConfig,
|
||||||
manifest: Manifest,
|
manifest: AnyManifest,
|
||||||
search_package: str,
|
search_package: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.manifest = manifest
|
self.manifest = manifest
|
||||||
|
# this is the package of the node for which this context was built
|
||||||
self.search_package = search_package
|
self.search_package = search_package
|
||||||
self.macro_stack = MacroStack()
|
self.macro_stack = MacroStack()
|
||||||
|
# This namespace is used by the BaseDatabaseWrapper in jinja rendering.
|
||||||
|
# The namespace is passed to it when it's constructed. It expects
|
||||||
|
# to be able to do: namespace.get_from_package(..)
|
||||||
|
self.namespace = self._build_namespace()
|
||||||
|
|
||||||
|
def _build_namespace(self):
|
||||||
|
# this takes all the macros in the manifest and adds them
|
||||||
|
# to the MacroNamespaceBuilder stored in self.namespace
|
||||||
builder = self._get_namespace_builder()
|
builder = self._get_namespace_builder()
|
||||||
self.namespace = builder.build_namespace(
|
return builder.build_namespace(self.manifest.macros.values(), self._ctx)
|
||||||
self.manifest.macros.values(),
|
|
||||||
self._ctx,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_namespace_builder(self) -> MacroNamespaceBuilder:
|
def _get_namespace_builder(self) -> MacroNamespaceBuilder:
|
||||||
# avoid an import loop
|
# avoid an import loop
|
||||||
from dbt.adapters.factory import get_adapter_package_names
|
from dbt.adapters.factory import get_adapter_package_names
|
||||||
|
|
||||||
internal_packages: List[str] = get_adapter_package_names(
|
internal_packages: List[str] = get_adapter_package_names(
|
||||||
self.config.credentials.type
|
self.config.credentials.type
|
||||||
)
|
)
|
||||||
@@ -46,21 +55,23 @@ class ManifestContext(ConfiguredContext):
|
|||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# This does not use the Mashumaro code
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
dct = super().to_dict()
|
dct = super().to_dict()
|
||||||
dct.update(self.namespace)
|
# This moves all of the macros in the 'namespace' into top level
|
||||||
|
# keys in the manifest dictionary
|
||||||
|
if isinstance(self.namespace, TestMacroNamespace):
|
||||||
|
dct.update(self.namespace.local_namespace)
|
||||||
|
else:
|
||||||
|
dct.update(self.namespace)
|
||||||
return dct
|
return dct
|
||||||
|
|
||||||
|
|
||||||
class QueryHeaderContext(ManifestContext):
|
class QueryHeaderContext(ManifestContext):
|
||||||
def __init__(
|
def __init__(self, config: AdapterRequiredConfig, manifest: Manifest) -> None:
|
||||||
self, config: AdapterRequiredConfig, manifest: Manifest
|
|
||||||
) -> None:
|
|
||||||
super().__init__(config, manifest, config.project_name)
|
super().__init__(config, manifest, config.project_name)
|
||||||
|
|
||||||
|
|
||||||
def generate_query_header_context(
|
def generate_query_header_context(config: AdapterRequiredConfig, manifest: Manifest):
|
||||||
config: AdapterRequiredConfig, manifest: Manifest
|
|
||||||
):
|
|
||||||
ctx = QueryHeaderContext(config, manifest)
|
ctx = QueryHeaderContext(config, manifest)
|
||||||
return ctx.to_dict()
|
return ctx.to_dict()
|
||||||
|
|||||||
@@ -1,7 +1,15 @@
|
|||||||
import abc
|
import abc
|
||||||
import os
|
import os
|
||||||
from typing import (
|
from typing import (
|
||||||
Callable, Any, Dict, Optional, Union, List, TypeVar, Type, Iterable,
|
Callable,
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
List,
|
||||||
|
TypeVar,
|
||||||
|
Type,
|
||||||
|
Iterable,
|
||||||
Mapping,
|
Mapping,
|
||||||
)
|
)
|
||||||
from typing_extensions import Protocol
|
from typing_extensions import Protocol
|
||||||
@@ -10,15 +18,16 @@ from dbt import deprecations
|
|||||||
from dbt.adapters.base.column import Column
|
from dbt.adapters.base.column import Column
|
||||||
from dbt.adapters.factory import get_adapter, get_adapter_package_names
|
from dbt.adapters.factory import get_adapter, get_adapter_package_names
|
||||||
from dbt.clients import agate_helper
|
from dbt.clients import agate_helper
|
||||||
from dbt.clients.jinja import get_rendered, MacroGenerator
|
from dbt.clients.jinja import get_rendered, MacroGenerator, MacroStack
|
||||||
from dbt.config import RuntimeConfig, Project
|
from dbt.config import RuntimeConfig, Project
|
||||||
from .base import contextmember, contextproperty, Var
|
from dbt.context import contextmember, contextproperty, Var
|
||||||
from .configured import FQNLookup
|
from .configured import FQNLookup
|
||||||
from .context_config import ContextConfig
|
from .context_config import ContextConfig
|
||||||
|
from dbt.context.macro_resolver import MacroResolver, TestMacroNamespace
|
||||||
from .macros import MacroNamespaceBuilder, MacroNamespace
|
from .macros import MacroNamespaceBuilder, MacroNamespace
|
||||||
from .manifest import ManifestContext
|
from .manifest import ManifestContext
|
||||||
from dbt.contracts.graph.manifest import Manifest, Disabled
|
|
||||||
from dbt.contracts.connection import AdapterResponse
|
from dbt.contracts.connection import AdapterResponse
|
||||||
|
from dbt.contracts.graph.manifest import Manifest, AnyManifest, Disabled, MacroManifest
|
||||||
from dbt.contracts.graph.compiled import (
|
from dbt.contracts.graph.compiled import (
|
||||||
CompiledResource,
|
CompiledResource,
|
||||||
CompiledSeedNode,
|
CompiledSeedNode,
|
||||||
@@ -47,9 +56,7 @@ from dbt.config import IsFQNResource
|
|||||||
from dbt.logger import GLOBAL_LOGGER as logger # noqa
|
from dbt.logger import GLOBAL_LOGGER as logger # noqa
|
||||||
from dbt.node_types import NodeType
|
from dbt.node_types import NodeType
|
||||||
|
|
||||||
from dbt.utils import (
|
from dbt.utils import merge, AttrDict, MultiDict
|
||||||
merge, AttrDict, MultiDict
|
|
||||||
)
|
|
||||||
|
|
||||||
import agate
|
import agate
|
||||||
|
|
||||||
@@ -72,9 +79,8 @@ class RelationProxy:
|
|||||||
return self._relation_type.create_from_source(*args, **kwargs)
|
return self._relation_type.create_from_source(*args, **kwargs)
|
||||||
|
|
||||||
def create(self, *args, **kwargs):
|
def create(self, *args, **kwargs):
|
||||||
kwargs['quote_policy'] = merge(
|
kwargs["quote_policy"] = merge(
|
||||||
self._quoting_config,
|
self._quoting_config, kwargs.pop("quote_policy", {})
|
||||||
kwargs.pop('quote_policy', {})
|
|
||||||
)
|
)
|
||||||
return self._relation_type.create(*args, **kwargs)
|
return self._relation_type.create(*args, **kwargs)
|
||||||
|
|
||||||
@@ -91,7 +97,7 @@ class BaseDatabaseWrapper:
|
|||||||
self._namespace = namespace
|
self._namespace = namespace
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
raise NotImplementedError('subclasses need to implement this')
|
raise NotImplementedError("subclasses need to implement this")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def config(self):
|
def config(self):
|
||||||
@@ -107,7 +113,7 @@ class BaseDatabaseWrapper:
|
|||||||
# a future version of this could have plugins automatically call fall
|
# a future version of this could have plugins automatically call fall
|
||||||
# back to their dependencies' dependencies by using
|
# back to their dependencies' dependencies by using
|
||||||
# `get_adapter_type_names` instead of `[self.config.credentials.type]`
|
# `get_adapter_type_names` instead of `[self.config.credentials.type]`
|
||||||
search_prefixes = [self._adapter.type(), 'default']
|
search_prefixes = [self._adapter.type(), "default"]
|
||||||
return search_prefixes
|
return search_prefixes
|
||||||
|
|
||||||
def dispatch(
|
def dispatch(
|
||||||
@@ -115,8 +121,8 @@ class BaseDatabaseWrapper:
|
|||||||
) -> MacroGenerator:
|
) -> MacroGenerator:
|
||||||
search_packages: List[Optional[str]]
|
search_packages: List[Optional[str]]
|
||||||
|
|
||||||
if '.' in macro_name:
|
if "." in macro_name:
|
||||||
suggest_package, suggest_macro_name = macro_name.split('.', 1)
|
suggest_package, suggest_macro_name = macro_name.split(".", 1)
|
||||||
msg = (
|
msg = (
|
||||||
f'In adapter.dispatch, got a macro name of "{macro_name}", '
|
f'In adapter.dispatch, got a macro name of "{macro_name}", '
|
||||||
f'but "." is not a valid macro name component. Did you mean '
|
f'but "." is not a valid macro name component. Did you mean '
|
||||||
@@ -129,7 +135,7 @@ class BaseDatabaseWrapper:
|
|||||||
search_packages = [None]
|
search_packages = [None]
|
||||||
elif isinstance(packages, str):
|
elif isinstance(packages, str):
|
||||||
raise CompilationException(
|
raise CompilationException(
|
||||||
f'In adapter.dispatch, got a string packages argument '
|
f"In adapter.dispatch, got a string packages argument "
|
||||||
f'("{packages}"), but packages should be None or a list.'
|
f'("{packages}"), but packages should be None or a list.'
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -139,25 +145,24 @@ class BaseDatabaseWrapper:
|
|||||||
|
|
||||||
for package_name in search_packages:
|
for package_name in search_packages:
|
||||||
for prefix in self._get_adapter_macro_prefixes():
|
for prefix in self._get_adapter_macro_prefixes():
|
||||||
search_name = f'{prefix}__{macro_name}'
|
search_name = f"{prefix}__{macro_name}"
|
||||||
try:
|
try:
|
||||||
macro = self._namespace.get_from_package(
|
# this uses the namespace from the context
|
||||||
package_name, search_name
|
macro = self._namespace.get_from_package(package_name, search_name)
|
||||||
)
|
|
||||||
except CompilationException as exc:
|
except CompilationException as exc:
|
||||||
raise CompilationException(
|
raise CompilationException(
|
||||||
f'In dispatch: {exc.msg}',
|
f"In dispatch: {exc.msg}",
|
||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
if package_name is None:
|
if package_name is None:
|
||||||
attempts.append(search_name)
|
attempts.append(search_name)
|
||||||
else:
|
else:
|
||||||
attempts.append(f'{package_name}.{search_name}')
|
attempts.append(f"{package_name}.{search_name}")
|
||||||
|
|
||||||
if macro is not None:
|
if macro is not None:
|
||||||
return macro
|
return macro
|
||||||
|
|
||||||
searched = ', '.join(repr(a) for a in attempts)
|
searched = ", ".join(repr(a) for a in attempts)
|
||||||
msg = (
|
msg = (
|
||||||
f"In dispatch: No macro named '{macro_name}' found\n"
|
f"In dispatch: No macro named '{macro_name}' found\n"
|
||||||
f" Searched for: {searched}"
|
f" Searched for: {searched}"
|
||||||
@@ -187,14 +192,10 @@ class BaseResolver(metaclass=abc.ABCMeta):
|
|||||||
|
|
||||||
class BaseRefResolver(BaseResolver):
|
class BaseRefResolver(BaseResolver):
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def resolve(
|
def resolve(self, name: str, package: Optional[str] = None) -> RelationProxy:
|
||||||
self, name: str, package: Optional[str] = None
|
|
||||||
) -> RelationProxy:
|
|
||||||
...
|
...
|
||||||
|
|
||||||
def _repack_args(
|
def _repack_args(self, name: str, package: Optional[str]) -> List[str]:
|
||||||
self, name: str, package: Optional[str]
|
|
||||||
) -> List[str]:
|
|
||||||
if package is None:
|
if package is None:
|
||||||
return [name]
|
return [name]
|
||||||
else:
|
else:
|
||||||
@@ -203,14 +204,13 @@ class BaseRefResolver(BaseResolver):
|
|||||||
def validate_args(self, name: str, package: Optional[str]):
|
def validate_args(self, name: str, package: Optional[str]):
|
||||||
if not isinstance(name, str):
|
if not isinstance(name, str):
|
||||||
raise CompilationException(
|
raise CompilationException(
|
||||||
f'The name argument to ref() must be a string, got '
|
f"The name argument to ref() must be a string, got " f"{type(name)}"
|
||||||
f'{type(name)}'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if package is not None and not isinstance(package, str):
|
if package is not None and not isinstance(package, str):
|
||||||
raise CompilationException(
|
raise CompilationException(
|
||||||
f'The package argument to ref() must be a string or None, got '
|
f"The package argument to ref() must be a string or None, got "
|
||||||
f'{type(package)}'
|
f"{type(package)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, *args: str) -> RelationProxy:
|
def __call__(self, *args: str) -> RelationProxy:
|
||||||
@@ -235,20 +235,19 @@ class BaseSourceResolver(BaseResolver):
|
|||||||
def validate_args(self, source_name: str, table_name: str):
|
def validate_args(self, source_name: str, table_name: str):
|
||||||
if not isinstance(source_name, str):
|
if not isinstance(source_name, str):
|
||||||
raise CompilationException(
|
raise CompilationException(
|
||||||
f'The source name (first) argument to source() must be a '
|
f"The source name (first) argument to source() must be a "
|
||||||
f'string, got {type(source_name)}'
|
f"string, got {type(source_name)}"
|
||||||
)
|
)
|
||||||
if not isinstance(table_name, str):
|
if not isinstance(table_name, str):
|
||||||
raise CompilationException(
|
raise CompilationException(
|
||||||
f'The table name (second) argument to source() must be a '
|
f"The table name (second) argument to source() must be a "
|
||||||
f'string, got {type(table_name)}'
|
f"string, got {type(table_name)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, *args: str) -> RelationProxy:
|
def __call__(self, *args: str) -> RelationProxy:
|
||||||
if len(args) != 2:
|
if len(args) != 2:
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
f"source() takes exactly two arguments ({len(args)} given)",
|
f"source() takes exactly two arguments ({len(args)} given)", self.model
|
||||||
self.model
|
|
||||||
)
|
)
|
||||||
self.validate_args(args[0], args[1])
|
self.validate_args(args[0], args[1])
|
||||||
return self.resolve(args[0], args[1])
|
return self.resolve(args[0], args[1])
|
||||||
@@ -266,14 +265,15 @@ class ParseConfigObject(Config):
|
|||||||
self.context_config = context_config
|
self.context_config = context_config
|
||||||
|
|
||||||
def _transform_config(self, config):
|
def _transform_config(self, config):
|
||||||
for oldkey in ('pre_hook', 'post_hook'):
|
for oldkey in ("pre_hook", "post_hook"):
|
||||||
if oldkey in config:
|
if oldkey in config:
|
||||||
newkey = oldkey.replace('_', '-')
|
newkey = oldkey.replace("_", "-")
|
||||||
if newkey in config:
|
if newkey in config:
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
'Invalid config, has conflicting keys "{}" and "{}"'
|
'Invalid config, has conflicting keys "{}" and "{}"'.format(
|
||||||
.format(oldkey, newkey),
|
oldkey, newkey
|
||||||
self.model
|
),
|
||||||
|
self.model,
|
||||||
)
|
)
|
||||||
config[newkey] = config.pop(oldkey)
|
config[newkey] = config.pop(oldkey)
|
||||||
return config
|
return config
|
||||||
@@ -284,29 +284,25 @@ class ParseConfigObject(Config):
|
|||||||
elif len(args) == 0 and len(kwargs) > 0:
|
elif len(args) == 0 and len(kwargs) > 0:
|
||||||
opts = kwargs
|
opts = kwargs
|
||||||
else:
|
else:
|
||||||
raise_compiler_error(
|
raise_compiler_error("Invalid inline model config", self.model)
|
||||||
"Invalid inline model config",
|
|
||||||
self.model)
|
|
||||||
|
|
||||||
opts = self._transform_config(opts)
|
opts = self._transform_config(opts)
|
||||||
|
|
||||||
# it's ok to have a parse context with no context config, but you must
|
# it's ok to have a parse context with no context config, but you must
|
||||||
# not call it!
|
# not call it!
|
||||||
if self.context_config is None:
|
if self.context_config is None:
|
||||||
raise RuntimeException(
|
raise RuntimeException("At parse time, did not receive a context config")
|
||||||
'At parse time, did not receive a context config'
|
|
||||||
)
|
|
||||||
self.context_config.update_in_model_config(opts)
|
self.context_config.update_in_model_config(opts)
|
||||||
return ''
|
return ""
|
||||||
|
|
||||||
def set(self, name, value):
|
def set(self, name, value):
|
||||||
return self.__call__({name: value})
|
return self.__call__({name: value})
|
||||||
|
|
||||||
def require(self, name, validator=None):
|
def require(self, name, validator=None):
|
||||||
return ''
|
return ""
|
||||||
|
|
||||||
def get(self, name, validator=None, default=None):
|
def get(self, name, validator=None, default=None):
|
||||||
return ''
|
return ""
|
||||||
|
|
||||||
def persist_relation_docs(self) -> bool:
|
def persist_relation_docs(self) -> bool:
|
||||||
return False
|
return False
|
||||||
@@ -316,14 +312,12 @@ class ParseConfigObject(Config):
|
|||||||
|
|
||||||
|
|
||||||
class RuntimeConfigObject(Config):
|
class RuntimeConfigObject(Config):
|
||||||
def __init__(
|
def __init__(self, model, context_config: Optional[ContextConfig] = None):
|
||||||
self, model, context_config: Optional[ContextConfig] = None
|
|
||||||
):
|
|
||||||
self.model = model
|
self.model = model
|
||||||
# we never use or get a config, only the parser cares
|
# we never use or get a config, only the parser cares
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
return ''
|
return ""
|
||||||
|
|
||||||
def set(self, name, value):
|
def set(self, name, value):
|
||||||
return self.__call__({name: value})
|
return self.__call__({name: value})
|
||||||
@@ -333,7 +327,7 @@ class RuntimeConfigObject(Config):
|
|||||||
|
|
||||||
def _lookup(self, name, default=_MISSING):
|
def _lookup(self, name, default=_MISSING):
|
||||||
# if this is a macro, there might be no `model.config`.
|
# if this is a macro, there might be no `model.config`.
|
||||||
if not hasattr(self.model, 'config'):
|
if not hasattr(self.model, "config"):
|
||||||
result = default
|
result = default
|
||||||
else:
|
else:
|
||||||
result = self.model.config.get(name, default)
|
result = self.model.config.get(name, default)
|
||||||
@@ -358,22 +352,24 @@ class RuntimeConfigObject(Config):
|
|||||||
return to_return
|
return to_return
|
||||||
|
|
||||||
def persist_relation_docs(self) -> bool:
|
def persist_relation_docs(self) -> bool:
|
||||||
persist_docs = self.get('persist_docs', default={})
|
persist_docs = self.get("persist_docs", default={})
|
||||||
if not isinstance(persist_docs, dict):
|
if not isinstance(persist_docs, dict):
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
f"Invalid value provided for 'persist_docs'. Expected dict "
|
f"Invalid value provided for 'persist_docs'. Expected dict "
|
||||||
f"but received {type(persist_docs)}")
|
f"but received {type(persist_docs)}"
|
||||||
|
)
|
||||||
|
|
||||||
return persist_docs.get('relation', False)
|
return persist_docs.get("relation", False)
|
||||||
|
|
||||||
def persist_column_docs(self) -> bool:
|
def persist_column_docs(self) -> bool:
|
||||||
persist_docs = self.get('persist_docs', default={})
|
persist_docs = self.get("persist_docs", default={})
|
||||||
if not isinstance(persist_docs, dict):
|
if not isinstance(persist_docs, dict):
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
f"Invalid value provided for 'persist_docs'. Expected dict "
|
f"Invalid value provided for 'persist_docs'. Expected dict "
|
||||||
f"but received {type(persist_docs)}")
|
f"but received {type(persist_docs)}"
|
||||||
|
)
|
||||||
|
|
||||||
return persist_docs.get('columns', False)
|
return persist_docs.get("columns", False)
|
||||||
|
|
||||||
|
|
||||||
# `adapter` implementations
|
# `adapter` implementations
|
||||||
@@ -383,8 +379,10 @@ class ParseDatabaseWrapper(BaseDatabaseWrapper):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
override = (name in self._adapter._available_ and
|
override = (
|
||||||
name in self._adapter._parse_replacements_)
|
name in self._adapter._available_
|
||||||
|
and name in self._adapter._parse_replacements_
|
||||||
|
)
|
||||||
|
|
||||||
if override:
|
if override:
|
||||||
return self._adapter._parse_replacements_[name]
|
return self._adapter._parse_replacements_[name]
|
||||||
@@ -416,9 +414,7 @@ class RuntimeDatabaseWrapper(BaseDatabaseWrapper):
|
|||||||
|
|
||||||
# `ref` implementations
|
# `ref` implementations
|
||||||
class ParseRefResolver(BaseRefResolver):
|
class ParseRefResolver(BaseRefResolver):
|
||||||
def resolve(
|
def resolve(self, name: str, package: Optional[str] = None) -> RelationProxy:
|
||||||
self, name: str, package: Optional[str] = None
|
|
||||||
) -> RelationProxy:
|
|
||||||
self.model.refs.append(self._repack_args(name, package))
|
self.model.refs.append(self._repack_args(name, package))
|
||||||
|
|
||||||
return self.Relation.create_from(self.config, self.model)
|
return self.Relation.create_from(self.config, self.model)
|
||||||
@@ -448,22 +444,15 @@ class RuntimeRefResolver(BaseRefResolver):
|
|||||||
self.validate(target_model, target_name, target_package)
|
self.validate(target_model, target_name, target_package)
|
||||||
return self.create_relation(target_model, target_name)
|
return self.create_relation(target_model, target_name)
|
||||||
|
|
||||||
def create_relation(
|
def create_relation(self, target_model: ManifestNode, name: str) -> RelationProxy:
|
||||||
self, target_model: ManifestNode, name: str
|
|
||||||
) -> RelationProxy:
|
|
||||||
if target_model.is_ephemeral_model:
|
if target_model.is_ephemeral_model:
|
||||||
self.model.set_cte(target_model.unique_id, None)
|
self.model.set_cte(target_model.unique_id, None)
|
||||||
return self.Relation.create_ephemeral_from_node(
|
return self.Relation.create_ephemeral_from_node(self.config, target_model)
|
||||||
self.config, target_model
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return self.Relation.create_from(self.config, target_model)
|
return self.Relation.create_from(self.config, target_model)
|
||||||
|
|
||||||
def validate(
|
def validate(
|
||||||
self,
|
self, resolved: ManifestNode, target_name: str, target_package: Optional[str]
|
||||||
resolved: ManifestNode,
|
|
||||||
target_name: str,
|
|
||||||
target_package: Optional[str]
|
|
||||||
) -> None:
|
) -> None:
|
||||||
if resolved.unique_id not in self.model.depends_on.nodes:
|
if resolved.unique_id not in self.model.depends_on.nodes:
|
||||||
args = self._repack_args(target_name, target_package)
|
args = self._repack_args(target_name, target_package)
|
||||||
@@ -479,16 +468,15 @@ class OperationRefResolver(RuntimeRefResolver):
|
|||||||
) -> None:
|
) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def create_relation(
|
def create_relation(self, target_model: ManifestNode, name: str) -> RelationProxy:
|
||||||
self, target_model: ManifestNode, name: str
|
|
||||||
) -> RelationProxy:
|
|
||||||
if target_model.is_ephemeral_model:
|
if target_model.is_ephemeral_model:
|
||||||
# In operations, we can't ref() ephemeral nodes, because
|
# In operations, we can't ref() ephemeral nodes, because
|
||||||
# ParsedMacros do not support set_cte
|
# ParsedMacros do not support set_cte
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
'Operations can not ref() ephemeral nodes, but {} is ephemeral'
|
"Operations can not ref() ephemeral nodes, but {} is ephemeral".format(
|
||||||
.format(target_model.name),
|
target_model.name
|
||||||
self.model
|
),
|
||||||
|
self.model,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return super().create_relation(target_model, name)
|
return super().create_relation(target_model, name)
|
||||||
@@ -540,8 +528,7 @@ class ModelConfiguredVar(Var):
|
|||||||
if package_name not in dependencies:
|
if package_name not in dependencies:
|
||||||
# I don't think this is actually reachable
|
# I don't think this is actually reachable
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
f'Node package named {package_name} not found!',
|
f"Node package named {package_name} not found!", self._node
|
||||||
self._node
|
|
||||||
)
|
)
|
||||||
yield dependencies[package_name]
|
yield dependencies[package_name]
|
||||||
yield self._config
|
yield self._config
|
||||||
@@ -613,7 +600,7 @@ class OperationProvider(RuntimeProvider):
|
|||||||
ref = OperationRefResolver
|
ref = OperationRefResolver
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar('T')
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
# Base context collection, used for parsing configs.
|
# Base context collection, used for parsing configs.
|
||||||
@@ -627,9 +614,7 @@ class ProviderContext(ManifestContext):
|
|||||||
context_config: Optional[ContextConfig],
|
context_config: Optional[ContextConfig],
|
||||||
) -> None:
|
) -> None:
|
||||||
if provider is None:
|
if provider is None:
|
||||||
raise InternalException(
|
raise InternalException(f"Invalid provider given to context: {provider}")
|
||||||
f"Invalid provider given to context: {provider}"
|
|
||||||
)
|
|
||||||
# mypy appeasement - we know it'll be a RuntimeConfig
|
# mypy appeasement - we know it'll be a RuntimeConfig
|
||||||
self.config: RuntimeConfig
|
self.config: RuntimeConfig
|
||||||
self.model: Union[ParsedMacro, ManifestNode] = model
|
self.model: Union[ParsedMacro, ManifestNode] = model
|
||||||
@@ -638,14 +623,13 @@ class ProviderContext(ManifestContext):
|
|||||||
self.context_config: Optional[ContextConfig] = context_config
|
self.context_config: Optional[ContextConfig] = context_config
|
||||||
self.provider: Provider = provider
|
self.provider: Provider = provider
|
||||||
self.adapter = get_adapter(self.config)
|
self.adapter = get_adapter(self.config)
|
||||||
self.db_wrapper = self.provider.DatabaseWrapper(
|
# The macro namespace is used in creating the DatabaseWrapper
|
||||||
self.adapter, self.namespace
|
self.db_wrapper = self.provider.DatabaseWrapper(self.adapter, self.namespace)
|
||||||
)
|
|
||||||
|
|
||||||
|
# This overrides the method in ManifestContext, and provides
|
||||||
|
# a model, which the ManifestContext builder does not
|
||||||
def _get_namespace_builder(self):
|
def _get_namespace_builder(self):
|
||||||
internal_packages = get_adapter_package_names(
|
internal_packages = get_adapter_package_names(self.config.credentials.type)
|
||||||
self.config.credentials.type
|
|
||||||
)
|
|
||||||
return MacroNamespaceBuilder(
|
return MacroNamespaceBuilder(
|
||||||
self.config.project_name,
|
self.config.project_name,
|
||||||
self.search_package,
|
self.search_package,
|
||||||
@@ -664,19 +648,19 @@ class ProviderContext(ManifestContext):
|
|||||||
|
|
||||||
@contextmember
|
@contextmember
|
||||||
def store_result(
|
def store_result(
|
||||||
self, name: str,
|
self, name: str, response: Any, agate_table: Optional[agate.Table] = None
|
||||||
response: Any,
|
|
||||||
agate_table: Optional[agate.Table] = None
|
|
||||||
) -> str:
|
) -> str:
|
||||||
if agate_table is None:
|
if agate_table is None:
|
||||||
agate_table = agate_helper.empty_table()
|
agate_table = agate_helper.empty_table()
|
||||||
|
|
||||||
self.sql_results[name] = AttrDict({
|
self.sql_results[name] = AttrDict(
|
||||||
'response': response,
|
{
|
||||||
'data': agate_helper.as_matrix(agate_table),
|
"response": response,
|
||||||
'table': agate_table
|
"data": agate_helper.as_matrix(agate_table),
|
||||||
})
|
"table": agate_table,
|
||||||
return ''
|
}
|
||||||
|
)
|
||||||
|
return ""
|
||||||
|
|
||||||
@contextmember
|
@contextmember
|
||||||
def store_raw_result(
|
def store_raw_result(
|
||||||
@@ -685,10 +669,11 @@ class ProviderContext(ManifestContext):
|
|||||||
message=Optional[str],
|
message=Optional[str],
|
||||||
code=Optional[str],
|
code=Optional[str],
|
||||||
rows_affected=Optional[str],
|
rows_affected=Optional[str],
|
||||||
agate_table: Optional[agate.Table] = None
|
agate_table: Optional[agate.Table] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
response = AdapterResponse(
|
response = AdapterResponse(
|
||||||
_message=message, code=code, rows_affected=rows_affected)
|
_message=message, code=code, rows_affected=rows_affected
|
||||||
|
)
|
||||||
return self.store_result(name, response, agate_table)
|
return self.store_result(name, response, agate_table)
|
||||||
|
|
||||||
@contextproperty
|
@contextproperty
|
||||||
@@ -701,25 +686,28 @@ class ProviderContext(ManifestContext):
|
|||||||
elif value == arg:
|
elif value == arg:
|
||||||
return
|
return
|
||||||
raise ValidationException(
|
raise ValidationException(
|
||||||
'Expected value "{}" to be one of {}'
|
'Expected value "{}" to be one of {}'.format(
|
||||||
.format(value, ','.join(map(str, args))))
|
value, ",".join(map(str, args))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return inner
|
return inner
|
||||||
|
|
||||||
return AttrDict({
|
return AttrDict(
|
||||||
'any': validate_any,
|
{
|
||||||
})
|
"any": validate_any,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
@contextmember
|
@contextmember
|
||||||
def write(self, payload: str) -> str:
|
def write(self, payload: str) -> str:
|
||||||
# macros/source defs aren't 'writeable'.
|
# macros/source defs aren't 'writeable'.
|
||||||
if isinstance(self.model, (ParsedMacro, ParsedSourceDefinition)):
|
if isinstance(self.model, (ParsedMacro, ParsedSourceDefinition)):
|
||||||
raise_compiler_error(
|
raise_compiler_error('cannot "write" macros or sources')
|
||||||
'cannot "write" macros or sources'
|
|
||||||
)
|
|
||||||
self.model.build_path = self.model.write_node(
|
self.model.build_path = self.model.write_node(
|
||||||
self.config.target_path, 'run', payload
|
self.config.target_path, "run", payload
|
||||||
)
|
)
|
||||||
return ''
|
return ""
|
||||||
|
|
||||||
@contextmember
|
@contextmember
|
||||||
def render(self, string: str) -> str:
|
def render(self, string: str) -> str:
|
||||||
@@ -732,20 +720,17 @@ class ProviderContext(ManifestContext):
|
|||||||
try:
|
try:
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise_compiler_error(
|
raise_compiler_error(message_if_exception, self.model)
|
||||||
message_if_exception, self.model
|
|
||||||
)
|
|
||||||
|
|
||||||
@contextmember
|
@contextmember
|
||||||
def load_agate_table(self) -> agate.Table:
|
def load_agate_table(self) -> agate.Table:
|
||||||
if not isinstance(self.model, (ParsedSeedNode, CompiledSeedNode)):
|
if not isinstance(self.model, (ParsedSeedNode, CompiledSeedNode)):
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
'can only load_agate_table for seeds (got a {})'
|
"can only load_agate_table for seeds (got a {})".format(
|
||||||
.format(self.model.resource_type)
|
self.model.resource_type
|
||||||
|
)
|
||||||
)
|
)
|
||||||
path = os.path.join(
|
path = os.path.join(self.model.root_path, self.model.original_file_path)
|
||||||
self.model.root_path, self.model.original_file_path
|
|
||||||
)
|
|
||||||
column_types = self.model.config.column_types
|
column_types = self.model.config.column_types
|
||||||
try:
|
try:
|
||||||
table = agate_helper.from_csv(path, text_columns=column_types)
|
table = agate_helper.from_csv(path, text_columns=column_types)
|
||||||
@@ -803,7 +788,7 @@ class ProviderContext(ManifestContext):
|
|||||||
self.db_wrapper, self.model, self.config, self.manifest
|
self.db_wrapper, self.model, self.config, self.manifest
|
||||||
)
|
)
|
||||||
|
|
||||||
@contextproperty('config')
|
@contextproperty("config")
|
||||||
def ctx_config(self) -> Config:
|
def ctx_config(self) -> Config:
|
||||||
"""The `config` variable exists to handle end-user configuration for
|
"""The `config` variable exists to handle end-user configuration for
|
||||||
custom materializations. Configs like `unique_key` can be implemented
|
custom materializations. Configs like `unique_key` can be implemented
|
||||||
@@ -975,7 +960,7 @@ class ProviderContext(ManifestContext):
|
|||||||
node=self.model,
|
node=self.model,
|
||||||
)
|
)
|
||||||
|
|
||||||
@contextproperty('adapter')
|
@contextproperty("adapter")
|
||||||
def ctx_adapter(self) -> BaseDatabaseWrapper:
|
def ctx_adapter(self) -> BaseDatabaseWrapper:
|
||||||
"""`adapter` is a wrapper around the internal database adapter used by
|
"""`adapter` is a wrapper around the internal database adapter used by
|
||||||
dbt. It allows users to make calls to the database in their dbt models.
|
dbt. It allows users to make calls to the database in their dbt models.
|
||||||
@@ -987,8 +972,8 @@ class ProviderContext(ManifestContext):
|
|||||||
@contextproperty
|
@contextproperty
|
||||||
def api(self) -> Dict[str, Any]:
|
def api(self) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
'Relation': self.db_wrapper.Relation,
|
"Relation": self.db_wrapper.Relation,
|
||||||
'Column': self.adapter.Column,
|
"Column": self.adapter.Column,
|
||||||
}
|
}
|
||||||
|
|
||||||
@contextproperty
|
@contextproperty
|
||||||
@@ -1106,9 +1091,9 @@ class ProviderContext(ManifestContext):
|
|||||||
""" # noqa
|
""" # noqa
|
||||||
return self.manifest.flat_graph
|
return self.manifest.flat_graph
|
||||||
|
|
||||||
@contextproperty('model')
|
@contextproperty("model")
|
||||||
def ctx_model(self) -> Dict[str, Any]:
|
def ctx_model(self) -> Dict[str, Any]:
|
||||||
return self.model.to_dict()
|
return self.model.to_dict(omit_none=True)
|
||||||
|
|
||||||
@contextproperty
|
@contextproperty
|
||||||
def pre_hooks(self) -> Optional[List[Dict[str, Any]]]:
|
def pre_hooks(self) -> Optional[List[Dict[str, Any]]]:
|
||||||
@@ -1170,22 +1155,20 @@ class ProviderContext(ManifestContext):
|
|||||||
...
|
...
|
||||||
{%- endmacro %}
|
{%- endmacro %}
|
||||||
"""
|
"""
|
||||||
deprecations.warn('adapter-macro', macro_name=name)
|
deprecations.warn("adapter-macro", macro_name=name)
|
||||||
original_name = name
|
original_name = name
|
||||||
package_names: Optional[List[str]] = None
|
package_names: Optional[List[str]] = None
|
||||||
if '.' in name:
|
if "." in name:
|
||||||
package_name, name = name.split('.', 1)
|
package_name, name = name.split(".", 1)
|
||||||
package_names = [package_name]
|
package_names = [package_name]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
macro = self.db_wrapper.dispatch(
|
macro = self.db_wrapper.dispatch(macro_name=name, packages=package_names)
|
||||||
macro_name=name, packages=package_names
|
|
||||||
)
|
|
||||||
except CompilationException as exc:
|
except CompilationException as exc:
|
||||||
raise CompilationException(
|
raise CompilationException(
|
||||||
f'In adapter_macro: {exc.msg}\n'
|
f"In adapter_macro: {exc.msg}\n"
|
||||||
f" Original name: '{original_name}'",
|
f" Original name: '{original_name}'",
|
||||||
node=self.model
|
node=self.model,
|
||||||
) from exc
|
) from exc
|
||||||
return macro(*args, **kwargs)
|
return macro(*args, **kwargs)
|
||||||
|
|
||||||
@@ -1193,17 +1176,17 @@ class ProviderContext(ManifestContext):
|
|||||||
class MacroContext(ProviderContext):
|
class MacroContext(ProviderContext):
|
||||||
"""Internally, macros can be executed like nodes, with some restrictions:
|
"""Internally, macros can be executed like nodes, with some restrictions:
|
||||||
|
|
||||||
- they don't have have all values available that nodes do:
|
- they don't have have all values available that nodes do:
|
||||||
- 'this', 'pre_hooks', 'post_hooks', and 'sql' are missing
|
- 'this', 'pre_hooks', 'post_hooks', and 'sql' are missing
|
||||||
- 'schema' does not use any 'model' information
|
- 'schema' does not use any 'model' information
|
||||||
- they can't be configured with config() directives
|
- they can't be configured with config() directives
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: ParsedMacro,
|
model: ParsedMacro,
|
||||||
config: RuntimeConfig,
|
config: RuntimeConfig,
|
||||||
manifest: Manifest,
|
manifest: AnyManifest,
|
||||||
provider: Provider,
|
provider: Provider,
|
||||||
search_package: Optional[str],
|
search_package: Optional[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -1223,35 +1206,27 @@ class ModelContext(ProviderContext):
|
|||||||
def pre_hooks(self) -> List[Dict[str, Any]]:
|
def pre_hooks(self) -> List[Dict[str, Any]]:
|
||||||
if isinstance(self.model, ParsedSourceDefinition):
|
if isinstance(self.model, ParsedSourceDefinition):
|
||||||
return []
|
return []
|
||||||
return [
|
return [h.to_dict(omit_none=True) for h in self.model.config.pre_hook]
|
||||||
h.to_dict() for h in self.model.config.pre_hook
|
|
||||||
]
|
|
||||||
|
|
||||||
@contextproperty
|
@contextproperty
|
||||||
def post_hooks(self) -> List[Dict[str, Any]]:
|
def post_hooks(self) -> List[Dict[str, Any]]:
|
||||||
if isinstance(self.model, ParsedSourceDefinition):
|
if isinstance(self.model, ParsedSourceDefinition):
|
||||||
return []
|
return []
|
||||||
return [
|
return [h.to_dict(omit_none=True) for h in self.model.config.post_hook]
|
||||||
h.to_dict() for h in self.model.config.post_hook
|
|
||||||
]
|
|
||||||
|
|
||||||
@contextproperty
|
@contextproperty
|
||||||
def sql(self) -> Optional[str]:
|
def sql(self) -> Optional[str]:
|
||||||
if getattr(self.model, 'extra_ctes_injected', None):
|
if getattr(self.model, "extra_ctes_injected", None):
|
||||||
return self.model.compiled_sql
|
return self.model.compiled_sql
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@contextproperty
|
@contextproperty
|
||||||
def database(self) -> str:
|
def database(self) -> str:
|
||||||
return getattr(
|
return getattr(self.model, "database", self.config.credentials.database)
|
||||||
self.model, 'database', self.config.credentials.database
|
|
||||||
)
|
|
||||||
|
|
||||||
@contextproperty
|
@contextproperty
|
||||||
def schema(self) -> str:
|
def schema(self) -> str:
|
||||||
return getattr(
|
return getattr(self.model, "schema", self.config.credentials.schema)
|
||||||
self.model, 'schema', self.config.credentials.schema
|
|
||||||
)
|
|
||||||
|
|
||||||
@contextproperty
|
@contextproperty
|
||||||
def this(self) -> Optional[RelationProxy]:
|
def this(self) -> Optional[RelationProxy]:
|
||||||
@@ -1289,38 +1264,28 @@ class ModelContext(ProviderContext):
|
|||||||
return self.db_wrapper.Relation.create_from(self.config, self.model)
|
return self.db_wrapper.Relation.create_from(self.config, self.model)
|
||||||
|
|
||||||
|
|
||||||
|
# This is called by '_context_for', used in 'render_with_context'
|
||||||
def generate_parser_model(
|
def generate_parser_model(
|
||||||
model: ManifestNode,
|
model: ManifestNode,
|
||||||
config: RuntimeConfig,
|
config: RuntimeConfig,
|
||||||
manifest: Manifest,
|
manifest: MacroManifest,
|
||||||
context_config: ContextConfig,
|
context_config: ContextConfig,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
ctx = ModelContext(
|
# The __init__ method of ModelContext also initializes
|
||||||
model, config, manifest, ParseProvider(), context_config
|
# a ManifestContext object which creates a MacroNamespaceBuilder
|
||||||
)
|
# which adds every macro in the Manifest.
|
||||||
return ctx.to_dict()
|
ctx = ModelContext(model, config, manifest, ParseProvider(), context_config)
|
||||||
|
# The 'to_dict' method in ManifestContext moves all of the macro names
|
||||||
|
# in the macro 'namespace' up to top level keys
|
||||||
def generate_parser_macro(
|
|
||||||
macro: ParsedMacro,
|
|
||||||
config: RuntimeConfig,
|
|
||||||
manifest: Manifest,
|
|
||||||
package_name: Optional[str],
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
ctx = MacroContext(
|
|
||||||
macro, config, manifest, ParseProvider(), package_name
|
|
||||||
)
|
|
||||||
return ctx.to_dict()
|
return ctx.to_dict()
|
||||||
|
|
||||||
|
|
||||||
def generate_generate_component_name_macro(
|
def generate_generate_component_name_macro(
|
||||||
macro: ParsedMacro,
|
macro: ParsedMacro,
|
||||||
config: RuntimeConfig,
|
config: RuntimeConfig,
|
||||||
manifest: Manifest,
|
manifest: MacroManifest,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
ctx = MacroContext(
|
ctx = MacroContext(macro, config, manifest, GenerateNameProvider(), None)
|
||||||
macro, config, manifest, GenerateNameProvider(), None
|
|
||||||
)
|
|
||||||
return ctx.to_dict()
|
return ctx.to_dict()
|
||||||
|
|
||||||
|
|
||||||
@@ -1329,9 +1294,7 @@ def generate_runtime_model(
|
|||||||
config: RuntimeConfig,
|
config: RuntimeConfig,
|
||||||
manifest: Manifest,
|
manifest: Manifest,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
ctx = ModelContext(
|
ctx = ModelContext(model, config, manifest, RuntimeProvider(), None)
|
||||||
model, config, manifest, RuntimeProvider(), None
|
|
||||||
)
|
|
||||||
return ctx.to_dict()
|
return ctx.to_dict()
|
||||||
|
|
||||||
|
|
||||||
@@ -1341,9 +1304,7 @@ def generate_runtime_macro(
|
|||||||
manifest: Manifest,
|
manifest: Manifest,
|
||||||
package_name: Optional[str],
|
package_name: Optional[str],
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
ctx = MacroContext(
|
ctx = MacroContext(macro, config, manifest, OperationProvider(), package_name)
|
||||||
macro, config, manifest, OperationProvider(), package_name
|
|
||||||
)
|
|
||||||
return ctx.to_dict()
|
return ctx.to_dict()
|
||||||
|
|
||||||
|
|
||||||
@@ -1352,38 +1313,89 @@ class ExposureRefResolver(BaseResolver):
|
|||||||
if len(args) not in (1, 2):
|
if len(args) not in (1, 2):
|
||||||
ref_invalid_args(self.model, args)
|
ref_invalid_args(self.model, args)
|
||||||
self.model.refs.append(list(args))
|
self.model.refs.append(list(args))
|
||||||
return ''
|
return ""
|
||||||
|
|
||||||
|
|
||||||
class ExposureSourceResolver(BaseResolver):
|
class ExposureSourceResolver(BaseResolver):
|
||||||
def __call__(self, *args) -> str:
|
def __call__(self, *args) -> str:
|
||||||
if len(args) != 2:
|
if len(args) != 2:
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
f"source() takes exactly two arguments ({len(args)} given)",
|
f"source() takes exactly two arguments ({len(args)} given)", self.model
|
||||||
self.model
|
|
||||||
)
|
)
|
||||||
self.model.sources.append(list(args))
|
self.model.sources.append(list(args))
|
||||||
return ''
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def generate_parse_exposure(
|
def generate_parse_exposure(
|
||||||
exposure: ParsedExposure,
|
exposure: ParsedExposure,
|
||||||
config: RuntimeConfig,
|
config: RuntimeConfig,
|
||||||
manifest: Manifest,
|
manifest: MacroManifest,
|
||||||
package_name: str,
|
package_name: str,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
project = config.load_dependencies()[package_name]
|
project = config.load_dependencies()[package_name]
|
||||||
return {
|
return {
|
||||||
'ref': ExposureRefResolver(
|
"ref": ExposureRefResolver(
|
||||||
None,
|
None,
|
||||||
exposure,
|
exposure,
|
||||||
project,
|
project,
|
||||||
manifest,
|
manifest,
|
||||||
),
|
),
|
||||||
'source': ExposureSourceResolver(
|
"source": ExposureSourceResolver(
|
||||||
None,
|
None,
|
||||||
exposure,
|
exposure,
|
||||||
project,
|
project,
|
||||||
manifest,
|
manifest,
|
||||||
)
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# This class is currently used by the schema parser in order
|
||||||
|
# to limit the number of macros in the context by using
|
||||||
|
# the TestMacroNamespace
|
||||||
|
class TestContext(ProviderContext):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
config: RuntimeConfig,
|
||||||
|
manifest: Manifest,
|
||||||
|
provider: Provider,
|
||||||
|
context_config: Optional[ContextConfig],
|
||||||
|
macro_resolver: MacroResolver,
|
||||||
|
) -> None:
|
||||||
|
# this must be before super init so that macro_resolver exists for
|
||||||
|
# build_namespace
|
||||||
|
self.macro_resolver = macro_resolver
|
||||||
|
self.thread_ctx = MacroStack()
|
||||||
|
super().__init__(model, config, manifest, provider, context_config)
|
||||||
|
self._build_test_namespace
|
||||||
|
|
||||||
|
def _build_namespace(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# this overrides _build_namespace in ManifestContext which provides a
|
||||||
|
# complete namespace of all macros to only specify macros in the depends_on
|
||||||
|
# This only provides a namespace with macros in the test node
|
||||||
|
# 'depends_on.macros' by using the TestMacroNamespace
|
||||||
|
def _build_test_namespace(self):
|
||||||
|
depends_on_macros = []
|
||||||
|
if self.model.depends_on and self.model.depends_on.macros:
|
||||||
|
depends_on_macros = self.model.depends_on.macros
|
||||||
|
macro_namespace = TestMacroNamespace(
|
||||||
|
self.macro_resolver, self.ctx, self.node, self.thread_ctx, depends_on_macros
|
||||||
|
)
|
||||||
|
self._namespace = macro_namespace
|
||||||
|
|
||||||
|
|
||||||
|
def generate_test_context(
|
||||||
|
model: ManifestNode,
|
||||||
|
config: RuntimeConfig,
|
||||||
|
manifest: Manifest,
|
||||||
|
context_config: ContextConfig,
|
||||||
|
macro_resolver: MacroResolver,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
ctx = TestContext(
|
||||||
|
model, config, manifest, ParseProvider(), context_config, macro_resolver
|
||||||
|
)
|
||||||
|
# The 'to_dict' method in ManifestContext moves all of the macro names
|
||||||
|
# in the macro 'namespace' up to top level keys
|
||||||
|
return ctx.to_dict()
|
||||||
|
|||||||
@@ -2,9 +2,7 @@ from typing import Any, Dict
|
|||||||
|
|
||||||
from dbt.contracts.connection import HasCredentials
|
from dbt.contracts.connection import HasCredentials
|
||||||
|
|
||||||
from dbt.context.base import (
|
from dbt.context import BaseContext, contextproperty
|
||||||
BaseContext, contextproperty
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TargetContext(BaseContext):
|
class TargetContext(BaseContext):
|
||||||
|
|||||||
@@ -2,28 +2,39 @@ import abc
|
|||||||
import itertools
|
import itertools
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import (
|
from typing import (
|
||||||
Any, ClassVar, Dict, Tuple, Iterable, Optional, NewType, List, Callable,
|
Any,
|
||||||
|
ClassVar,
|
||||||
|
Dict,
|
||||||
|
Tuple,
|
||||||
|
Iterable,
|
||||||
|
Optional,
|
||||||
|
List,
|
||||||
|
Callable,
|
||||||
)
|
)
|
||||||
from typing_extensions import Protocol
|
|
||||||
|
|
||||||
from hologram import JsonSchemaMixin
|
|
||||||
from hologram.helpers import (
|
|
||||||
StrEnum, register_pattern, ExtensibleJsonSchemaMixin
|
|
||||||
)
|
|
||||||
|
|
||||||
from dbt.contracts.util import Replaceable
|
|
||||||
from dbt.exceptions import InternalException
|
from dbt.exceptions import InternalException
|
||||||
from dbt.utils import translate_aliases
|
from dbt.utils import translate_aliases
|
||||||
|
|
||||||
from dbt.logger import GLOBAL_LOGGER as logger
|
from dbt.logger import GLOBAL_LOGGER as logger
|
||||||
|
from typing_extensions import Protocol
|
||||||
|
from dbt.dataclass_schema import (
|
||||||
|
dbtClassMixin,
|
||||||
|
StrEnum,
|
||||||
|
ExtensibleDbtClassMixin,
|
||||||
|
ValidatedStringMixin,
|
||||||
|
register_pattern,
|
||||||
|
)
|
||||||
|
from dbt.contracts.util import Replaceable
|
||||||
|
|
||||||
|
|
||||||
Identifier = NewType('Identifier', str)
|
class Identifier(ValidatedStringMixin):
|
||||||
register_pattern(Identifier, r'^[A-Za-z_][A-Za-z0-9_]+$')
|
ValidationRegex = r"^[A-Za-z_][A-Za-z0-9_]+$"
|
||||||
|
|
||||||
|
|
||||||
|
# we need register_pattern for jsonschema validation
|
||||||
|
register_pattern(Identifier, r"^[A-Za-z_][A-Za-z0-9_]+$")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AdapterResponse(JsonSchemaMixin):
|
class AdapterResponse(dbtClassMixin):
|
||||||
_message: str
|
_message: str
|
||||||
code: Optional[str] = None
|
code: Optional[str] = None
|
||||||
rows_affected: Optional[int] = None
|
rows_affected: Optional[int] = None
|
||||||
@@ -33,27 +44,26 @@ class AdapterResponse(JsonSchemaMixin):
|
|||||||
|
|
||||||
|
|
||||||
class ConnectionState(StrEnum):
|
class ConnectionState(StrEnum):
|
||||||
INIT = 'init'
|
INIT = "init"
|
||||||
OPEN = 'open'
|
OPEN = "open"
|
||||||
CLOSED = 'closed'
|
CLOSED = "closed"
|
||||||
FAIL = 'fail'
|
FAIL = "fail"
|
||||||
|
|
||||||
|
|
||||||
@dataclass(init=False)
|
@dataclass(init=False)
|
||||||
class Connection(ExtensibleJsonSchemaMixin, Replaceable):
|
class Connection(ExtensibleDbtClassMixin, Replaceable):
|
||||||
type: Identifier
|
type: Identifier
|
||||||
name: Optional[str]
|
name: Optional[str] = None
|
||||||
state: ConnectionState = ConnectionState.INIT
|
state: ConnectionState = ConnectionState.INIT
|
||||||
transaction_open: bool = False
|
transaction_open: bool = False
|
||||||
# prevent serialization
|
|
||||||
_handle: Optional[Any] = None
|
_handle: Optional[Any] = None
|
||||||
_credentials: JsonSchemaMixin = field(init=False)
|
_credentials: Optional[Any] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
type: Identifier,
|
type: Identifier,
|
||||||
name: Optional[str],
|
name: Optional[str],
|
||||||
credentials: JsonSchemaMixin,
|
credentials: dbtClassMixin,
|
||||||
state: ConnectionState = ConnectionState.INIT,
|
state: ConnectionState = ConnectionState.INIT,
|
||||||
transaction_open: bool = False,
|
transaction_open: bool = False,
|
||||||
handle: Optional[Any] = None,
|
handle: Optional[Any] = None,
|
||||||
@@ -81,8 +91,7 @@ class Connection(ExtensibleJsonSchemaMixin, Replaceable):
|
|||||||
self._handle.resolve(self)
|
self._handle.resolve(self)
|
||||||
except RecursionError as exc:
|
except RecursionError as exc:
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
"A connection's open() method attempted to read the "
|
"A connection's open() method attempted to read the " "handle value"
|
||||||
"handle value"
|
|
||||||
) from exc
|
) from exc
|
||||||
return self._handle
|
return self._handle
|
||||||
|
|
||||||
@@ -101,8 +110,7 @@ class LazyHandle:
|
|||||||
|
|
||||||
def resolve(self, connection: Connection) -> Connection:
|
def resolve(self, connection: Connection) -> Connection:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
'Opening a new connection, currently in state {}'
|
"Opening a new connection, currently in state {}".format(connection.state)
|
||||||
.format(connection.state)
|
|
||||||
)
|
)
|
||||||
return self.opener(connection)
|
return self.opener(connection)
|
||||||
|
|
||||||
@@ -112,33 +120,24 @@ class LazyHandle:
|
|||||||
# for why we have type: ignore. Maybe someday dataclasses + abstract classes
|
# for why we have type: ignore. Maybe someday dataclasses + abstract classes
|
||||||
# will work.
|
# will work.
|
||||||
@dataclass # type: ignore
|
@dataclass # type: ignore
|
||||||
class Credentials(
|
class Credentials(ExtensibleDbtClassMixin, Replaceable, metaclass=abc.ABCMeta):
|
||||||
ExtensibleJsonSchemaMixin,
|
|
||||||
Replaceable,
|
|
||||||
metaclass=abc.ABCMeta
|
|
||||||
):
|
|
||||||
database: str
|
database: str
|
||||||
schema: str
|
schema: str
|
||||||
_ALIASES: ClassVar[Dict[str, str]] = field(default={}, init=False)
|
_ALIASES: ClassVar[Dict[str, str]] = field(default={}, init=False)
|
||||||
|
|
||||||
@abc.abstractproperty
|
@abc.abstractproperty
|
||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError("type not implemented for base credentials class")
|
||||||
'type not implemented for base credentials class'
|
|
||||||
)
|
|
||||||
|
|
||||||
def connection_info(
|
def connection_info(
|
||||||
self, *, with_aliases: bool = False
|
self, *, with_aliases: bool = False
|
||||||
) -> Iterable[Tuple[str, Any]]:
|
) -> Iterable[Tuple[str, Any]]:
|
||||||
"""Return an ordered iterator of key/value pairs for pretty-printing.
|
"""Return an ordered iterator of key/value pairs for pretty-printing."""
|
||||||
"""
|
as_dict = self.to_dict(omit_none=False)
|
||||||
as_dict = self.to_dict(omit_none=False, with_aliases=with_aliases)
|
|
||||||
connection_keys = set(self._connection_keys())
|
connection_keys = set(self._connection_keys())
|
||||||
aliases: List[str] = []
|
aliases: List[str] = []
|
||||||
if with_aliases:
|
if with_aliases:
|
||||||
aliases = [
|
aliases = [k for k, v in self._ALIASES.items() if v in connection_keys]
|
||||||
k for k, v in self._ALIASES.items() if v in connection_keys
|
|
||||||
]
|
|
||||||
for key in itertools.chain(self._connection_keys(), aliases):
|
for key in itertools.chain(self._connection_keys(), aliases):
|
||||||
if key in as_dict:
|
if key in as_dict:
|
||||||
yield key, as_dict[key]
|
yield key, as_dict[key]
|
||||||
@@ -148,9 +147,10 @@ class Credentials(
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data):
|
def __pre_deserialize__(cls, data):
|
||||||
|
data = super().__pre_deserialize__(data)
|
||||||
data = cls.translate_aliases(data)
|
data = cls.translate_aliases(data)
|
||||||
return super().from_dict(data)
|
return data
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def translate_aliases(
|
def translate_aliases(
|
||||||
@@ -158,31 +158,28 @@ class Credentials(
|
|||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
return translate_aliases(kwargs, cls._ALIASES, recurse)
|
return translate_aliases(kwargs, cls._ALIASES, recurse)
|
||||||
|
|
||||||
def to_dict(self, omit_none=True, validate=False, *, with_aliases=False):
|
def __post_serialize__(self, dct):
|
||||||
serialized = super().to_dict(omit_none=omit_none, validate=validate)
|
# no super() -- do we need it?
|
||||||
if with_aliases:
|
if self._ALIASES:
|
||||||
serialized.update({
|
dct.update(
|
||||||
new_name: serialized[canonical_name]
|
{
|
||||||
for new_name, canonical_name in self._ALIASES.items()
|
new_name: dct[canonical_name]
|
||||||
if canonical_name in serialized
|
for new_name, canonical_name in self._ALIASES.items()
|
||||||
})
|
if canonical_name in dct
|
||||||
return serialized
|
}
|
||||||
|
)
|
||||||
|
return dct
|
||||||
|
|
||||||
|
|
||||||
class UserConfigContract(Protocol):
|
class UserConfigContract(Protocol):
|
||||||
send_anonymous_usage_stats: bool
|
send_anonymous_usage_stats: bool
|
||||||
use_colors: Optional[bool]
|
use_colors: Optional[bool] = None
|
||||||
partial_parse: Optional[bool]
|
partial_parse: Optional[bool] = None
|
||||||
printer_width: Optional[int]
|
printer_width: Optional[int] = None
|
||||||
|
|
||||||
def set_values(self, cookie_dir: str) -> None:
|
def set_values(self, cookie_dir: str) -> None:
|
||||||
...
|
...
|
||||||
|
|
||||||
def to_dict(
|
|
||||||
self, omit_none: bool = True, validate: bool = False
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class HasCredentials(Protocol):
|
class HasCredentials(Protocol):
|
||||||
credentials: Credentials
|
credentials: Credentials
|
||||||
@@ -192,10 +189,10 @@ class HasCredentials(Protocol):
|
|||||||
threads: int
|
threads: int
|
||||||
|
|
||||||
def to_target_dict(self):
|
def to_target_dict(self):
|
||||||
raise NotImplementedError('to_target_dict not implemented')
|
raise NotImplementedError("to_target_dict not implemented")
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_QUERY_COMMENT = '''
|
DEFAULT_QUERY_COMMENT = """
|
||||||
{%- set comment_dict = {} -%}
|
{%- set comment_dict = {} -%}
|
||||||
{%- do comment_dict.update(
|
{%- do comment_dict.update(
|
||||||
app='dbt',
|
app='dbt',
|
||||||
@@ -212,11 +209,11 @@ DEFAULT_QUERY_COMMENT = '''
|
|||||||
{%- do comment_dict.update(connection_name=connection_name) -%}
|
{%- do comment_dict.update(connection_name=connection_name) -%}
|
||||||
{%- endif -%}
|
{%- endif -%}
|
||||||
{{ return(tojson(comment_dict)) }}
|
{{ return(tojson(comment_dict)) }}
|
||||||
'''
|
"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class QueryComment(JsonSchemaMixin):
|
class QueryComment(dbtClassMixin):
|
||||||
comment: str = DEFAULT_QUERY_COMMENT
|
comment: str = DEFAULT_QUERY_COMMENT
|
||||||
append: bool = False
|
append: bool = False
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import os
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from hologram import JsonSchemaMixin
|
from dbt.dataclass_schema import dbtClassMixin
|
||||||
|
|
||||||
from dbt.exceptions import InternalException
|
from dbt.exceptions import InternalException
|
||||||
|
|
||||||
@@ -11,11 +11,11 @@ from .util import MacroKey, SourceKey
|
|||||||
|
|
||||||
|
|
||||||
MAXIMUM_SEED_SIZE = 1 * 1024 * 1024
|
MAXIMUM_SEED_SIZE = 1 * 1024 * 1024
|
||||||
MAXIMUM_SEED_SIZE_NAME = '1MB'
|
MAXIMUM_SEED_SIZE_NAME = "1MB"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FilePath(JsonSchemaMixin):
|
class FilePath(dbtClassMixin):
|
||||||
searched_path: str
|
searched_path: str
|
||||||
relative_path: str
|
relative_path: str
|
||||||
project_root: str
|
project_root: str
|
||||||
@@ -28,9 +28,7 @@ class FilePath(JsonSchemaMixin):
|
|||||||
@property
|
@property
|
||||||
def full_path(self) -> str:
|
def full_path(self) -> str:
|
||||||
# useful for symlink preservation
|
# useful for symlink preservation
|
||||||
return os.path.join(
|
return os.path.join(self.project_root, self.searched_path, self.relative_path)
|
||||||
self.project_root, self.searched_path, self.relative_path
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def absolute_path(self) -> str:
|
def absolute_path(self) -> str:
|
||||||
@@ -40,78 +38,76 @@ class FilePath(JsonSchemaMixin):
|
|||||||
def original_file_path(self) -> str:
|
def original_file_path(self) -> str:
|
||||||
# this is mostly used for reporting errors. It doesn't show the project
|
# this is mostly used for reporting errors. It doesn't show the project
|
||||||
# name, should it?
|
# name, should it?
|
||||||
return os.path.join(
|
return os.path.join(self.searched_path, self.relative_path)
|
||||||
self.searched_path, self.relative_path
|
|
||||||
)
|
|
||||||
|
|
||||||
def seed_too_large(self) -> bool:
|
def seed_too_large(self) -> bool:
|
||||||
"""Return whether the file this represents is over the seed size limit
|
"""Return whether the file this represents is over the seed size limit"""
|
||||||
"""
|
|
||||||
return os.stat(self.full_path).st_size > MAXIMUM_SEED_SIZE
|
return os.stat(self.full_path).st_size > MAXIMUM_SEED_SIZE
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FileHash(JsonSchemaMixin):
|
class FileHash(dbtClassMixin):
|
||||||
name: str # the hash type name
|
name: str # the hash type name
|
||||||
checksum: str # the hashlib.hash_type().hexdigest() of the file contents
|
checksum: str # the hashlib.hash_type().hexdigest() of the file contents
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def empty(cls):
|
def empty(cls):
|
||||||
return FileHash(name='none', checksum='')
|
return FileHash(name="none", checksum="")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def path(cls, path: str):
|
def path(cls, path: str):
|
||||||
return FileHash(name='path', checksum=path)
|
return FileHash(name="path", checksum=path)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if not isinstance(other, FileHash):
|
if not isinstance(other, FileHash):
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
|
||||||
if self.name == 'none' or self.name != other.name:
|
if self.name == "none" or self.name != other.name:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return self.checksum == other.checksum
|
return self.checksum == other.checksum
|
||||||
|
|
||||||
def compare(self, contents: str) -> bool:
|
def compare(self, contents: str) -> bool:
|
||||||
"""Compare the file contents with the given hash"""
|
"""Compare the file contents with the given hash"""
|
||||||
if self.name == 'none':
|
if self.name == "none":
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return self.from_contents(contents, name=self.name) == self.checksum
|
return self.from_contents(contents, name=self.name) == self.checksum
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_contents(cls, contents: str, name='sha256') -> 'FileHash':
|
def from_contents(cls, contents: str, name="sha256") -> "FileHash":
|
||||||
"""Create a file hash from the given file contents. The hash is always
|
"""Create a file hash from the given file contents. The hash is always
|
||||||
the utf-8 encoding of the contents given, because dbt only reads files
|
the utf-8 encoding of the contents given, because dbt only reads files
|
||||||
as utf-8.
|
as utf-8.
|
||||||
"""
|
"""
|
||||||
data = contents.encode('utf-8')
|
data = contents.encode("utf-8")
|
||||||
checksum = hashlib.new(name, data).hexdigest()
|
checksum = hashlib.new(name, data).hexdigest()
|
||||||
return cls(name=name, checksum=checksum)
|
return cls(name=name, checksum=checksum)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RemoteFile(JsonSchemaMixin):
|
class RemoteFile(dbtClassMixin):
|
||||||
@property
|
@property
|
||||||
def searched_path(self) -> str:
|
def searched_path(self) -> str:
|
||||||
return 'from remote system'
|
return "from remote system"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def relative_path(self) -> str:
|
def relative_path(self) -> str:
|
||||||
return 'from remote system'
|
return "from remote system"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def absolute_path(self) -> str:
|
def absolute_path(self) -> str:
|
||||||
return 'from remote system'
|
return "from remote system"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def original_file_path(self):
|
def original_file_path(self):
|
||||||
return 'from remote system'
|
return "from remote system"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SourceFile(JsonSchemaMixin):
|
class SourceFile(dbtClassMixin):
|
||||||
"""Define a source file in dbt"""
|
"""Define a source file in dbt"""
|
||||||
|
|
||||||
path: Union[FilePath, RemoteFile] # the path information
|
path: Union[FilePath, RemoteFile] # the path information
|
||||||
checksum: FileHash
|
checksum: FileHash
|
||||||
# we don't want to serialize this
|
# we don't want to serialize this
|
||||||
@@ -133,14 +129,14 @@ class SourceFile(JsonSchemaMixin):
|
|||||||
def search_key(self) -> Optional[str]:
|
def search_key(self) -> Optional[str]:
|
||||||
if isinstance(self.path, RemoteFile):
|
if isinstance(self.path, RemoteFile):
|
||||||
return None
|
return None
|
||||||
if self.checksum.name == 'none':
|
if self.checksum.name == "none":
|
||||||
return None
|
return None
|
||||||
return self.path.search_key
|
return self.path.search_key
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def contents(self) -> str:
|
def contents(self) -> str:
|
||||||
if self._contents is None:
|
if self._contents is None:
|
||||||
raise InternalException('SourceFile has no contents!')
|
raise InternalException("SourceFile has no contents!")
|
||||||
return self._contents
|
return self._contents
|
||||||
|
|
||||||
@contents.setter
|
@contents.setter
|
||||||
@@ -148,20 +144,20 @@ class SourceFile(JsonSchemaMixin):
|
|||||||
self._contents = value
|
self._contents = value
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def empty(cls, path: FilePath) -> 'SourceFile':
|
def empty(cls, path: FilePath) -> "SourceFile":
|
||||||
self = cls(path=path, checksum=FileHash.empty())
|
self = cls(path=path, checksum=FileHash.empty())
|
||||||
self.contents = ''
|
self.contents = ""
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def big_seed(cls, path: FilePath) -> 'SourceFile':
|
def big_seed(cls, path: FilePath) -> "SourceFile":
|
||||||
"""Parse seeds over the size limit with just the path"""
|
"""Parse seeds over the size limit with just the path"""
|
||||||
self = cls(path=path, checksum=FileHash.path(path.original_file_path))
|
self = cls(path=path, checksum=FileHash.path(path.original_file_path))
|
||||||
self.contents = ''
|
self.contents = ""
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def remote(cls, contents: str) -> 'SourceFile':
|
def remote(cls, contents: str) -> "SourceFile":
|
||||||
self = cls(path=RemoteFile(), checksum=FileHash.empty())
|
self = cls(path=RemoteFile(), checksum=FileHash.empty())
|
||||||
self.contents = contents
|
self.contents = contents
|
||||||
return self
|
return self
|
||||||
|
|||||||
@@ -19,19 +19,19 @@ from dbt.contracts.graph.parsed import (
|
|||||||
from dbt.node_types import NodeType
|
from dbt.node_types import NodeType
|
||||||
from dbt.contracts.util import Replaceable
|
from dbt.contracts.util import Replaceable
|
||||||
|
|
||||||
from hologram import JsonSchemaMixin
|
from dbt.dataclass_schema import dbtClassMixin
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional, List, Union, Dict, Type
|
from typing import Optional, List, Union, Dict, Type
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InjectedCTE(JsonSchemaMixin, Replaceable):
|
class InjectedCTE(dbtClassMixin, Replaceable):
|
||||||
id: str
|
id: str
|
||||||
sql: str
|
sql: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CompiledNodeMixin(JsonSchemaMixin):
|
class CompiledNodeMixin(dbtClassMixin):
|
||||||
# this is a special mixin class to provide a required argument. If a node
|
# this is a special mixin class to provide a required argument. If a node
|
||||||
# is missing a `compiled` flag entirely, it must not be a CompiledNode.
|
# is missing a `compiled` flag entirely, it must not be a CompiledNode.
|
||||||
compiled: bool
|
compiled: bool
|
||||||
@@ -58,31 +58,29 @@ class CompiledNode(ParsedNode, CompiledNodeMixin):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CompiledAnalysisNode(CompiledNode):
|
class CompiledAnalysisNode(CompiledNode):
|
||||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Analysis]})
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Analysis]})
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CompiledHookNode(CompiledNode):
|
class CompiledHookNode(CompiledNode):
|
||||||
resource_type: NodeType = field(
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Operation]})
|
||||||
metadata={'restrict': [NodeType.Operation]}
|
|
||||||
)
|
|
||||||
index: Optional[int] = None
|
index: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CompiledModelNode(CompiledNode):
|
class CompiledModelNode(CompiledNode):
|
||||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Model]})
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Model]})
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CompiledRPCNode(CompiledNode):
|
class CompiledRPCNode(CompiledNode):
|
||||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.RPCCall]})
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.RPCCall]})
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CompiledSeedNode(CompiledNode):
|
class CompiledSeedNode(CompiledNode):
|
||||||
# keep this in sync with ParsedSeedNode!
|
# keep this in sync with ParsedSeedNode!
|
||||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Seed]})
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Seed]})
|
||||||
config: SeedConfig = field(default_factory=SeedConfig)
|
config: SeedConfig = field(default_factory=SeedConfig)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -96,26 +94,25 @@ class CompiledSeedNode(CompiledNode):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CompiledSnapshotNode(CompiledNode):
|
class CompiledSnapshotNode(CompiledNode):
|
||||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Snapshot]})
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Snapshot]})
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CompiledDataTestNode(CompiledNode):
|
class CompiledDataTestNode(CompiledNode):
|
||||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Test]})
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Test]})
|
||||||
config: TestConfig = field(default_factory=TestConfig)
|
config: TestConfig = field(default_factory=TestConfig)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CompiledSchemaTestNode(CompiledNode, HasTestMetadata):
|
class CompiledSchemaTestNode(CompiledNode, HasTestMetadata):
|
||||||
# keep this in sync with ParsedSchemaTestNode!
|
# keep this in sync with ParsedSchemaTestNode!
|
||||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Test]})
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Test]})
|
||||||
column_name: Optional[str] = None
|
column_name: Optional[str] = None
|
||||||
config: TestConfig = field(default_factory=TestConfig)
|
config: TestConfig = field(default_factory=TestConfig)
|
||||||
|
|
||||||
def same_config(self, other) -> bool:
|
def same_config(self, other) -> bool:
|
||||||
return (
|
return self.unrendered_config.get("severity") == other.unrendered_config.get(
|
||||||
self.unrendered_config.get('severity') ==
|
"severity"
|
||||||
other.unrendered_config.get('severity')
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def same_column_name(self, other) -> bool:
|
def same_column_name(self, other) -> bool:
|
||||||
@@ -125,11 +122,7 @@ class CompiledSchemaTestNode(CompiledNode, HasTestMetadata):
|
|||||||
if other is None:
|
if other is None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return (
|
return self.same_config(other) and self.same_fqn(other) and True
|
||||||
self.same_config(other) and
|
|
||||||
self.same_fqn(other) and
|
|
||||||
True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
CompiledTestNode = Union[CompiledDataTestNode, CompiledSchemaTestNode]
|
CompiledTestNode = Union[CompiledDataTestNode, CompiledSchemaTestNode]
|
||||||
@@ -175,11 +168,9 @@ def parsed_instance_for(compiled: CompiledNode) -> ParsedResource:
|
|||||||
cls = PARSED_TYPES.get(type(compiled))
|
cls = PARSED_TYPES.get(type(compiled))
|
||||||
if cls is None:
|
if cls is None:
|
||||||
# how???
|
# how???
|
||||||
raise ValueError('invalid resource_type: {}'
|
raise ValueError("invalid resource_type: {}".format(compiled.resource_type))
|
||||||
.format(compiled.resource_type))
|
|
||||||
|
|
||||||
# validate=False to allow extra keys from compiling
|
return cls.from_dict(compiled.to_dict(omit_none=True))
|
||||||
return cls.from_dict(compiled.to_dict(), validate=False)
|
|
||||||
|
|
||||||
|
|
||||||
NonSourceCompiledNode = Union[
|
NonSourceCompiledNode = Union[
|
||||||
|
|||||||
@@ -4,25 +4,51 @@ from dataclasses import dataclass, field
|
|||||||
from itertools import chain, islice
|
from itertools import chain, islice
|
||||||
from multiprocessing.synchronize import Lock
|
from multiprocessing.synchronize import Lock
|
||||||
from typing import (
|
from typing import (
|
||||||
Dict, List, Optional, Union, Mapping, MutableMapping, Any, Set, Tuple,
|
Dict,
|
||||||
TypeVar, Callable, Iterable, Generic, cast, AbstractSet
|
List,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
Mapping,
|
||||||
|
MutableMapping,
|
||||||
|
Any,
|
||||||
|
Set,
|
||||||
|
Tuple,
|
||||||
|
TypeVar,
|
||||||
|
Callable,
|
||||||
|
Iterable,
|
||||||
|
Generic,
|
||||||
|
cast,
|
||||||
|
AbstractSet,
|
||||||
)
|
)
|
||||||
from typing_extensions import Protocol
|
from typing_extensions import Protocol
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from dbt.contracts.graph.compiled import (
|
from dbt.contracts.graph.compiled import (
|
||||||
CompileResultNode, ManifestNode, NonSourceCompiledNode, GraphMemberNode
|
CompileResultNode,
|
||||||
|
ManifestNode,
|
||||||
|
NonSourceCompiledNode,
|
||||||
|
GraphMemberNode,
|
||||||
)
|
)
|
||||||
from dbt.contracts.graph.parsed import (
|
from dbt.contracts.graph.parsed import (
|
||||||
ParsedMacro, ParsedDocumentation, ParsedNodePatch, ParsedMacroPatch,
|
ParsedMacro,
|
||||||
ParsedSourceDefinition, ParsedExposure
|
ParsedDocumentation,
|
||||||
|
ParsedNodePatch,
|
||||||
|
ParsedMacroPatch,
|
||||||
|
ParsedSourceDefinition,
|
||||||
|
ParsedExposure,
|
||||||
)
|
)
|
||||||
from dbt.contracts.files import SourceFile
|
from dbt.contracts.files import SourceFile
|
||||||
from dbt.contracts.util import (
|
from dbt.contracts.util import (
|
||||||
BaseArtifactMetadata, MacroKey, SourceKey, ArtifactMixin, schema_version
|
BaseArtifactMetadata,
|
||||||
|
MacroKey,
|
||||||
|
SourceKey,
|
||||||
|
ArtifactMixin,
|
||||||
|
schema_version,
|
||||||
)
|
)
|
||||||
from dbt.exceptions import (
|
from dbt.exceptions import (
|
||||||
raise_duplicate_resource_name, raise_compiler_error, warn_or_error,
|
raise_duplicate_resource_name,
|
||||||
|
raise_compiler_error,
|
||||||
|
warn_or_error,
|
||||||
raise_invalid_patch,
|
raise_invalid_patch,
|
||||||
)
|
)
|
||||||
from dbt.helper_types import PathSet
|
from dbt.helper_types import PathSet
|
||||||
@@ -40,12 +66,12 @@ RefName = str
|
|||||||
UniqueID = str
|
UniqueID = str
|
||||||
|
|
||||||
|
|
||||||
K_T = TypeVar('K_T')
|
K_T = TypeVar("K_T")
|
||||||
V_T = TypeVar('V_T')
|
V_T = TypeVar("V_T")
|
||||||
|
|
||||||
|
|
||||||
class PackageAwareCache(Generic[K_T, V_T]):
|
class PackageAwareCache(Generic[K_T, V_T]):
|
||||||
def __init__(self, manifest: 'Manifest'):
|
def __init__(self, manifest: "Manifest"):
|
||||||
self.storage: Dict[K_T, Dict[PackageName, UniqueID]] = {}
|
self.storage: Dict[K_T, Dict[PackageName, UniqueID]] = {}
|
||||||
self._manifest = manifest
|
self._manifest = manifest
|
||||||
self.populate()
|
self.populate()
|
||||||
@@ -95,12 +121,10 @@ class DocCache(PackageAwareCache[DocName, ParsedDocumentation]):
|
|||||||
for doc in self._manifest.docs.values():
|
for doc in self._manifest.docs.values():
|
||||||
self.add_doc(doc)
|
self.add_doc(doc)
|
||||||
|
|
||||||
def perform_lookup(
|
def perform_lookup(self, unique_id: UniqueID) -> ParsedDocumentation:
|
||||||
self, unique_id: UniqueID
|
|
||||||
) -> ParsedDocumentation:
|
|
||||||
if unique_id not in self._manifest.docs:
|
if unique_id not in self._manifest.docs:
|
||||||
raise dbt.exceptions.InternalException(
|
raise dbt.exceptions.InternalException(
|
||||||
f'Doc {unique_id} found in cache but not found in manifest'
|
f"Doc {unique_id} found in cache but not found in manifest"
|
||||||
)
|
)
|
||||||
return self._manifest.docs[unique_id]
|
return self._manifest.docs[unique_id]
|
||||||
|
|
||||||
@@ -117,12 +141,10 @@ class SourceCache(PackageAwareCache[SourceKey, ParsedSourceDefinition]):
|
|||||||
for source in self._manifest.sources.values():
|
for source in self._manifest.sources.values():
|
||||||
self.add_source(source)
|
self.add_source(source)
|
||||||
|
|
||||||
def perform_lookup(
|
def perform_lookup(self, unique_id: UniqueID) -> ParsedSourceDefinition:
|
||||||
self, unique_id: UniqueID
|
|
||||||
) -> ParsedSourceDefinition:
|
|
||||||
if unique_id not in self._manifest.sources:
|
if unique_id not in self._manifest.sources:
|
||||||
raise dbt.exceptions.InternalException(
|
raise dbt.exceptions.InternalException(
|
||||||
f'Source {unique_id} found in cache but not found in manifest'
|
f"Source {unique_id} found in cache but not found in manifest"
|
||||||
)
|
)
|
||||||
return self._manifest.sources[unique_id]
|
return self._manifest.sources[unique_id]
|
||||||
|
|
||||||
@@ -131,7 +153,7 @@ class RefableCache(PackageAwareCache[RefName, ManifestNode]):
|
|||||||
# refables are actually unique, so the Dict[PackageName, UniqueID] will
|
# refables are actually unique, so the Dict[PackageName, UniqueID] will
|
||||||
# only ever have exactly one value, but doing 3 dict lookups instead of 1
|
# only ever have exactly one value, but doing 3 dict lookups instead of 1
|
||||||
# is not a big deal at all and retains consistency
|
# is not a big deal at all and retains consistency
|
||||||
def __init__(self, manifest: 'Manifest'):
|
def __init__(self, manifest: "Manifest"):
|
||||||
self._cached_types = set(NodeType.refable())
|
self._cached_types = set(NodeType.refable())
|
||||||
super().__init__(manifest)
|
super().__init__(manifest)
|
||||||
|
|
||||||
@@ -145,12 +167,10 @@ class RefableCache(PackageAwareCache[RefName, ManifestNode]):
|
|||||||
for node in self._manifest.nodes.values():
|
for node in self._manifest.nodes.values():
|
||||||
self.add_node(node)
|
self.add_node(node)
|
||||||
|
|
||||||
def perform_lookup(
|
def perform_lookup(self, unique_id: UniqueID) -> ManifestNode:
|
||||||
self, unique_id: UniqueID
|
|
||||||
) -> ManifestNode:
|
|
||||||
if unique_id not in self._manifest.nodes:
|
if unique_id not in self._manifest.nodes:
|
||||||
raise dbt.exceptions.InternalException(
|
raise dbt.exceptions.InternalException(
|
||||||
f'Node {unique_id} found in cache but not found in manifest'
|
f"Node {unique_id} found in cache but not found in manifest"
|
||||||
)
|
)
|
||||||
return self._manifest.nodes[unique_id]
|
return self._manifest.nodes[unique_id]
|
||||||
|
|
||||||
@@ -171,30 +191,31 @@ def _search_packages(
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ManifestMetadata(BaseArtifactMetadata):
|
class ManifestMetadata(BaseArtifactMetadata):
|
||||||
"""Metadata for the manifest."""
|
"""Metadata for the manifest."""
|
||||||
|
|
||||||
dbt_schema_version: str = field(
|
dbt_schema_version: str = field(
|
||||||
default_factory=lambda: str(WritableManifest.dbt_schema_version)
|
default_factory=lambda: str(WritableManifest.dbt_schema_version)
|
||||||
)
|
)
|
||||||
project_id: Optional[str] = field(
|
project_id: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
'description': 'A unique identifier for the project',
|
"description": "A unique identifier for the project",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
user_id: Optional[UUID] = field(
|
user_id: Optional[UUID] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
'description': 'A unique identifier for the user',
|
"description": "A unique identifier for the user",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
send_anonymous_usage_stats: Optional[bool] = field(
|
send_anonymous_usage_stats: Optional[bool] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata=dict(description=(
|
metadata=dict(
|
||||||
'Whether dbt is configured to send anonymous usage statistics'
|
description=("Whether dbt is configured to send anonymous usage statistics")
|
||||||
)),
|
),
|
||||||
)
|
)
|
||||||
adapter_type: Optional[str] = field(
|
adapter_type: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata=dict(description='The type name of the adapter'),
|
metadata=dict(description="The type name of the adapter"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -205,9 +226,7 @@ class ManifestMetadata(BaseArtifactMetadata):
|
|||||||
self.user_id = tracking.active_user.id
|
self.user_id = tracking.active_user.id
|
||||||
|
|
||||||
if self.send_anonymous_usage_stats is None:
|
if self.send_anonymous_usage_stats is None:
|
||||||
self.send_anonymous_usage_stats = (
|
self.send_anonymous_usage_stats = not tracking.active_user.do_not_track
|
||||||
not tracking.active_user.do_not_track
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default(cls):
|
def default(cls):
|
||||||
@@ -240,7 +259,7 @@ def build_edges(nodes: List[ManifestNode]):
|
|||||||
|
|
||||||
|
|
||||||
def _deepcopy(value):
|
def _deepcopy(value):
|
||||||
return value.from_dict(value.to_dict())
|
return value.from_dict(value.to_dict(omit_none=True))
|
||||||
|
|
||||||
|
|
||||||
class Locality(enum.IntEnum):
|
class Locality(enum.IntEnum):
|
||||||
@@ -281,7 +300,7 @@ class MaterializationCandidate(MacroCandidate):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_macro(
|
def from_macro(
|
||||||
cls, candidate: MacroCandidate, specificity: Specificity
|
cls, candidate: MacroCandidate, specificity: Specificity
|
||||||
) -> 'MaterializationCandidate':
|
) -> "MaterializationCandidate":
|
||||||
return cls(
|
return cls(
|
||||||
locality=candidate.locality,
|
locality=candidate.locality,
|
||||||
macro=candidate.macro,
|
macro=candidate.macro,
|
||||||
@@ -292,15 +311,14 @@ class MaterializationCandidate(MacroCandidate):
|
|||||||
if not isinstance(other, MaterializationCandidate):
|
if not isinstance(other, MaterializationCandidate):
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
equal = (
|
equal = (
|
||||||
self.specificity == other.specificity and
|
self.specificity == other.specificity and self.locality == other.locality
|
||||||
self.locality == other.locality
|
|
||||||
)
|
)
|
||||||
if equal:
|
if equal:
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
'Found two materializations with the name {} (packages {} and '
|
"Found two materializations with the name {} (packages {} and "
|
||||||
'{}). dbt cannot resolve this ambiguity'
|
"{}). dbt cannot resolve this ambiguity".format(
|
||||||
.format(self.macro.name, self.macro.package_name,
|
self.macro.name, self.macro.package_name, other.macro.package_name
|
||||||
other.macro.package_name)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return equal
|
return equal
|
||||||
@@ -319,7 +337,7 @@ class MaterializationCandidate(MacroCandidate):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
M = TypeVar('M', bound=MacroCandidate)
|
M = TypeVar("M", bound=MacroCandidate)
|
||||||
|
|
||||||
|
|
||||||
class CandidateList(List[M]):
|
class CandidateList(List[M]):
|
||||||
@@ -347,10 +365,10 @@ class Searchable(Protocol):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def search_name(self) -> str:
|
def search_name(self) -> str:
|
||||||
raise NotImplementedError('search_name not implemented')
|
raise NotImplementedError("search_name not implemented")
|
||||||
|
|
||||||
|
|
||||||
N = TypeVar('N', bound=Searchable)
|
N = TypeVar("N", bound=Searchable)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -382,7 +400,7 @@ class NameSearcher(Generic[N]):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
D = TypeVar('D')
|
D = TypeVar("D")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -393,19 +411,18 @@ class Disabled(Generic[D]):
|
|||||||
MaybeDocumentation = Optional[ParsedDocumentation]
|
MaybeDocumentation = Optional[ParsedDocumentation]
|
||||||
|
|
||||||
|
|
||||||
MaybeParsedSource = Optional[Union[
|
MaybeParsedSource = Optional[
|
||||||
ParsedSourceDefinition,
|
Union[
|
||||||
Disabled[ParsedSourceDefinition],
|
ParsedSourceDefinition,
|
||||||
]]
|
Disabled[ParsedSourceDefinition],
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
MaybeNonSource = Optional[Union[
|
MaybeNonSource = Optional[Union[ManifestNode, Disabled[ManifestNode]]]
|
||||||
ManifestNode,
|
|
||||||
Disabled[ManifestNode]
|
|
||||||
]]
|
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar('T', bound=GraphMemberNode)
|
T = TypeVar("T", bound=GraphMemberNode)
|
||||||
|
|
||||||
|
|
||||||
def _update_into(dest: MutableMapping[str, T], new_item: T):
|
def _update_into(dest: MutableMapping[str, T], new_item: T):
|
||||||
@@ -416,170 +433,24 @@ def _update_into(dest: MutableMapping[str, T], new_item: T):
|
|||||||
unique_id = new_item.unique_id
|
unique_id = new_item.unique_id
|
||||||
if unique_id not in dest:
|
if unique_id not in dest:
|
||||||
raise dbt.exceptions.RuntimeException(
|
raise dbt.exceptions.RuntimeException(
|
||||||
f'got an update_{new_item.resource_type} call with an '
|
f"got an update_{new_item.resource_type} call with an "
|
||||||
f'unrecognized {new_item.resource_type}: {new_item.unique_id}'
|
f"unrecognized {new_item.resource_type}: {new_item.unique_id}"
|
||||||
)
|
)
|
||||||
existing = dest[unique_id]
|
existing = dest[unique_id]
|
||||||
if new_item.original_file_path != existing.original_file_path:
|
if new_item.original_file_path != existing.original_file_path:
|
||||||
raise dbt.exceptions.RuntimeException(
|
raise dbt.exceptions.RuntimeException(
|
||||||
f'cannot update a {new_item.resource_type} to have a new file '
|
f"cannot update a {new_item.resource_type} to have a new file " f"path!"
|
||||||
f'path!'
|
|
||||||
)
|
)
|
||||||
dest[unique_id] = new_item
|
dest[unique_id] = new_item
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
# This contains macro methods that are in both the Manifest
|
||||||
class Manifest:
|
# and the MacroManifest
|
||||||
"""The manifest for the full graph, after parsing and during compilation.
|
class MacroMethods:
|
||||||
"""
|
# Just to make mypy happy. There must be a better way.
|
||||||
# These attributes are both positional and by keyword. If an attribute
|
def __init__(self):
|
||||||
# is added it must all be added in the __reduce_ex__ method in the
|
self.macros = []
|
||||||
# args tuple in the right position.
|
self.metadata = {}
|
||||||
nodes: MutableMapping[str, ManifestNode]
|
|
||||||
sources: MutableMapping[str, ParsedSourceDefinition]
|
|
||||||
macros: MutableMapping[str, ParsedMacro]
|
|
||||||
docs: MutableMapping[str, ParsedDocumentation]
|
|
||||||
exposures: MutableMapping[str, ParsedExposure]
|
|
||||||
selectors: MutableMapping[str, Any]
|
|
||||||
disabled: List[CompileResultNode]
|
|
||||||
files: MutableMapping[str, SourceFile]
|
|
||||||
metadata: ManifestMetadata = field(default_factory=ManifestMetadata)
|
|
||||||
flat_graph: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
_docs_cache: Optional[DocCache] = None
|
|
||||||
_sources_cache: Optional[SourceCache] = None
|
|
||||||
_refs_cache: Optional[RefableCache] = None
|
|
||||||
_lock: Lock = field(default_factory=flags.MP_CONTEXT.Lock)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_macros(
|
|
||||||
cls,
|
|
||||||
macros: Optional[MutableMapping[str, ParsedMacro]] = None,
|
|
||||||
files: Optional[MutableMapping[str, SourceFile]] = None,
|
|
||||||
) -> 'Manifest':
|
|
||||||
if macros is None:
|
|
||||||
macros = {}
|
|
||||||
if files is None:
|
|
||||||
files = {}
|
|
||||||
return cls(
|
|
||||||
nodes={},
|
|
||||||
sources={},
|
|
||||||
macros=macros,
|
|
||||||
docs={},
|
|
||||||
exposures={},
|
|
||||||
selectors={},
|
|
||||||
disabled=[],
|
|
||||||
files=files,
|
|
||||||
)
|
|
||||||
|
|
||||||
def sync_update_node(
|
|
||||||
self, new_node: NonSourceCompiledNode
|
|
||||||
) -> NonSourceCompiledNode:
|
|
||||||
"""update the node with a lock. The only time we should want to lock is
|
|
||||||
when compiling an ephemeral ancestor of a node at runtime, because
|
|
||||||
multiple threads could be just-in-time compiling the same ephemeral
|
|
||||||
dependency, and we want them to have a consistent view of the manifest.
|
|
||||||
|
|
||||||
If the existing node is not compiled, update it with the new node and
|
|
||||||
return that. If the existing node is compiled, do not update the
|
|
||||||
manifest and return the existing node.
|
|
||||||
"""
|
|
||||||
with self._lock:
|
|
||||||
existing = self.nodes[new_node.unique_id]
|
|
||||||
if getattr(existing, 'compiled', False):
|
|
||||||
# already compiled -> must be a NonSourceCompiledNode
|
|
||||||
return cast(NonSourceCompiledNode, existing)
|
|
||||||
_update_into(self.nodes, new_node)
|
|
||||||
return new_node
|
|
||||||
|
|
||||||
def update_exposure(self, new_exposure: ParsedExposure):
|
|
||||||
_update_into(self.exposures, new_exposure)
|
|
||||||
|
|
||||||
def update_node(self, new_node: ManifestNode):
|
|
||||||
_update_into(self.nodes, new_node)
|
|
||||||
|
|
||||||
def update_source(self, new_source: ParsedSourceDefinition):
|
|
||||||
_update_into(self.sources, new_source)
|
|
||||||
|
|
||||||
def build_flat_graph(self):
|
|
||||||
"""This attribute is used in context.common by each node, so we want to
|
|
||||||
only build it once and avoid any concurrency issues around it.
|
|
||||||
Make sure you don't call this until you're done with building your
|
|
||||||
manifest!
|
|
||||||
"""
|
|
||||||
self.flat_graph = {
|
|
||||||
'nodes': {
|
|
||||||
k: v.to_dict(omit_none=False) for k, v in self.nodes.items()
|
|
||||||
},
|
|
||||||
'sources': {
|
|
||||||
k: v.to_dict(omit_none=False) for k, v in self.sources.items()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def find_disabled_by_name(
|
|
||||||
self, name: str, package: Optional[str] = None
|
|
||||||
) -> Optional[ManifestNode]:
|
|
||||||
searcher: NameSearcher = NameSearcher(
|
|
||||||
name, package, NodeType.refable()
|
|
||||||
)
|
|
||||||
result = searcher.search(self.disabled)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def find_disabled_source_by_name(
|
|
||||||
self, source_name: str, table_name: str, package: Optional[str] = None
|
|
||||||
) -> Optional[ParsedSourceDefinition]:
|
|
||||||
search_name = f'{source_name}.{table_name}'
|
|
||||||
searcher: NameSearcher = NameSearcher(
|
|
||||||
search_name, package, [NodeType.Source]
|
|
||||||
)
|
|
||||||
result = searcher.search(self.disabled)
|
|
||||||
if result is not None:
|
|
||||||
assert isinstance(result, ParsedSourceDefinition)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _find_macros_by_name(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
root_project_name: str,
|
|
||||||
filter: Optional[Callable[[MacroCandidate], bool]] = None
|
|
||||||
) -> CandidateList:
|
|
||||||
"""Find macros by their name.
|
|
||||||
"""
|
|
||||||
# avoid an import cycle
|
|
||||||
from dbt.adapters.factory import get_adapter_package_names
|
|
||||||
candidates: CandidateList = CandidateList()
|
|
||||||
packages = set(get_adapter_package_names(self.metadata.adapter_type))
|
|
||||||
for unique_id, macro in self.macros.items():
|
|
||||||
if macro.name != name:
|
|
||||||
continue
|
|
||||||
candidate = MacroCandidate(
|
|
||||||
locality=_get_locality(macro, root_project_name, packages),
|
|
||||||
macro=macro,
|
|
||||||
)
|
|
||||||
if filter is None or filter(candidate):
|
|
||||||
candidates.append(candidate)
|
|
||||||
|
|
||||||
return candidates
|
|
||||||
|
|
||||||
def _materialization_candidates_for(
|
|
||||||
self, project_name: str,
|
|
||||||
materialization_name: str,
|
|
||||||
adapter_type: Optional[str],
|
|
||||||
) -> CandidateList:
|
|
||||||
|
|
||||||
if adapter_type is None:
|
|
||||||
specificity = Specificity.Default
|
|
||||||
else:
|
|
||||||
specificity = Specificity.Adapter
|
|
||||||
|
|
||||||
full_name = dbt.utils.get_materialization_macro_name(
|
|
||||||
materialization_name=materialization_name,
|
|
||||||
adapter_type=adapter_type,
|
|
||||||
with_prefix=False,
|
|
||||||
)
|
|
||||||
return CandidateList(
|
|
||||||
MaterializationCandidate.from_macro(m, specificity)
|
|
||||||
for m in self._find_macros_by_name(full_name, project_name)
|
|
||||||
)
|
|
||||||
|
|
||||||
def find_macro_by_name(
|
def find_macro_by_name(
|
||||||
self, name: str, root_project_name: str, package: Optional[str]
|
self, name: str, root_project_name: str, package: Optional[str]
|
||||||
@@ -592,6 +463,7 @@ class Manifest:
|
|||||||
"""
|
"""
|
||||||
filter: Optional[Callable[[MacroCandidate], bool]] = None
|
filter: Optional[Callable[[MacroCandidate], bool]] = None
|
||||||
if package is not None:
|
if package is not None:
|
||||||
|
|
||||||
def filter(candidate: MacroCandidate) -> bool:
|
def filter(candidate: MacroCandidate) -> bool:
|
||||||
return package == candidate.macro.package_name
|
return package == candidate.macro.package_name
|
||||||
|
|
||||||
@@ -614,27 +486,157 @@ class Manifest:
|
|||||||
- return the `generate_{component}_name` macro from the 'dbt'
|
- return the `generate_{component}_name` macro from the 'dbt'
|
||||||
internal project
|
internal project
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def filter(candidate: MacroCandidate) -> bool:
|
def filter(candidate: MacroCandidate) -> bool:
|
||||||
return candidate.locality != Locality.Imported
|
return candidate.locality != Locality.Imported
|
||||||
|
|
||||||
candidates: CandidateList = self._find_macros_by_name(
|
candidates: CandidateList = self._find_macros_by_name(
|
||||||
name=f'generate_{component}_name',
|
name=f"generate_{component}_name",
|
||||||
root_project_name=root_project_name,
|
root_project_name=root_project_name,
|
||||||
# filter out imported packages
|
# filter out imported packages
|
||||||
filter=filter,
|
filter=filter,
|
||||||
)
|
)
|
||||||
return candidates.last()
|
return candidates.last()
|
||||||
|
|
||||||
|
def _find_macros_by_name(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
root_project_name: str,
|
||||||
|
filter: Optional[Callable[[MacroCandidate], bool]] = None,
|
||||||
|
) -> CandidateList:
|
||||||
|
"""Find macros by their name."""
|
||||||
|
# avoid an import cycle
|
||||||
|
from dbt.adapters.factory import get_adapter_package_names
|
||||||
|
|
||||||
|
candidates: CandidateList = CandidateList()
|
||||||
|
packages = set(get_adapter_package_names(self.metadata.adapter_type))
|
||||||
|
for unique_id, macro in self.macros.items():
|
||||||
|
if macro.name != name:
|
||||||
|
continue
|
||||||
|
candidate = MacroCandidate(
|
||||||
|
locality=_get_locality(macro, root_project_name, packages),
|
||||||
|
macro=macro,
|
||||||
|
)
|
||||||
|
if filter is None or filter(candidate):
|
||||||
|
candidates.append(candidate)
|
||||||
|
|
||||||
|
return candidates
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Manifest(MacroMethods):
|
||||||
|
"""The manifest for the full graph, after parsing and during compilation."""
|
||||||
|
|
||||||
|
# These attributes are both positional and by keyword. If an attribute
|
||||||
|
# is added it must all be added in the __reduce_ex__ method in the
|
||||||
|
# args tuple in the right position.
|
||||||
|
nodes: MutableMapping[str, ManifestNode]
|
||||||
|
sources: MutableMapping[str, ParsedSourceDefinition]
|
||||||
|
macros: MutableMapping[str, ParsedMacro]
|
||||||
|
docs: MutableMapping[str, ParsedDocumentation]
|
||||||
|
exposures: MutableMapping[str, ParsedExposure]
|
||||||
|
selectors: MutableMapping[str, Any]
|
||||||
|
disabled: List[CompileResultNode]
|
||||||
|
files: MutableMapping[str, SourceFile]
|
||||||
|
metadata: ManifestMetadata = field(default_factory=ManifestMetadata)
|
||||||
|
flat_graph: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
_docs_cache: Optional[DocCache] = None
|
||||||
|
_sources_cache: Optional[SourceCache] = None
|
||||||
|
_refs_cache: Optional[RefableCache] = None
|
||||||
|
_lock: Lock = field(default_factory=flags.MP_CONTEXT.Lock)
|
||||||
|
|
||||||
|
def sync_update_node(
|
||||||
|
self, new_node: NonSourceCompiledNode
|
||||||
|
) -> NonSourceCompiledNode:
|
||||||
|
"""update the node with a lock. The only time we should want to lock is
|
||||||
|
when compiling an ephemeral ancestor of a node at runtime, because
|
||||||
|
multiple threads could be just-in-time compiling the same ephemeral
|
||||||
|
dependency, and we want them to have a consistent view of the manifest.
|
||||||
|
|
||||||
|
If the existing node is not compiled, update it with the new node and
|
||||||
|
return that. If the existing node is compiled, do not update the
|
||||||
|
manifest and return the existing node.
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
existing = self.nodes[new_node.unique_id]
|
||||||
|
if getattr(existing, "compiled", False):
|
||||||
|
# already compiled -> must be a NonSourceCompiledNode
|
||||||
|
return cast(NonSourceCompiledNode, existing)
|
||||||
|
_update_into(self.nodes, new_node)
|
||||||
|
return new_node
|
||||||
|
|
||||||
|
def update_exposure(self, new_exposure: ParsedExposure):
|
||||||
|
_update_into(self.exposures, new_exposure)
|
||||||
|
|
||||||
|
def update_node(self, new_node: ManifestNode):
|
||||||
|
_update_into(self.nodes, new_node)
|
||||||
|
|
||||||
|
def update_source(self, new_source: ParsedSourceDefinition):
|
||||||
|
_update_into(self.sources, new_source)
|
||||||
|
|
||||||
|
def build_flat_graph(self):
|
||||||
|
"""This attribute is used in context.common by each node, so we want to
|
||||||
|
only build it once and avoid any concurrency issues around it.
|
||||||
|
Make sure you don't call this until you're done with building your
|
||||||
|
manifest!
|
||||||
|
"""
|
||||||
|
self.flat_graph = {
|
||||||
|
"nodes": {k: v.to_dict(omit_none=False) for k, v in self.nodes.items()},
|
||||||
|
"sources": {k: v.to_dict(omit_none=False) for k, v in self.sources.items()},
|
||||||
|
}
|
||||||
|
|
||||||
|
def find_disabled_by_name(
|
||||||
|
self, name: str, package: Optional[str] = None
|
||||||
|
) -> Optional[ManifestNode]:
|
||||||
|
searcher: NameSearcher = NameSearcher(name, package, NodeType.refable())
|
||||||
|
result = searcher.search(self.disabled)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def find_disabled_source_by_name(
|
||||||
|
self, source_name: str, table_name: str, package: Optional[str] = None
|
||||||
|
) -> Optional[ParsedSourceDefinition]:
|
||||||
|
search_name = f"{source_name}.{table_name}"
|
||||||
|
searcher: NameSearcher = NameSearcher(search_name, package, [NodeType.Source])
|
||||||
|
result = searcher.search(self.disabled)
|
||||||
|
if result is not None:
|
||||||
|
assert isinstance(result, ParsedSourceDefinition)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _materialization_candidates_for(
|
||||||
|
self,
|
||||||
|
project_name: str,
|
||||||
|
materialization_name: str,
|
||||||
|
adapter_type: Optional[str],
|
||||||
|
) -> CandidateList:
|
||||||
|
|
||||||
|
if adapter_type is None:
|
||||||
|
specificity = Specificity.Default
|
||||||
|
else:
|
||||||
|
specificity = Specificity.Adapter
|
||||||
|
|
||||||
|
full_name = dbt.utils.get_materialization_macro_name(
|
||||||
|
materialization_name=materialization_name,
|
||||||
|
adapter_type=adapter_type,
|
||||||
|
with_prefix=False,
|
||||||
|
)
|
||||||
|
return CandidateList(
|
||||||
|
MaterializationCandidate.from_macro(m, specificity)
|
||||||
|
for m in self._find_macros_by_name(full_name, project_name)
|
||||||
|
)
|
||||||
|
|
||||||
def find_materialization_macro_by_name(
|
def find_materialization_macro_by_name(
|
||||||
self, project_name: str, materialization_name: str, adapter_type: str
|
self, project_name: str, materialization_name: str, adapter_type: str
|
||||||
) -> Optional[ParsedMacro]:
|
) -> Optional[ParsedMacro]:
|
||||||
candidates: CandidateList = CandidateList(chain.from_iterable(
|
candidates: CandidateList = CandidateList(
|
||||||
self._materialization_candidates_for(
|
chain.from_iterable(
|
||||||
project_name=project_name,
|
self._materialization_candidates_for(
|
||||||
materialization_name=materialization_name,
|
project_name=project_name,
|
||||||
adapter_type=atype,
|
materialization_name=materialization_name,
|
||||||
) for atype in (adapter_type, None)
|
adapter_type=atype,
|
||||||
))
|
)
|
||||||
|
for atype in (adapter_type, None)
|
||||||
|
)
|
||||||
|
)
|
||||||
return candidates.last()
|
return candidates.last()
|
||||||
|
|
||||||
def get_resource_fqns(self) -> Mapping[str, PathSet]:
|
def get_resource_fqns(self) -> Mapping[str, PathSet]:
|
||||||
@@ -658,9 +660,7 @@ class Manifest:
|
|||||||
if node.resource_type in NodeType.refable():
|
if node.resource_type in NodeType.refable():
|
||||||
self._refs_cache.add_node(node)
|
self._refs_cache.add_node(node)
|
||||||
|
|
||||||
def patch_macros(
|
def patch_macros(self, patches: MutableMapping[MacroKey, ParsedMacroPatch]) -> None:
|
||||||
self, patches: MutableMapping[MacroKey, ParsedMacroPatch]
|
|
||||||
) -> None:
|
|
||||||
for macro in self.macros.values():
|
for macro in self.macros.values():
|
||||||
key = (macro.package_name, macro.name)
|
key = (macro.package_name, macro.name)
|
||||||
patch = patches.pop(key, None)
|
patch = patches.pop(key, None)
|
||||||
@@ -672,12 +672,10 @@ class Manifest:
|
|||||||
for patch in patches.values():
|
for patch in patches.values():
|
||||||
warn_or_error(
|
warn_or_error(
|
||||||
f'WARNING: Found documentation for macro "{patch.name}" '
|
f'WARNING: Found documentation for macro "{patch.name}" '
|
||||||
f'which was not found'
|
f"which was not found"
|
||||||
)
|
)
|
||||||
|
|
||||||
def patch_nodes(
|
def patch_nodes(self, patches: MutableMapping[str, ParsedNodePatch]) -> None:
|
||||||
self, patches: MutableMapping[str, ParsedNodePatch]
|
|
||||||
) -> None:
|
|
||||||
"""Patch nodes with the given dict of patches. Note that this consumes
|
"""Patch nodes with the given dict of patches. Note that this consumes
|
||||||
the input!
|
the input!
|
||||||
This relies on the fact that all nodes have unique _name_ fields, not
|
This relies on the fact that all nodes have unique _name_ fields, not
|
||||||
@@ -694,15 +692,15 @@ class Manifest:
|
|||||||
|
|
||||||
expected_key = node.resource_type.pluralize()
|
expected_key = node.resource_type.pluralize()
|
||||||
if expected_key != patch.yaml_key:
|
if expected_key != patch.yaml_key:
|
||||||
if patch.yaml_key == 'models':
|
if patch.yaml_key == "models":
|
||||||
deprecations.warn(
|
deprecations.warn(
|
||||||
'models-key-mismatch',
|
"models-key-mismatch",
|
||||||
patch=patch, node=node, expected_key=expected_key
|
patch=patch,
|
||||||
|
node=node,
|
||||||
|
expected_key=expected_key,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise_invalid_patch(
|
raise_invalid_patch(node, patch.yaml_key, patch.original_file_path)
|
||||||
node, patch.yaml_key, patch.original_file_path
|
|
||||||
)
|
|
||||||
|
|
||||||
node.patch(patch)
|
node.patch(patch)
|
||||||
|
|
||||||
@@ -711,22 +709,25 @@ class Manifest:
|
|||||||
for patch in patches.values():
|
for patch in patches.values():
|
||||||
# since patches aren't nodes, we can't use the existing
|
# since patches aren't nodes, we can't use the existing
|
||||||
# target_not_found warning
|
# target_not_found warning
|
||||||
logger.debug((
|
logger.debug(
|
||||||
'WARNING: Found documentation for resource "{}" which was '
|
(
|
||||||
'not found or is disabled').format(patch.name)
|
'WARNING: Found documentation for resource "{}" which was '
|
||||||
|
"not found or is disabled"
|
||||||
|
).format(patch.name)
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_used_schemas(self, resource_types=None):
|
def get_used_schemas(self, resource_types=None):
|
||||||
return frozenset({
|
return frozenset(
|
||||||
(node.database, node.schema) for node in
|
{
|
||||||
chain(self.nodes.values(), self.sources.values())
|
(node.database, node.schema)
|
||||||
if not resource_types or node.resource_type in resource_types
|
for node in chain(self.nodes.values(), self.sources.values())
|
||||||
})
|
if not resource_types or node.resource_type in resource_types
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def get_used_databases(self):
|
def get_used_databases(self):
|
||||||
return frozenset(
|
return frozenset(
|
||||||
x.database for x in
|
x.database for x in chain(self.nodes.values(), self.sources.values())
|
||||||
chain(self.nodes.values(), self.sources.values())
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def deepcopy(self):
|
def deepcopy(self):
|
||||||
@@ -743,11 +744,13 @@ class Manifest:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def writable_manifest(self):
|
def writable_manifest(self):
|
||||||
edge_members = list(chain(
|
edge_members = list(
|
||||||
self.nodes.values(),
|
chain(
|
||||||
self.sources.values(),
|
self.nodes.values(),
|
||||||
self.exposures.values(),
|
self.sources.values(),
|
||||||
))
|
self.exposures.values(),
|
||||||
|
)
|
||||||
|
)
|
||||||
forward_edges, backward_edges = build_edges(edge_members)
|
forward_edges, backward_edges = build_edges(edge_members)
|
||||||
|
|
||||||
return WritableManifest(
|
return WritableManifest(
|
||||||
@@ -763,10 +766,10 @@ class Manifest:
|
|||||||
parent_map=backward_edges,
|
parent_map=backward_edges,
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_dict(self, omit_none=True, validate=False):
|
# When 'to_dict' is called on the Manifest, it substitues a
|
||||||
return self.writable_manifest().to_dict(
|
# WritableManifest
|
||||||
omit_none=omit_none, validate=validate
|
def __pre_serialize__(self):
|
||||||
)
|
return self.writable_manifest()
|
||||||
|
|
||||||
def write(self, path):
|
def write(self, path):
|
||||||
self.writable_manifest().write(path)
|
self.writable_manifest().write(path)
|
||||||
@@ -781,7 +784,7 @@ class Manifest:
|
|||||||
else:
|
else:
|
||||||
# something terrible has happened
|
# something terrible has happened
|
||||||
raise dbt.exceptions.InternalException(
|
raise dbt.exceptions.InternalException(
|
||||||
'Expected node {} not found in manifest'.format(unique_id)
|
"Expected node {} not found in manifest".format(unique_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -830,9 +833,7 @@ class Manifest:
|
|||||||
|
|
||||||
# it's possible that the node is disabled
|
# it's possible that the node is disabled
|
||||||
if disabled is None:
|
if disabled is None:
|
||||||
disabled = self.find_disabled_by_name(
|
disabled = self.find_disabled_by_name(target_model_name, pkg)
|
||||||
target_model_name, pkg
|
|
||||||
)
|
|
||||||
|
|
||||||
if disabled is not None:
|
if disabled is not None:
|
||||||
return Disabled(disabled)
|
return Disabled(disabled)
|
||||||
@@ -843,7 +844,7 @@ class Manifest:
|
|||||||
target_source_name: str,
|
target_source_name: str,
|
||||||
target_table_name: str,
|
target_table_name: str,
|
||||||
current_project: str,
|
current_project: str,
|
||||||
node_package: str
|
node_package: str,
|
||||||
) -> MaybeParsedSource:
|
) -> MaybeParsedSource:
|
||||||
key = (target_source_name, target_table_name)
|
key = (target_source_name, target_table_name)
|
||||||
candidates = _search_packages(current_project, node_package)
|
candidates = _search_packages(current_project, node_package)
|
||||||
@@ -876,9 +877,7 @@ class Manifest:
|
|||||||
resolve_ref except the is_enabled checks are unnecessary as docs are
|
resolve_ref except the is_enabled checks are unnecessary as docs are
|
||||||
always enabled.
|
always enabled.
|
||||||
"""
|
"""
|
||||||
candidates = _search_packages(
|
candidates = _search_packages(current_project, node_package, package)
|
||||||
current_project, node_package, package
|
|
||||||
)
|
|
||||||
|
|
||||||
for pkg in candidates:
|
for pkg in candidates:
|
||||||
result = self.docs_cache.find_cached_value(name, pkg)
|
result = self.docs_cache.find_cached_value(name, pkg)
|
||||||
@@ -889,7 +888,7 @@ class Manifest:
|
|||||||
def merge_from_artifact(
|
def merge_from_artifact(
|
||||||
self,
|
self,
|
||||||
adapter,
|
adapter,
|
||||||
other: 'WritableManifest',
|
other: "WritableManifest",
|
||||||
selected: AbstractSet[UniqueID],
|
selected: AbstractSet[UniqueID],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Given the selected unique IDs and a writable manifest, update this
|
"""Given the selected unique IDs and a writable manifest, update this
|
||||||
@@ -902,10 +901,10 @@ class Manifest:
|
|||||||
for unique_id, node in other.nodes.items():
|
for unique_id, node in other.nodes.items():
|
||||||
current = self.nodes.get(unique_id)
|
current = self.nodes.get(unique_id)
|
||||||
if current and (
|
if current and (
|
||||||
node.resource_type in refables and
|
node.resource_type in refables
|
||||||
not node.is_ephemeral and
|
and not node.is_ephemeral
|
||||||
unique_id not in selected and
|
and unique_id not in selected
|
||||||
not adapter.get_relation(
|
and not adapter.get_relation(
|
||||||
current.database, current.schema, current.identifier
|
current.database, current.schema, current.identifier
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
@@ -914,9 +913,7 @@ class Manifest:
|
|||||||
|
|
||||||
# log up to 5 items
|
# log up to 5 items
|
||||||
sample = list(islice(merged, 5))
|
sample = list(islice(merged, 5))
|
||||||
logger.debug(
|
logger.debug(f"Merged {len(merged)} items from state (sample: {sample})")
|
||||||
f'Merged {len(merged)} items from state (sample: {sample})'
|
|
||||||
)
|
|
||||||
|
|
||||||
# Provide support for copy.deepcopy() - we just need to avoid the lock!
|
# Provide support for copy.deepcopy() - we just need to avoid the lock!
|
||||||
# pickle and deepcopy use this. It returns a callable object used to
|
# pickle and deepcopy use this. It returns a callable object used to
|
||||||
@@ -944,48 +941,67 @@ class Manifest:
|
|||||||
return self.__class__, args
|
return self.__class__, args
|
||||||
|
|
||||||
|
|
||||||
|
class MacroManifest(MacroMethods):
|
||||||
|
def __init__(self, macros, files):
|
||||||
|
self.macros = macros
|
||||||
|
self.files = files
|
||||||
|
self.metadata = ManifestMetadata()
|
||||||
|
# This is returned by the 'graph' context property
|
||||||
|
# in the ProviderContext class.
|
||||||
|
self.flat_graph = {}
|
||||||
|
|
||||||
|
|
||||||
|
AnyManifest = Union[Manifest, MacroManifest]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('manifest', 1)
|
@schema_version("manifest", 1)
|
||||||
class WritableManifest(ArtifactMixin):
|
class WritableManifest(ArtifactMixin):
|
||||||
nodes: Mapping[UniqueID, ManifestNode] = field(
|
nodes: Mapping[UniqueID, ManifestNode] = field(
|
||||||
metadata=dict(description=(
|
metadata=dict(
|
||||||
'The nodes defined in the dbt project and its dependencies'
|
description=("The nodes defined in the dbt project and its dependencies")
|
||||||
))
|
)
|
||||||
)
|
)
|
||||||
sources: Mapping[UniqueID, ParsedSourceDefinition] = field(
|
sources: Mapping[UniqueID, ParsedSourceDefinition] = field(
|
||||||
metadata=dict(description=(
|
metadata=dict(
|
||||||
'The sources defined in the dbt project and its dependencies'
|
description=("The sources defined in the dbt project and its dependencies")
|
||||||
))
|
)
|
||||||
)
|
)
|
||||||
macros: Mapping[UniqueID, ParsedMacro] = field(
|
macros: Mapping[UniqueID, ParsedMacro] = field(
|
||||||
metadata=dict(description=(
|
metadata=dict(
|
||||||
'The macros defined in the dbt project and its dependencies'
|
description=("The macros defined in the dbt project and its dependencies")
|
||||||
))
|
)
|
||||||
)
|
)
|
||||||
docs: Mapping[UniqueID, ParsedDocumentation] = field(
|
docs: Mapping[UniqueID, ParsedDocumentation] = field(
|
||||||
metadata=dict(description=(
|
metadata=dict(
|
||||||
'The docs defined in the dbt project and its dependencies'
|
description=("The docs defined in the dbt project and its dependencies")
|
||||||
))
|
)
|
||||||
)
|
)
|
||||||
exposures: Mapping[UniqueID, ParsedExposure] = field(
|
exposures: Mapping[UniqueID, ParsedExposure] = field(
|
||||||
metadata=dict(description=(
|
metadata=dict(
|
||||||
'The exposures defined in the dbt project and its dependencies'
|
description=(
|
||||||
))
|
"The exposures defined in the dbt project and its dependencies"
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
selectors: Mapping[UniqueID, Any] = field(
|
selectors: Mapping[UniqueID, Any] = field(
|
||||||
metadata=dict(description=(
|
metadata=dict(description=("The selectors defined in selectors.yml"))
|
||||||
'The selectors defined in selectors.yml'
|
)
|
||||||
))
|
disabled: Optional[List[CompileResultNode]] = field(
|
||||||
|
metadata=dict(description="A list of the disabled nodes in the target")
|
||||||
|
)
|
||||||
|
parent_map: Optional[NodeEdgeMap] = field(
|
||||||
|
metadata=dict(
|
||||||
|
description="A mapping from child nodes to their dependencies",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
child_map: Optional[NodeEdgeMap] = field(
|
||||||
|
metadata=dict(
|
||||||
|
description="A mapping from parent nodes to their dependents",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
metadata: ManifestMetadata = field(
|
||||||
|
metadata=dict(
|
||||||
|
description="Metadata about the manifest",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
disabled: Optional[List[CompileResultNode]] = field(metadata=dict(
|
|
||||||
description='A list of the disabled nodes in the target'
|
|
||||||
))
|
|
||||||
parent_map: Optional[NodeEdgeMap] = field(metadata=dict(
|
|
||||||
description='A mapping from child nodes to their dependencies',
|
|
||||||
))
|
|
||||||
child_map: Optional[NodeEdgeMap] = field(metadata=dict(
|
|
||||||
description='A mapping from parent nodes to their dependents',
|
|
||||||
))
|
|
||||||
metadata: ManifestMetadata = field(metadata=dict(
|
|
||||||
description='Metadata about the manifest',
|
|
||||||
))
|
|
||||||
|
|||||||
@@ -2,19 +2,21 @@ from dataclasses import field, Field, dataclass
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import (
|
from typing import (
|
||||||
Any, List, Optional, Dict, MutableMapping, Union, Type, NewType, Tuple,
|
Any,
|
||||||
TypeVar, Callable, cast, Hashable
|
List,
|
||||||
|
Optional,
|
||||||
|
Dict,
|
||||||
|
MutableMapping,
|
||||||
|
Union,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
Callable,
|
||||||
|
)
|
||||||
|
from dbt.dataclass_schema import (
|
||||||
|
dbtClassMixin,
|
||||||
|
ValidationError,
|
||||||
|
register_pattern,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: patch+upgrade hologram to avoid this jsonschema import
|
|
||||||
import jsonschema # type: ignore
|
|
||||||
|
|
||||||
# This is protected, but we really do want to reuse this logic, and the cache!
|
|
||||||
# It would be nice to move the custom error picking stuff into hologram!
|
|
||||||
from hologram import _validate_schema
|
|
||||||
from hologram import JsonSchemaMixin, ValidationError
|
|
||||||
from hologram.helpers import StrEnum, register_pattern
|
|
||||||
|
|
||||||
from dbt.contracts.graph.unparsed import AdditionalPropertiesAllowed
|
from dbt.contracts.graph.unparsed import AdditionalPropertiesAllowed
|
||||||
from dbt.exceptions import CompilationException, InternalException
|
from dbt.exceptions import CompilationException, InternalException
|
||||||
from dbt.contracts.util import Replaceable, list_str
|
from dbt.contracts.util import Replaceable, list_str
|
||||||
@@ -22,7 +24,7 @@ from dbt import hooks
|
|||||||
from dbt.node_types import NodeType
|
from dbt.node_types import NodeType
|
||||||
|
|
||||||
|
|
||||||
M = TypeVar('M', bound='Metadata')
|
M = TypeVar("M", bound="Metadata")
|
||||||
|
|
||||||
|
|
||||||
def _get_meta_value(cls: Type[M], fld: Field, key: str, default: Any) -> M:
|
def _get_meta_value(cls: Type[M], fld: Field, key: str, default: Any) -> M:
|
||||||
@@ -37,9 +39,7 @@ def _get_meta_value(cls: Type[M], fld: Field, key: str, default: Any) -> M:
|
|||||||
try:
|
try:
|
||||||
return cls(value)
|
return cls(value)
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
raise InternalException(
|
raise InternalException(f"Invalid {cls} value: {value}") from exc
|
||||||
f'Invalid {cls} value: {value}'
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
|
|
||||||
def _set_meta_value(
|
def _set_meta_value(
|
||||||
@@ -61,19 +61,17 @@ class Metadata(Enum):
|
|||||||
|
|
||||||
return _get_meta_value(cls, fld, key, default)
|
return _get_meta_value(cls, fld, key, default)
|
||||||
|
|
||||||
def meta(
|
def meta(self, existing: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||||
self, existing: Optional[Dict[str, Any]] = None
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
key = self.metadata_key()
|
key = self.metadata_key()
|
||||||
return _set_meta_value(self, key, existing)
|
return _set_meta_value(self, key, existing)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default_field(cls) -> 'Metadata':
|
def default_field(cls) -> "Metadata":
|
||||||
raise NotImplementedError('Not implemented')
|
raise NotImplementedError("Not implemented")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def metadata_key(cls) -> str:
|
def metadata_key(cls) -> str:
|
||||||
raise NotImplementedError('Not implemented')
|
raise NotImplementedError("Not implemented")
|
||||||
|
|
||||||
|
|
||||||
class MergeBehavior(Metadata):
|
class MergeBehavior(Metadata):
|
||||||
@@ -82,12 +80,12 @@ class MergeBehavior(Metadata):
|
|||||||
Clobber = 3
|
Clobber = 3
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default_field(cls) -> 'MergeBehavior':
|
def default_field(cls) -> "MergeBehavior":
|
||||||
return cls.Clobber
|
return cls.Clobber
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def metadata_key(cls) -> str:
|
def metadata_key(cls) -> str:
|
||||||
return 'merge'
|
return "merge"
|
||||||
|
|
||||||
|
|
||||||
class ShowBehavior(Metadata):
|
class ShowBehavior(Metadata):
|
||||||
@@ -95,12 +93,12 @@ class ShowBehavior(Metadata):
|
|||||||
Hide = 2
|
Hide = 2
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default_field(cls) -> 'ShowBehavior':
|
def default_field(cls) -> "ShowBehavior":
|
||||||
return cls.Show
|
return cls.Show
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def metadata_key(cls) -> str:
|
def metadata_key(cls) -> str:
|
||||||
return 'show_hide'
|
return "show_hide"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def should_show(cls, fld: Field) -> bool:
|
def should_show(cls, fld: Field) -> bool:
|
||||||
@@ -112,12 +110,12 @@ class CompareBehavior(Metadata):
|
|||||||
Exclude = 2
|
Exclude = 2
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default_field(cls) -> 'CompareBehavior':
|
def default_field(cls) -> "CompareBehavior":
|
||||||
return cls.Include
|
return cls.Include
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def metadata_key(cls) -> str:
|
def metadata_key(cls) -> str:
|
||||||
return 'compare'
|
return "compare"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def should_include(cls, fld: Field) -> bool:
|
def should_include(cls, fld: Field) -> bool:
|
||||||
@@ -149,55 +147,44 @@ def _merge_field_value(
|
|||||||
return _listify(self_value) + _listify(other_value)
|
return _listify(self_value) + _listify(other_value)
|
||||||
elif merge_behavior == MergeBehavior.Update:
|
elif merge_behavior == MergeBehavior.Update:
|
||||||
if not isinstance(self_value, dict):
|
if not isinstance(self_value, dict):
|
||||||
raise InternalException(f'expected dict, got {self_value}')
|
raise InternalException(f"expected dict, got {self_value}")
|
||||||
if not isinstance(other_value, dict):
|
if not isinstance(other_value, dict):
|
||||||
raise InternalException(f'expected dict, got {other_value}')
|
raise InternalException(f"expected dict, got {other_value}")
|
||||||
value = self_value.copy()
|
value = self_value.copy()
|
||||||
value.update(other_value)
|
value.update(other_value)
|
||||||
return value
|
return value
|
||||||
else:
|
else:
|
||||||
raise InternalException(
|
raise InternalException(f"Got an invalid merge_behavior: {merge_behavior}")
|
||||||
f'Got an invalid merge_behavior: {merge_behavior}'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def insensitive_patterns(*patterns: str):
|
def insensitive_patterns(*patterns: str):
|
||||||
lowercased = []
|
lowercased = []
|
||||||
for pattern in patterns:
|
for pattern in patterns:
|
||||||
lowercased.append(
|
lowercased.append(
|
||||||
''.join('[{}{}]'.format(s.upper(), s.lower()) for s in pattern)
|
"".join("[{}{}]".format(s.upper(), s.lower()) for s in pattern)
|
||||||
)
|
)
|
||||||
return '^({})$'.format('|'.join(lowercased))
|
return "^({})$".format("|".join(lowercased))
|
||||||
|
|
||||||
|
|
||||||
Severity = NewType('Severity', str)
|
class Severity(str):
|
||||||
|
pass
|
||||||
register_pattern(Severity, insensitive_patterns('warn', 'error'))
|
|
||||||
|
|
||||||
|
|
||||||
class SnapshotStrategy(StrEnum):
|
register_pattern(Severity, insensitive_patterns("warn", "error"))
|
||||||
Timestamp = 'timestamp'
|
|
||||||
Check = 'check'
|
|
||||||
|
|
||||||
|
|
||||||
class All(StrEnum):
|
|
||||||
All = 'all'
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Hook(JsonSchemaMixin, Replaceable):
|
class Hook(dbtClassMixin, Replaceable):
|
||||||
sql: str
|
sql: str
|
||||||
transaction: bool = True
|
transaction: bool = True
|
||||||
index: Optional[int] = None
|
index: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar('T', bound='BaseConfig')
|
T = TypeVar("T", bound="BaseConfig")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseConfig(
|
class BaseConfig(AdditionalPropertiesAllowed, Replaceable, MutableMapping[str, Any]):
|
||||||
AdditionalPropertiesAllowed, Replaceable, MutableMapping[str, Any]
|
|
||||||
):
|
|
||||||
# Implement MutableMapping so this config will behave as some macros expect
|
# Implement MutableMapping so this config will behave as some macros expect
|
||||||
# during parsing (notably, syntax like `{{ node.config['schema'] }}`)
|
# during parsing (notably, syntax like `{{ node.config['schema'] }}`)
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
@@ -218,8 +205,7 @@ class BaseConfig(
|
|||||||
def __delitem__(self, key):
|
def __delitem__(self, key):
|
||||||
if hasattr(self, key):
|
if hasattr(self, key):
|
||||||
msg = (
|
msg = (
|
||||||
'Error, tried to delete config key "{}": Cannot delete '
|
'Error, tried to delete config key "{}": Cannot delete ' "built-in keys"
|
||||||
'built-in keys'
|
|
||||||
).format(key)
|
).format(key)
|
||||||
raise CompilationException(msg)
|
raise CompilationException(msg)
|
||||||
else:
|
else:
|
||||||
@@ -259,9 +245,7 @@ class BaseConfig(
|
|||||||
return unrendered[key] == other[key]
|
return unrendered[key] == other[key]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def same_contents(
|
def same_contents(cls, unrendered: Dict[str, Any], other: Dict[str, Any]) -> bool:
|
||||||
cls, unrendered: Dict[str, Any], other: Dict[str, Any]
|
|
||||||
) -> bool:
|
|
||||||
"""This is like __eq__, except it ignores some fields."""
|
"""This is like __eq__, except it ignores some fields."""
|
||||||
seen = set()
|
seen = set()
|
||||||
for fld, target_name in cls._get_fields():
|
for fld, target_name in cls._get_fields():
|
||||||
@@ -279,9 +263,7 @@ class BaseConfig(
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _extract_dict(
|
def _extract_dict(cls, src: Dict[str, Any], data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
cls, src: Dict[str, Any], data: Dict[str, Any]
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Find all the items in data that match a target_field on this class,
|
"""Find all the items in data that match a target_field on this class,
|
||||||
and merge them with the data found in `src` for target_field, using the
|
and merge them with the data found in `src` for target_field, using the
|
||||||
field's specified merge behavior. Matching items will be removed from
|
field's specified merge behavior. Matching items will be removed from
|
||||||
@@ -313,29 +295,6 @@ class BaseConfig(
|
|||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def to_dict(
|
|
||||||
self,
|
|
||||||
omit_none: bool = True,
|
|
||||||
validate: bool = False,
|
|
||||||
*,
|
|
||||||
omit_hidden: bool = True,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
result = super().to_dict(omit_none=omit_none, validate=validate)
|
|
||||||
if omit_hidden and not omit_none:
|
|
||||||
for fld, target_field in self._get_fields():
|
|
||||||
if target_field not in result:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# if the field is not None, preserve it regardless of the
|
|
||||||
# setting. This is in line with existing behavior, but isn't
|
|
||||||
# an endorsement of it!
|
|
||||||
if result[target_field] is not None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not ShowBehavior.should_show(fld):
|
|
||||||
del result[target_field]
|
|
||||||
return result
|
|
||||||
|
|
||||||
def update_from(
|
def update_from(
|
||||||
self: T, data: Dict[str, Any], adapter_type: str, validate: bool = True
|
self: T, data: Dict[str, Any], adapter_type: str, validate: bool = True
|
||||||
) -> T:
|
) -> T:
|
||||||
@@ -344,7 +303,8 @@ class BaseConfig(
|
|||||||
"""
|
"""
|
||||||
# sadly, this is a circular import
|
# sadly, this is a circular import
|
||||||
from dbt.adapters.factory import get_config_class_by_name
|
from dbt.adapters.factory import get_config_class_by_name
|
||||||
dct = self.to_dict(omit_none=False, validate=False, omit_hidden=False)
|
|
||||||
|
dct = self.to_dict(omit_none=False)
|
||||||
|
|
||||||
adapter_config_cls = get_config_class_by_name(adapter_type)
|
adapter_config_cls = get_config_class_by_name(adapter_type)
|
||||||
|
|
||||||
@@ -358,21 +318,23 @@ class BaseConfig(
|
|||||||
dct.update(data)
|
dct.update(data)
|
||||||
|
|
||||||
# any validation failures must have come from the update
|
# any validation failures must have come from the update
|
||||||
return self.from_dict(dct, validate=validate)
|
if validate:
|
||||||
|
self.validate(dct)
|
||||||
|
return self.from_dict(dct)
|
||||||
|
|
||||||
def finalize_and_validate(self: T) -> T:
|
def finalize_and_validate(self: T) -> T:
|
||||||
# from_dict will validate for us
|
dct = self.to_dict(omit_none=False)
|
||||||
dct = self.to_dict(omit_none=False, validate=False)
|
self.validate(dct)
|
||||||
return self.from_dict(dct)
|
return self.from_dict(dct)
|
||||||
|
|
||||||
def replace(self, **kwargs):
|
def replace(self, **kwargs):
|
||||||
dct = self.to_dict(validate=False)
|
dct = self.to_dict(omit_none=True)
|
||||||
|
|
||||||
mapping = self.field_mapping()
|
mapping = self.field_mapping()
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
new_key = mapping.get(key, key)
|
new_key = mapping.get(key, key)
|
||||||
dct[new_key] = value
|
dct[new_key] = value
|
||||||
return self.from_dict(dct, validate=False)
|
return self.from_dict(dct)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -383,7 +345,7 @@ class SourceConfig(BaseConfig):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class NodeConfig(BaseConfig):
|
class NodeConfig(BaseConfig):
|
||||||
enabled: bool = True
|
enabled: bool = True
|
||||||
materialized: str = 'view'
|
materialized: str = "view"
|
||||||
persist_docs: Dict[str, Any] = field(default_factory=dict)
|
persist_docs: Dict[str, Any] = field(default_factory=dict)
|
||||||
post_hook: List[Hook] = field(
|
post_hook: List[Hook] = field(
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
@@ -424,212 +386,103 @@ class NodeConfig(BaseConfig):
|
|||||||
)
|
)
|
||||||
tags: Union[List[str], str] = field(
|
tags: Union[List[str], str] = field(
|
||||||
default_factory=list_str,
|
default_factory=list_str,
|
||||||
metadata=metas(ShowBehavior.Hide,
|
metadata=metas(
|
||||||
MergeBehavior.Append,
|
ShowBehavior.Hide, MergeBehavior.Append, CompareBehavior.Exclude
|
||||||
CompareBehavior.Exclude),
|
),
|
||||||
)
|
)
|
||||||
full_refresh: Optional[bool] = None
|
full_refresh: Optional[bool] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data, validate=True):
|
def __pre_deserialize__(cls, data):
|
||||||
|
data = super().__pre_deserialize__(data)
|
||||||
|
field_map = {"post-hook": "post_hook", "pre-hook": "pre_hook"}
|
||||||
|
# create a new dict because otherwise it gets overwritten in
|
||||||
|
# tests
|
||||||
|
new_dict = {}
|
||||||
|
for key in data:
|
||||||
|
new_dict[key] = data[key]
|
||||||
|
data = new_dict
|
||||||
for key in hooks.ModelHookType:
|
for key in hooks.ModelHookType:
|
||||||
if key in data:
|
if key in data:
|
||||||
data[key] = [hooks.get_hook_dict(h) for h in data[key]]
|
data[key] = [hooks.get_hook_dict(h) for h in data[key]]
|
||||||
return super().from_dict(data, validate=validate)
|
for field_name in field_map:
|
||||||
|
if field_name in data:
|
||||||
|
new_name = field_map[field_name]
|
||||||
|
data[new_name] = data.pop(field_name)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def __post_serialize__(self, dct):
|
||||||
|
dct = super().__post_serialize__(dct)
|
||||||
|
field_map = {"post_hook": "post-hook", "pre_hook": "pre-hook"}
|
||||||
|
for field_name in field_map:
|
||||||
|
if field_name in dct:
|
||||||
|
dct[field_map[field_name]] = dct.pop(field_name)
|
||||||
|
return dct
|
||||||
|
|
||||||
|
# this is still used by jsonschema validation
|
||||||
@classmethod
|
@classmethod
|
||||||
def field_mapping(cls):
|
def field_mapping(cls):
|
||||||
return {'post_hook': 'post-hook', 'pre_hook': 'pre-hook'}
|
return {"post_hook": "post-hook", "pre_hook": "pre-hook"}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SeedConfig(NodeConfig):
|
class SeedConfig(NodeConfig):
|
||||||
materialized: str = 'seed'
|
materialized: str = "seed"
|
||||||
quote_columns: Optional[bool] = None
|
quote_columns: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TestConfig(NodeConfig):
|
class TestConfig(NodeConfig):
|
||||||
materialized: str = 'test'
|
materialized: str = "test"
|
||||||
severity: Severity = Severity('ERROR')
|
severity: Severity = Severity("ERROR")
|
||||||
|
|
||||||
|
|
||||||
SnapshotVariants = Union[
|
|
||||||
'TimestampSnapshotConfig',
|
|
||||||
'CheckSnapshotConfig',
|
|
||||||
'GenericSnapshotConfig',
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _relevance_without_strategy(error: jsonschema.ValidationError):
|
|
||||||
# calculate the 'relevance' of an error the normal jsonschema way, except
|
|
||||||
# if the validator is in the 'strategy' field and its conflicting with the
|
|
||||||
# 'enum'. This suppresses `"'timestamp' is not one of ['check']` and such
|
|
||||||
if 'strategy' in error.path and error.validator in {'enum', 'not'}:
|
|
||||||
length = 1
|
|
||||||
else:
|
|
||||||
length = -len(error.path)
|
|
||||||
validator = error.validator
|
|
||||||
return length, validator not in {'anyOf', 'oneOf'}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SnapshotWrapper(JsonSchemaMixin):
|
|
||||||
"""This is a little wrapper to let us serialize/deserialize the
|
|
||||||
SnapshotVariants union.
|
|
||||||
"""
|
|
||||||
config: SnapshotVariants # mypy: ignore
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def validate(cls, data: Any):
|
|
||||||
config = data.get('config', {})
|
|
||||||
|
|
||||||
if config.get('strategy') == 'check':
|
|
||||||
schema = _validate_schema(CheckSnapshotConfig)
|
|
||||||
to_validate = config
|
|
||||||
|
|
||||||
elif config.get('strategy') == 'timestamp':
|
|
||||||
schema = _validate_schema(TimestampSnapshotConfig)
|
|
||||||
to_validate = config
|
|
||||||
|
|
||||||
else:
|
|
||||||
h_cls = cast(Hashable, cls)
|
|
||||||
schema = _validate_schema(h_cls)
|
|
||||||
to_validate = data
|
|
||||||
|
|
||||||
validator = jsonschema.Draft7Validator(schema)
|
|
||||||
|
|
||||||
error = jsonschema.exceptions.best_match(
|
|
||||||
validator.iter_errors(to_validate),
|
|
||||||
key=_relevance_without_strategy,
|
|
||||||
)
|
|
||||||
|
|
||||||
if error is not None:
|
|
||||||
raise ValidationError.create_from(error) from error
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EmptySnapshotConfig(NodeConfig):
|
class EmptySnapshotConfig(NodeConfig):
|
||||||
materialized: str = 'snapshot'
|
materialized: str = "snapshot"
|
||||||
|
|
||||||
|
|
||||||
@dataclass(init=False)
|
@dataclass
|
||||||
class SnapshotConfig(EmptySnapshotConfig):
|
class SnapshotConfig(EmptySnapshotConfig):
|
||||||
unique_key: str = field(init=False, metadata=dict(init_required=True))
|
strategy: Optional[str] = None
|
||||||
target_schema: str = field(init=False, metadata=dict(init_required=True))
|
unique_key: Optional[str] = None
|
||||||
|
target_schema: Optional[str] = None
|
||||||
target_database: Optional[str] = None
|
target_database: Optional[str] = None
|
||||||
|
updated_at: Optional[str] = None
|
||||||
def __init__(
|
check_cols: Optional[Union[str, List[str]]] = None
|
||||||
self,
|
|
||||||
unique_key: str,
|
|
||||||
target_schema: str,
|
|
||||||
target_database: Optional[str] = None,
|
|
||||||
**kwargs
|
|
||||||
) -> None:
|
|
||||||
self.unique_key = unique_key
|
|
||||||
self.target_schema = target_schema
|
|
||||||
self.target_database = target_database
|
|
||||||
# kwargs['materialized'] = materialized
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
# type hacks...
|
|
||||||
@classmethod
|
|
||||||
def _get_fields(cls) -> List[Tuple[Field, str]]: # type: ignore
|
|
||||||
fields: List[Tuple[Field, str]] = []
|
|
||||||
for old_field, name in super()._get_fields():
|
|
||||||
new_field = old_field
|
|
||||||
# tell hologram we're really an initvar
|
|
||||||
if old_field.metadata and old_field.metadata.get('init_required'):
|
|
||||||
new_field = field(init=True, metadata=old_field.metadata)
|
|
||||||
new_field.name = old_field.name
|
|
||||||
new_field.type = old_field.type
|
|
||||||
new_field._field_type = old_field._field_type # type: ignore
|
|
||||||
fields.append((new_field, name))
|
|
||||||
return fields
|
|
||||||
|
|
||||||
def finalize_and_validate(self: 'SnapshotConfig') -> SnapshotVariants:
|
|
||||||
data = self.to_dict()
|
|
||||||
return SnapshotWrapper.from_dict({'config': data}).config
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(init=False)
|
|
||||||
class GenericSnapshotConfig(SnapshotConfig):
|
|
||||||
strategy: str = field(init=False, metadata=dict(init_required=True))
|
|
||||||
|
|
||||||
def __init__(self, strategy: str, **kwargs) -> None:
|
|
||||||
self.strategy = strategy
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _collect_json_schema(
|
def validate(cls, data):
|
||||||
cls, definitions: Dict[str, Any]
|
super().validate(data)
|
||||||
) -> Dict[str, Any]:
|
if data.get("strategy") == "check":
|
||||||
# this is the method you want to override in hologram if you want
|
if not data.get("check_cols"):
|
||||||
# to do clever things about the json schema and have classes that
|
raise ValidationError(
|
||||||
# contain instances of your JsonSchemaMixin respect the change.
|
"A snapshot configured with the check strategy must "
|
||||||
schema = super()._collect_json_schema(definitions)
|
"specify a check_cols configuration."
|
||||||
|
)
|
||||||
|
if isinstance(data["check_cols"], str) and data["check_cols"] != "all":
|
||||||
|
raise ValidationError(
|
||||||
|
f"Invalid value for 'check_cols': {data['check_cols']}. "
|
||||||
|
"Expected 'all' or a list of strings."
|
||||||
|
)
|
||||||
|
|
||||||
# Instead of just the strategy we'd calculate normally, say
|
elif data.get("strategy") == "timestamp":
|
||||||
# "this strategy except none of our specialization strategies".
|
if not data.get("updated_at"):
|
||||||
strategies = [schema['properties']['strategy']]
|
raise ValidationError(
|
||||||
for specialization in (TimestampSnapshotConfig, CheckSnapshotConfig):
|
"A snapshot configured with the timestamp strategy "
|
||||||
strategies.append(
|
"must specify an updated_at configuration."
|
||||||
{'not': specialization.json_schema()['properties']['strategy']}
|
)
|
||||||
)
|
if data.get("check_cols"):
|
||||||
|
raise ValidationError(
|
||||||
|
"A 'timestamp' snapshot should not have 'check_cols'"
|
||||||
|
)
|
||||||
|
# If the strategy is not 'check' or 'timestamp' it's a custom strategy,
|
||||||
|
# formerly supported with GenericSnapshotConfig
|
||||||
|
|
||||||
schema['properties']['strategy'] = {
|
def finalize_and_validate(self):
|
||||||
'allOf': strategies
|
data = self.to_dict(omit_none=True)
|
||||||
}
|
self.validate(data)
|
||||||
return schema
|
return self.from_dict(data)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(init=False)
|
|
||||||
class TimestampSnapshotConfig(SnapshotConfig):
|
|
||||||
strategy: str = field(
|
|
||||||
init=False,
|
|
||||||
metadata=dict(
|
|
||||||
restrict=[str(SnapshotStrategy.Timestamp)],
|
|
||||||
init_required=True,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
updated_at: str = field(init=False, metadata=dict(init_required=True))
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, strategy: str, updated_at: str, **kwargs
|
|
||||||
) -> None:
|
|
||||||
self.strategy = strategy
|
|
||||||
self.updated_at = updated_at
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(init=False)
|
|
||||||
class CheckSnapshotConfig(SnapshotConfig):
|
|
||||||
strategy: str = field(
|
|
||||||
init=False,
|
|
||||||
metadata=dict(
|
|
||||||
restrict=[str(SnapshotStrategy.Check)],
|
|
||||||
init_required=True,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
# TODO: is there a way to get this to accept tuples of strings? Adding
|
|
||||||
# `Tuple[str, ...]` to the list of types results in this:
|
|
||||||
# ['email'] is valid under each of {'type': 'array', 'items':
|
|
||||||
# {'type': 'string'}}, {'type': 'array', 'items': {'type': 'string'}}
|
|
||||||
# but without it, parsing gets upset about values like `('email',)`
|
|
||||||
# maybe hologram itself should support this behavior? It's not like tuples
|
|
||||||
# are meaningful in json
|
|
||||||
check_cols: Union[All, List[str]] = field(
|
|
||||||
init=False,
|
|
||||||
metadata=dict(init_required=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, strategy: str, check_cols: Union[All, List[str]],
|
|
||||||
**kwargs
|
|
||||||
) -> None:
|
|
||||||
self.strategy = strategy
|
|
||||||
self.check_cols = check_cols
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
RESOURCE_TYPES: Dict[NodeType, Type[BaseConfig]] = {
|
RESOURCE_TYPES: Dict[NodeType, Type[BaseConfig]] = {
|
||||||
@@ -644,9 +497,7 @@ RESOURCE_TYPES: Dict[NodeType, Type[BaseConfig]] = {
|
|||||||
# base resource types are like resource types, except nothing has mandatory
|
# base resource types are like resource types, except nothing has mandatory
|
||||||
# configs.
|
# configs.
|
||||||
BASE_RESOURCE_TYPES: Dict[NodeType, Type[BaseConfig]] = RESOURCE_TYPES.copy()
|
BASE_RESOURCE_TYPES: Dict[NodeType, Type[BaseConfig]] = RESOURCE_TYPES.copy()
|
||||||
BASE_RESOURCE_TYPES.update({
|
BASE_RESOURCE_TYPES.update({NodeType.Snapshot: EmptySnapshotConfig})
|
||||||
NodeType.Snapshot: EmptySnapshotConfig
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
def get_config_for(resource_type: NodeType, base=False) -> Type[BaseConfig]:
|
def get_config_for(resource_type: NodeType, base=False) -> Type[BaseConfig]:
|
||||||
|
|||||||
@@ -13,17 +13,27 @@ from typing import (
|
|||||||
TypeVar,
|
TypeVar,
|
||||||
)
|
)
|
||||||
|
|
||||||
from hologram import JsonSchemaMixin
|
from dbt.dataclass_schema import dbtClassMixin, ExtensibleDbtClassMixin
|
||||||
from hologram.helpers import ExtensibleJsonSchemaMixin
|
|
||||||
|
|
||||||
from dbt.clients.system import write_file
|
from dbt.clients.system import write_file
|
||||||
from dbt.contracts.files import FileHash, MAXIMUM_SEED_SIZE_NAME
|
from dbt.contracts.files import FileHash, MAXIMUM_SEED_SIZE_NAME
|
||||||
from dbt.contracts.graph.unparsed import (
|
from dbt.contracts.graph.unparsed import (
|
||||||
UnparsedNode, UnparsedDocumentation, Quoting, Docs,
|
UnparsedNode,
|
||||||
UnparsedBaseNode, FreshnessThreshold, ExternalTable,
|
UnparsedDocumentation,
|
||||||
HasYamlMetadata, MacroArgument, UnparsedSourceDefinition,
|
Quoting,
|
||||||
UnparsedSourceTableDefinition, UnparsedColumn, TestDef,
|
Docs,
|
||||||
ExposureOwner, ExposureType, MaturityType
|
UnparsedBaseNode,
|
||||||
|
FreshnessThreshold,
|
||||||
|
ExternalTable,
|
||||||
|
HasYamlMetadata,
|
||||||
|
MacroArgument,
|
||||||
|
UnparsedSourceDefinition,
|
||||||
|
UnparsedSourceTableDefinition,
|
||||||
|
UnparsedColumn,
|
||||||
|
TestDef,
|
||||||
|
ExposureOwner,
|
||||||
|
ExposureType,
|
||||||
|
MaturityType,
|
||||||
)
|
)
|
||||||
from dbt.contracts.util import Replaceable, AdditionalPropertiesMixin
|
from dbt.contracts.util import Replaceable, AdditionalPropertiesMixin
|
||||||
from dbt.exceptions import warn_or_error
|
from dbt.exceptions import warn_or_error
|
||||||
@@ -38,24 +48,14 @@ from .model_config import (
|
|||||||
TestConfig,
|
TestConfig,
|
||||||
SourceConfig,
|
SourceConfig,
|
||||||
EmptySnapshotConfig,
|
EmptySnapshotConfig,
|
||||||
SnapshotVariants,
|
SnapshotConfig,
|
||||||
)
|
|
||||||
# import these 3 so the SnapshotVariants forward ref works.
|
|
||||||
from .model_config import ( # noqa
|
|
||||||
TimestampSnapshotConfig,
|
|
||||||
CheckSnapshotConfig,
|
|
||||||
GenericSnapshotConfig,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ColumnInfo(
|
class ColumnInfo(AdditionalPropertiesMixin, ExtensibleDbtClassMixin, Replaceable):
|
||||||
AdditionalPropertiesMixin,
|
|
||||||
ExtensibleJsonSchemaMixin,
|
|
||||||
Replaceable
|
|
||||||
):
|
|
||||||
name: str
|
name: str
|
||||||
description: str = ''
|
description: str = ""
|
||||||
meta: Dict[str, Any] = field(default_factory=dict)
|
meta: Dict[str, Any] = field(default_factory=dict)
|
||||||
data_type: Optional[str] = None
|
data_type: Optional[str] = None
|
||||||
quote: Optional[bool] = None
|
quote: Optional[bool] = None
|
||||||
@@ -64,20 +64,20 @@ class ColumnInfo(
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class HasFqn(JsonSchemaMixin, Replaceable):
|
class HasFqn(dbtClassMixin, Replaceable):
|
||||||
fqn: List[str]
|
fqn: List[str]
|
||||||
|
|
||||||
def same_fqn(self, other: 'HasFqn') -> bool:
|
def same_fqn(self, other: "HasFqn") -> bool:
|
||||||
return self.fqn == other.fqn
|
return self.fqn == other.fqn
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class HasUniqueID(JsonSchemaMixin, Replaceable):
|
class HasUniqueID(dbtClassMixin, Replaceable):
|
||||||
unique_id: str
|
unique_id: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MacroDependsOn(JsonSchemaMixin, Replaceable):
|
class MacroDependsOn(dbtClassMixin, Replaceable):
|
||||||
macros: List[str] = field(default_factory=list)
|
macros: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
# 'in' on lists is O(n) so this is O(n^2) for # of macros
|
# 'in' on lists is O(n) so this is O(n^2) for # of macros
|
||||||
@@ -96,12 +96,22 @@ class DependsOn(MacroDependsOn):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class HasRelationMetadata(JsonSchemaMixin, Replaceable):
|
class HasRelationMetadata(dbtClassMixin, Replaceable):
|
||||||
database: Optional[str]
|
database: Optional[str]
|
||||||
schema: str
|
schema: str
|
||||||
|
|
||||||
|
# Can't set database to None like it ought to be
|
||||||
|
# because it messes up the subclasses and default parameters
|
||||||
|
# so hack it here
|
||||||
|
@classmethod
|
||||||
|
def __pre_deserialize__(cls, data):
|
||||||
|
data = super().__pre_deserialize__(data)
|
||||||
|
if "database" not in data:
|
||||||
|
data["database"] = None
|
||||||
|
return data
|
||||||
|
|
||||||
class ParsedNodeMixins(JsonSchemaMixin):
|
|
||||||
|
class ParsedNodeMixins(dbtClassMixin):
|
||||||
resource_type: NodeType
|
resource_type: NodeType
|
||||||
depends_on: DependsOn
|
depends_on: DependsOn
|
||||||
config: NodeConfig
|
config: NodeConfig
|
||||||
@@ -112,7 +122,7 @@ class ParsedNodeMixins(JsonSchemaMixin):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def is_ephemeral(self):
|
def is_ephemeral(self):
|
||||||
return self.config.materialized == 'ephemeral'
|
return self.config.materialized == "ephemeral"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_ephemeral_model(self):
|
def is_ephemeral_model(self):
|
||||||
@@ -122,7 +132,7 @@ class ParsedNodeMixins(JsonSchemaMixin):
|
|||||||
def depends_on_nodes(self):
|
def depends_on_nodes(self):
|
||||||
return self.depends_on.nodes
|
return self.depends_on.nodes
|
||||||
|
|
||||||
def patch(self, patch: 'ParsedNodePatch'):
|
def patch(self, patch: "ParsedNodePatch"):
|
||||||
"""Given a ParsedNodePatch, add the new information to the node."""
|
"""Given a ParsedNodePatch, add the new information to the node."""
|
||||||
# explicitly pick out the parts to update so we don't inadvertently
|
# explicitly pick out the parts to update so we don't inadvertently
|
||||||
# step on the model name or anything
|
# step on the model name or anything
|
||||||
@@ -132,8 +142,12 @@ class ParsedNodeMixins(JsonSchemaMixin):
|
|||||||
self.meta = patch.meta
|
self.meta = patch.meta
|
||||||
self.docs = patch.docs
|
self.docs = patch.docs
|
||||||
if flags.STRICT_MODE:
|
if flags.STRICT_MODE:
|
||||||
assert isinstance(self, JsonSchemaMixin)
|
# It seems odd that an instance can be invalid
|
||||||
self.to_dict(validate=True, omit_none=False)
|
# Maybe there should be validation or restrictions
|
||||||
|
# elsewhere?
|
||||||
|
assert isinstance(self, dbtClassMixin)
|
||||||
|
dct = self.to_dict(omit_none=False)
|
||||||
|
self.validate(dct)
|
||||||
|
|
||||||
def get_materialization(self):
|
def get_materialization(self):
|
||||||
return self.config.materialized
|
return self.config.materialized
|
||||||
@@ -144,11 +158,7 @@ class ParsedNodeMixins(JsonSchemaMixin):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ParsedNodeMandatory(
|
class ParsedNodeMandatory(
|
||||||
UnparsedNode,
|
UnparsedNode, HasUniqueID, HasFqn, HasRelationMetadata, Replaceable
|
||||||
HasUniqueID,
|
|
||||||
HasFqn,
|
|
||||||
HasRelationMetadata,
|
|
||||||
Replaceable
|
|
||||||
):
|
):
|
||||||
alias: str
|
alias: str
|
||||||
checksum: FileHash
|
checksum: FileHash
|
||||||
@@ -165,7 +175,7 @@ class ParsedNodeDefaults(ParsedNodeMandatory):
|
|||||||
refs: List[List[str]] = field(default_factory=list)
|
refs: List[List[str]] = field(default_factory=list)
|
||||||
sources: List[List[Any]] = field(default_factory=list)
|
sources: List[List[Any]] = field(default_factory=list)
|
||||||
depends_on: DependsOn = field(default_factory=DependsOn)
|
depends_on: DependsOn = field(default_factory=DependsOn)
|
||||||
description: str = field(default='')
|
description: str = field(default="")
|
||||||
columns: Dict[str, ColumnInfo] = field(default_factory=dict)
|
columns: Dict[str, ColumnInfo] = field(default_factory=dict)
|
||||||
meta: Dict[str, Any] = field(default_factory=dict)
|
meta: Dict[str, Any] = field(default_factory=dict)
|
||||||
docs: Docs = field(default_factory=Docs)
|
docs: Docs = field(default_factory=Docs)
|
||||||
@@ -175,31 +185,28 @@ class ParsedNodeDefaults(ParsedNodeMandatory):
|
|||||||
unrendered_config: Dict[str, Any] = field(default_factory=dict)
|
unrendered_config: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
def write_node(self, target_path: str, subdirectory: str, payload: str):
|
def write_node(self, target_path: str, subdirectory: str, payload: str):
|
||||||
if (os.path.basename(self.path) ==
|
if os.path.basename(self.path) == os.path.basename(self.original_file_path):
|
||||||
os.path.basename(self.original_file_path)):
|
|
||||||
# One-to-one relationship of nodes to files.
|
# One-to-one relationship of nodes to files.
|
||||||
path = self.original_file_path
|
path = self.original_file_path
|
||||||
else:
|
else:
|
||||||
# Many-to-one relationship of nodes to files.
|
# Many-to-one relationship of nodes to files.
|
||||||
path = os.path.join(self.original_file_path, self.path)
|
path = os.path.join(self.original_file_path, self.path)
|
||||||
full_path = os.path.join(
|
full_path = os.path.join(target_path, subdirectory, self.package_name, path)
|
||||||
target_path, subdirectory, self.package_name, path
|
|
||||||
)
|
|
||||||
|
|
||||||
write_file(full_path, payload)
|
write_file(full_path, payload)
|
||||||
return full_path
|
return full_path
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar('T', bound='ParsedNode')
|
T = TypeVar("T", bound="ParsedNode")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ParsedNode(ParsedNodeDefaults, ParsedNodeMixins):
|
class ParsedNode(ParsedNodeDefaults, ParsedNodeMixins):
|
||||||
def _persist_column_docs(self) -> bool:
|
def _persist_column_docs(self) -> bool:
|
||||||
return bool(self.config.persist_docs.get('columns'))
|
return bool(self.config.persist_docs.get("columns"))
|
||||||
|
|
||||||
def _persist_relation_docs(self) -> bool:
|
def _persist_relation_docs(self) -> bool:
|
||||||
return bool(self.config.persist_docs.get('relation'))
|
return bool(self.config.persist_docs.get("relation"))
|
||||||
|
|
||||||
def same_body(self: T, other: T) -> bool:
|
def same_body(self: T, other: T) -> bool:
|
||||||
return self.raw_sql == other.raw_sql
|
return self.raw_sql == other.raw_sql
|
||||||
@@ -214,9 +221,7 @@ class ParsedNode(ParsedNodeDefaults, ParsedNodeMixins):
|
|||||||
|
|
||||||
if self._persist_column_docs():
|
if self._persist_column_docs():
|
||||||
# assert other._persist_column_docs()
|
# assert other._persist_column_docs()
|
||||||
column_descriptions = {
|
column_descriptions = {k: v.description for k, v in self.columns.items()}
|
||||||
k: v.description for k, v in self.columns.items()
|
|
||||||
}
|
|
||||||
other_column_descriptions = {
|
other_column_descriptions = {
|
||||||
k: v.description for k, v in other.columns.items()
|
k: v.description for k, v in other.columns.items()
|
||||||
}
|
}
|
||||||
@@ -230,7 +235,7 @@ class ParsedNode(ParsedNodeDefaults, ParsedNodeMixins):
|
|||||||
# compares the configured value, rather than the ultimate value (so
|
# compares the configured value, rather than the ultimate value (so
|
||||||
# generate_*_name and unset values derived from the target are
|
# generate_*_name and unset values derived from the target are
|
||||||
# ignored)
|
# ignored)
|
||||||
keys = ('database', 'schema', 'alias')
|
keys = ("database", "schema", "alias")
|
||||||
for key in keys:
|
for key in keys:
|
||||||
mine = self.unrendered_config.get(key)
|
mine = self.unrendered_config.get(key)
|
||||||
others = other.unrendered_config.get(key)
|
others = other.unrendered_config.get(key)
|
||||||
@@ -249,36 +254,34 @@ class ParsedNode(ParsedNodeDefaults, ParsedNodeMixins):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
return (
|
return (
|
||||||
self.same_body(old) and
|
self.same_body(old)
|
||||||
self.same_config(old) and
|
and self.same_config(old)
|
||||||
self.same_persisted_description(old) and
|
and self.same_persisted_description(old)
|
||||||
self.same_fqn(old) and
|
and self.same_fqn(old)
|
||||||
self.same_database_representation(old) and
|
and self.same_database_representation(old)
|
||||||
True
|
and True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ParsedAnalysisNode(ParsedNode):
|
class ParsedAnalysisNode(ParsedNode):
|
||||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Analysis]})
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Analysis]})
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ParsedHookNode(ParsedNode):
|
class ParsedHookNode(ParsedNode):
|
||||||
resource_type: NodeType = field(
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Operation]})
|
||||||
metadata={'restrict': [NodeType.Operation]}
|
|
||||||
)
|
|
||||||
index: Optional[int] = None
|
index: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ParsedModelNode(ParsedNode):
|
class ParsedModelNode(ParsedNode):
|
||||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Model]})
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Model]})
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ParsedRPCNode(ParsedNode):
|
class ParsedRPCNode(ParsedNode):
|
||||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.RPCCall]})
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.RPCCall]})
|
||||||
|
|
||||||
|
|
||||||
def same_seeds(first: ParsedNode, second: ParsedNode) -> bool:
|
def same_seeds(first: ParsedNode, second: ParsedNode) -> bool:
|
||||||
@@ -288,31 +291,31 @@ def same_seeds(first: ParsedNode, second: ParsedNode) -> bool:
|
|||||||
# if the current checksum is a path, we want to log a warning.
|
# if the current checksum is a path, we want to log a warning.
|
||||||
result = first.checksum == second.checksum
|
result = first.checksum == second.checksum
|
||||||
|
|
||||||
if first.checksum.name == 'path':
|
if first.checksum.name == "path":
|
||||||
msg: str
|
msg: str
|
||||||
if second.checksum.name != 'path':
|
if second.checksum.name != "path":
|
||||||
msg = (
|
msg = (
|
||||||
f'Found a seed ({first.package_name}.{first.name}) '
|
f"Found a seed ({first.package_name}.{first.name}) "
|
||||||
f'>{MAXIMUM_SEED_SIZE_NAME} in size. The previous file was '
|
f">{MAXIMUM_SEED_SIZE_NAME} in size. The previous file was "
|
||||||
f'<={MAXIMUM_SEED_SIZE_NAME}, so it has changed'
|
f"<={MAXIMUM_SEED_SIZE_NAME}, so it has changed"
|
||||||
)
|
)
|
||||||
elif result:
|
elif result:
|
||||||
msg = (
|
msg = (
|
||||||
f'Found a seed ({first.package_name}.{first.name}) '
|
f"Found a seed ({first.package_name}.{first.name}) "
|
||||||
f'>{MAXIMUM_SEED_SIZE_NAME} in size at the same path, dbt '
|
f">{MAXIMUM_SEED_SIZE_NAME} in size at the same path, dbt "
|
||||||
f'cannot tell if it has changed: assuming they are the same'
|
f"cannot tell if it has changed: assuming they are the same"
|
||||||
)
|
)
|
||||||
elif not result:
|
elif not result:
|
||||||
msg = (
|
msg = (
|
||||||
f'Found a seed ({first.package_name}.{first.name}) '
|
f"Found a seed ({first.package_name}.{first.name}) "
|
||||||
f'>{MAXIMUM_SEED_SIZE_NAME} in size. The previous file was in '
|
f">{MAXIMUM_SEED_SIZE_NAME} in size. The previous file was in "
|
||||||
f'a different location, assuming it has changed'
|
f"a different location, assuming it has changed"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
msg = (
|
msg = (
|
||||||
f'Found a seed ({first.package_name}.{first.name}) '
|
f"Found a seed ({first.package_name}.{first.name}) "
|
||||||
f'>{MAXIMUM_SEED_SIZE_NAME} in size. The previous file had a '
|
f">{MAXIMUM_SEED_SIZE_NAME} in size. The previous file had a "
|
||||||
f'checksum type of {second.checksum.name}, so it has changed'
|
f"checksum type of {second.checksum.name}, so it has changed"
|
||||||
)
|
)
|
||||||
warn_or_error(msg, node=first)
|
warn_or_error(msg, node=first)
|
||||||
|
|
||||||
@@ -322,7 +325,7 @@ def same_seeds(first: ParsedNode, second: ParsedNode) -> bool:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ParsedSeedNode(ParsedNode):
|
class ParsedSeedNode(ParsedNode):
|
||||||
# keep this in sync with CompiledSeedNode!
|
# keep this in sync with CompiledSeedNode!
|
||||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Seed]})
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Seed]})
|
||||||
config: SeedConfig = field(default_factory=SeedConfig)
|
config: SeedConfig = field(default_factory=SeedConfig)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -335,34 +338,33 @@ class ParsedSeedNode(ParsedNode):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TestMetadata(JsonSchemaMixin, Replaceable):
|
class TestMetadata(dbtClassMixin, Replaceable):
|
||||||
namespace: Optional[str]
|
|
||||||
name: str
|
name: str
|
||||||
kwargs: Dict[str, Any]
|
kwargs: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
namespace: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class HasTestMetadata(JsonSchemaMixin):
|
class HasTestMetadata(dbtClassMixin):
|
||||||
test_metadata: TestMetadata
|
test_metadata: TestMetadata
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ParsedDataTestNode(ParsedNode):
|
class ParsedDataTestNode(ParsedNode):
|
||||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Test]})
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Test]})
|
||||||
config: TestConfig = field(default_factory=TestConfig)
|
config: TestConfig = field(default_factory=TestConfig)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ParsedSchemaTestNode(ParsedNode, HasTestMetadata):
|
class ParsedSchemaTestNode(ParsedNode, HasTestMetadata):
|
||||||
# keep this in sync with CompiledSchemaTestNode!
|
# keep this in sync with CompiledSchemaTestNode!
|
||||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Test]})
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Test]})
|
||||||
column_name: Optional[str] = None
|
column_name: Optional[str] = None
|
||||||
config: TestConfig = field(default_factory=TestConfig)
|
config: TestConfig = field(default_factory=TestConfig)
|
||||||
|
|
||||||
def same_config(self, other) -> bool:
|
def same_config(self, other) -> bool:
|
||||||
return (
|
return self.unrendered_config.get("severity") == other.unrendered_config.get(
|
||||||
self.unrendered_config.get('severity') ==
|
"severity"
|
||||||
other.unrendered_config.get('severity')
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def same_column_name(self, other) -> bool:
|
def same_column_name(self, other) -> bool:
|
||||||
@@ -372,11 +374,7 @@ class ParsedSchemaTestNode(ParsedNode, HasTestMetadata):
|
|||||||
if other is None:
|
if other is None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return (
|
return self.same_config(other) and self.same_fqn(other) and True
|
||||||
self.same_config(other) and
|
|
||||||
self.same_fqn(other) and
|
|
||||||
True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -387,14 +385,14 @@ class IntermediateSnapshotNode(ParsedNode):
|
|||||||
# defined in config blocks. To fix that, we have an intermediate type that
|
# defined in config blocks. To fix that, we have an intermediate type that
|
||||||
# uses a regular node config, which the snapshot parser will then convert
|
# uses a regular node config, which the snapshot parser will then convert
|
||||||
# into a full ParsedSnapshotNode after rendering.
|
# into a full ParsedSnapshotNode after rendering.
|
||||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Snapshot]})
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Snapshot]})
|
||||||
config: EmptySnapshotConfig = field(default_factory=EmptySnapshotConfig)
|
config: EmptySnapshotConfig = field(default_factory=EmptySnapshotConfig)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ParsedSnapshotNode(ParsedNode):
|
class ParsedSnapshotNode(ParsedNode):
|
||||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Snapshot]})
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Snapshot]})
|
||||||
config: SnapshotVariants
|
config: SnapshotConfig
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -422,12 +420,12 @@ class ParsedMacroPatch(ParsedPatch):
|
|||||||
class ParsedMacro(UnparsedBaseNode, HasUniqueID):
|
class ParsedMacro(UnparsedBaseNode, HasUniqueID):
|
||||||
name: str
|
name: str
|
||||||
macro_sql: str
|
macro_sql: str
|
||||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Macro]})
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Macro]})
|
||||||
# TODO: can macros even have tags?
|
# TODO: can macros even have tags?
|
||||||
tags: List[str] = field(default_factory=list)
|
tags: List[str] = field(default_factory=list)
|
||||||
# TODO: is this ever populated?
|
# TODO: is this ever populated?
|
||||||
depends_on: MacroDependsOn = field(default_factory=MacroDependsOn)
|
depends_on: MacroDependsOn = field(default_factory=MacroDependsOn)
|
||||||
description: str = ''
|
description: str = ""
|
||||||
meta: Dict[str, Any] = field(default_factory=dict)
|
meta: Dict[str, Any] = field(default_factory=dict)
|
||||||
docs: Docs = field(default_factory=Docs)
|
docs: Docs = field(default_factory=Docs)
|
||||||
patch_path: Optional[str] = None
|
patch_path: Optional[str] = None
|
||||||
@@ -443,10 +441,12 @@ class ParsedMacro(UnparsedBaseNode, HasUniqueID):
|
|||||||
self.docs = patch.docs
|
self.docs = patch.docs
|
||||||
self.arguments = patch.arguments
|
self.arguments = patch.arguments
|
||||||
if flags.STRICT_MODE:
|
if flags.STRICT_MODE:
|
||||||
assert isinstance(self, JsonSchemaMixin)
|
# What does this actually validate?
|
||||||
self.to_dict(validate=True, omit_none=False)
|
assert isinstance(self, dbtClassMixin)
|
||||||
|
dct = self.to_dict(omit_none=False)
|
||||||
|
self.validate(dct)
|
||||||
|
|
||||||
def same_contents(self, other: Optional['ParsedMacro']) -> bool:
|
def same_contents(self, other: Optional["ParsedMacro"]) -> bool:
|
||||||
if other is None:
|
if other is None:
|
||||||
return False
|
return False
|
||||||
# the only thing that makes one macro different from another with the
|
# the only thing that makes one macro different from another with the
|
||||||
@@ -463,7 +463,7 @@ class ParsedDocumentation(UnparsedDocumentation, HasUniqueID):
|
|||||||
def search_name(self):
|
def search_name(self):
|
||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
def same_contents(self, other: Optional['ParsedDocumentation']) -> bool:
|
def same_contents(self, other: Optional["ParsedDocumentation"]) -> bool:
|
||||||
if other is None:
|
if other is None:
|
||||||
return False
|
return False
|
||||||
# the only thing that makes one doc different from another with the
|
# the only thing that makes one doc different from another with the
|
||||||
@@ -482,11 +482,11 @@ def normalize_test(testdef: TestDef) -> Dict[str, Any]:
|
|||||||
class UnpatchedSourceDefinition(UnparsedBaseNode, HasUniqueID, HasFqn):
|
class UnpatchedSourceDefinition(UnparsedBaseNode, HasUniqueID, HasFqn):
|
||||||
source: UnparsedSourceDefinition
|
source: UnparsedSourceDefinition
|
||||||
table: UnparsedSourceTableDefinition
|
table: UnparsedSourceTableDefinition
|
||||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Source]})
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Source]})
|
||||||
patch_path: Optional[Path] = None
|
patch_path: Optional[Path] = None
|
||||||
|
|
||||||
def get_full_source_name(self):
|
def get_full_source_name(self):
|
||||||
return f'{self.source.name}_{self.table.name}'
|
return f"{self.source.name}_{self.table.name}"
|
||||||
|
|
||||||
def get_source_representation(self):
|
def get_source_representation(self):
|
||||||
return f'source("{self.source.name}", "{self.table.name}")'
|
return f'source("{self.source.name}", "{self.table.name}")'
|
||||||
@@ -511,9 +511,7 @@ class UnpatchedSourceDefinition(UnparsedBaseNode, HasUniqueID, HasFqn):
|
|||||||
else:
|
else:
|
||||||
return self.table.columns
|
return self.table.columns
|
||||||
|
|
||||||
def get_tests(
|
def get_tests(self) -> Iterator[Tuple[Dict[str, Any], Optional[UnparsedColumn]]]:
|
||||||
self
|
|
||||||
) -> Iterator[Tuple[Dict[str, Any], Optional[UnparsedColumn]]]:
|
|
||||||
for test in self.tests:
|
for test in self.tests:
|
||||||
yield normalize_test(test), None
|
yield normalize_test(test), None
|
||||||
|
|
||||||
@@ -532,22 +530,19 @@ class UnpatchedSourceDefinition(UnparsedBaseNode, HasUniqueID, HasFqn):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ParsedSourceDefinition(
|
class ParsedSourceDefinition(
|
||||||
UnparsedBaseNode,
|
UnparsedBaseNode, HasUniqueID, HasRelationMetadata, HasFqn
|
||||||
HasUniqueID,
|
|
||||||
HasRelationMetadata,
|
|
||||||
HasFqn
|
|
||||||
):
|
):
|
||||||
name: str
|
name: str
|
||||||
source_name: str
|
source_name: str
|
||||||
source_description: str
|
source_description: str
|
||||||
loader: str
|
loader: str
|
||||||
identifier: str
|
identifier: str
|
||||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Source]})
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Source]})
|
||||||
quoting: Quoting = field(default_factory=Quoting)
|
quoting: Quoting = field(default_factory=Quoting)
|
||||||
loaded_at_field: Optional[str] = None
|
loaded_at_field: Optional[str] = None
|
||||||
freshness: Optional[FreshnessThreshold] = None
|
freshness: Optional[FreshnessThreshold] = None
|
||||||
external: Optional[ExternalTable] = None
|
external: Optional[ExternalTable] = None
|
||||||
description: str = ''
|
description: str = ""
|
||||||
columns: Dict[str, ColumnInfo] = field(default_factory=dict)
|
columns: Dict[str, ColumnInfo] = field(default_factory=dict)
|
||||||
meta: Dict[str, Any] = field(default_factory=dict)
|
meta: Dict[str, Any] = field(default_factory=dict)
|
||||||
source_meta: Dict[str, Any] = field(default_factory=dict)
|
source_meta: Dict[str, Any] = field(default_factory=dict)
|
||||||
@@ -557,36 +552,34 @@ class ParsedSourceDefinition(
|
|||||||
unrendered_config: Dict[str, Any] = field(default_factory=dict)
|
unrendered_config: Dict[str, Any] = field(default_factory=dict)
|
||||||
relation_name: Optional[str] = None
|
relation_name: Optional[str] = None
|
||||||
|
|
||||||
def same_database_representation(
|
def same_database_representation(self, other: "ParsedSourceDefinition") -> bool:
|
||||||
self, other: 'ParsedSourceDefinition'
|
|
||||||
) -> bool:
|
|
||||||
return (
|
return (
|
||||||
self.database == other.database and
|
self.database == other.database
|
||||||
self.schema == other.schema and
|
and self.schema == other.schema
|
||||||
self.identifier == other.identifier and
|
and self.identifier == other.identifier
|
||||||
True
|
and True
|
||||||
)
|
)
|
||||||
|
|
||||||
def same_quoting(self, other: 'ParsedSourceDefinition') -> bool:
|
def same_quoting(self, other: "ParsedSourceDefinition") -> bool:
|
||||||
return self.quoting == other.quoting
|
return self.quoting == other.quoting
|
||||||
|
|
||||||
def same_freshness(self, other: 'ParsedSourceDefinition') -> bool:
|
def same_freshness(self, other: "ParsedSourceDefinition") -> bool:
|
||||||
return (
|
return (
|
||||||
self.freshness == other.freshness and
|
self.freshness == other.freshness
|
||||||
self.loaded_at_field == other.loaded_at_field and
|
and self.loaded_at_field == other.loaded_at_field
|
||||||
True
|
and True
|
||||||
)
|
)
|
||||||
|
|
||||||
def same_external(self, other: 'ParsedSourceDefinition') -> bool:
|
def same_external(self, other: "ParsedSourceDefinition") -> bool:
|
||||||
return self.external == other.external
|
return self.external == other.external
|
||||||
|
|
||||||
def same_config(self, old: 'ParsedSourceDefinition') -> bool:
|
def same_config(self, old: "ParsedSourceDefinition") -> bool:
|
||||||
return self.config.same_contents(
|
return self.config.same_contents(
|
||||||
self.unrendered_config,
|
self.unrendered_config,
|
||||||
old.unrendered_config,
|
old.unrendered_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
def same_contents(self, old: Optional['ParsedSourceDefinition']) -> bool:
|
def same_contents(self, old: Optional["ParsedSourceDefinition"]) -> bool:
|
||||||
# existing when it didn't before is a change!
|
# existing when it didn't before is a change!
|
||||||
if old is None:
|
if old is None:
|
||||||
return True
|
return True
|
||||||
@@ -600,17 +593,17 @@ class ParsedSourceDefinition(
|
|||||||
# metadata/tags changes are not "changes"
|
# metadata/tags changes are not "changes"
|
||||||
# patching/description changes are not "changes"
|
# patching/description changes are not "changes"
|
||||||
return (
|
return (
|
||||||
self.same_database_representation(old) and
|
self.same_database_representation(old)
|
||||||
self.same_fqn(old) and
|
and self.same_fqn(old)
|
||||||
self.same_config(old) and
|
and self.same_config(old)
|
||||||
self.same_quoting(old) and
|
and self.same_quoting(old)
|
||||||
self.same_freshness(old) and
|
and self.same_freshness(old)
|
||||||
self.same_external(old) and
|
and self.same_external(old)
|
||||||
True
|
and True
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_full_source_name(self):
|
def get_full_source_name(self):
|
||||||
return f'{self.source_name}_{self.name}'
|
return f"{self.source_name}_{self.name}"
|
||||||
|
|
||||||
def get_source_representation(self):
|
def get_source_representation(self):
|
||||||
return f'source("{self.source.name}", "{self.table.name}")'
|
return f'source("{self.source.name}", "{self.table.name}")'
|
||||||
@@ -645,7 +638,7 @@ class ParsedSourceDefinition(
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def search_name(self):
|
def search_name(self):
|
||||||
return f'{self.source_name}.{self.name}'
|
return f"{self.source_name}.{self.name}"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -654,7 +647,7 @@ class ParsedExposure(UnparsedBaseNode, HasUniqueID, HasFqn):
|
|||||||
type: ExposureType
|
type: ExposureType
|
||||||
owner: ExposureOwner
|
owner: ExposureOwner
|
||||||
resource_type: NodeType = NodeType.Exposure
|
resource_type: NodeType = NodeType.Exposure
|
||||||
description: str = ''
|
description: str = ""
|
||||||
maturity: Optional[MaturityType] = None
|
maturity: Optional[MaturityType] = None
|
||||||
url: Optional[str] = None
|
url: Optional[str] = None
|
||||||
depends_on: DependsOn = field(default_factory=DependsOn)
|
depends_on: DependsOn = field(default_factory=DependsOn)
|
||||||
@@ -674,38 +667,38 @@ class ParsedExposure(UnparsedBaseNode, HasUniqueID, HasFqn):
|
|||||||
def tags(self):
|
def tags(self):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def same_depends_on(self, old: 'ParsedExposure') -> bool:
|
def same_depends_on(self, old: "ParsedExposure") -> bool:
|
||||||
return set(self.depends_on.nodes) == set(old.depends_on.nodes)
|
return set(self.depends_on.nodes) == set(old.depends_on.nodes)
|
||||||
|
|
||||||
def same_description(self, old: 'ParsedExposure') -> bool:
|
def same_description(self, old: "ParsedExposure") -> bool:
|
||||||
return self.description == old.description
|
return self.description == old.description
|
||||||
|
|
||||||
def same_maturity(self, old: 'ParsedExposure') -> bool:
|
def same_maturity(self, old: "ParsedExposure") -> bool:
|
||||||
return self.maturity == old.maturity
|
return self.maturity == old.maturity
|
||||||
|
|
||||||
def same_owner(self, old: 'ParsedExposure') -> bool:
|
def same_owner(self, old: "ParsedExposure") -> bool:
|
||||||
return self.owner == old.owner
|
return self.owner == old.owner
|
||||||
|
|
||||||
def same_exposure_type(self, old: 'ParsedExposure') -> bool:
|
def same_exposure_type(self, old: "ParsedExposure") -> bool:
|
||||||
return self.type == old.type
|
return self.type == old.type
|
||||||
|
|
||||||
def same_url(self, old: 'ParsedExposure') -> bool:
|
def same_url(self, old: "ParsedExposure") -> bool:
|
||||||
return self.url == old.url
|
return self.url == old.url
|
||||||
|
|
||||||
def same_contents(self, old: Optional['ParsedExposure']) -> bool:
|
def same_contents(self, old: Optional["ParsedExposure"]) -> bool:
|
||||||
# existing when it didn't before is a change!
|
# existing when it didn't before is a change!
|
||||||
if old is None:
|
if old is None:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return (
|
return (
|
||||||
self.same_fqn(old) and
|
self.same_fqn(old)
|
||||||
self.same_exposure_type(old) and
|
and self.same_exposure_type(old)
|
||||||
self.same_owner(old) and
|
and self.same_owner(old)
|
||||||
self.same_maturity(old) and
|
and self.same_maturity(old)
|
||||||
self.same_url(old) and
|
and self.same_url(old)
|
||||||
self.same_description(old) and
|
and self.same_description(old)
|
||||||
self.same_depends_on(old) and
|
and self.same_depends_on(old)
|
||||||
True
|
and True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,12 +4,12 @@ from dbt.contracts.util import (
|
|||||||
Mergeable,
|
Mergeable,
|
||||||
Replaceable,
|
Replaceable,
|
||||||
)
|
)
|
||||||
|
|
||||||
# trigger the PathEncoder
|
# trigger the PathEncoder
|
||||||
import dbt.helper_types # noqa:F401
|
import dbt.helper_types # noqa:F401
|
||||||
from dbt.exceptions import CompilationException
|
from dbt.exceptions import CompilationException
|
||||||
|
|
||||||
from hologram import JsonSchemaMixin
|
from dbt.dataclass_schema import dbtClassMixin, StrEnum, ExtensibleDbtClassMixin
|
||||||
from hologram.helpers import StrEnum, ExtensibleJsonSchemaMixin
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
@@ -18,7 +18,7 @@ from typing import Optional, List, Union, Dict, Any, Sequence
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UnparsedBaseNode(JsonSchemaMixin, Replaceable):
|
class UnparsedBaseNode(dbtClassMixin, Replaceable):
|
||||||
package_name: str
|
package_name: str
|
||||||
root_path: str
|
root_path: str
|
||||||
path: str
|
path: str
|
||||||
@@ -36,21 +36,25 @@ class HasSQL:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UnparsedMacro(UnparsedBaseNode, HasSQL):
|
class UnparsedMacro(UnparsedBaseNode, HasSQL):
|
||||||
resource_type: NodeType = field(metadata={'restrict': [NodeType.Macro]})
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Macro]})
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UnparsedNode(UnparsedBaseNode, HasSQL):
|
class UnparsedNode(UnparsedBaseNode, HasSQL):
|
||||||
name: str
|
name: str
|
||||||
resource_type: NodeType = field(metadata={'restrict': [
|
resource_type: NodeType = field(
|
||||||
NodeType.Model,
|
metadata={
|
||||||
NodeType.Analysis,
|
"restrict": [
|
||||||
NodeType.Test,
|
NodeType.Model,
|
||||||
NodeType.Snapshot,
|
NodeType.Analysis,
|
||||||
NodeType.Operation,
|
NodeType.Test,
|
||||||
NodeType.Seed,
|
NodeType.Snapshot,
|
||||||
NodeType.RPCCall,
|
NodeType.Operation,
|
||||||
]})
|
NodeType.Seed,
|
||||||
|
NodeType.RPCCall,
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def search_name(self):
|
def search_name(self):
|
||||||
@@ -59,22 +63,19 @@ class UnparsedNode(UnparsedBaseNode, HasSQL):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UnparsedRunHook(UnparsedNode):
|
class UnparsedRunHook(UnparsedNode):
|
||||||
resource_type: NodeType = field(
|
resource_type: NodeType = field(metadata={"restrict": [NodeType.Operation]})
|
||||||
metadata={'restrict': [NodeType.Operation]}
|
|
||||||
)
|
|
||||||
index: Optional[int] = None
|
index: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Docs(JsonSchemaMixin, Replaceable):
|
class Docs(dbtClassMixin, Replaceable):
|
||||||
show: bool = True
|
show: bool = True
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class HasDocs(AdditionalPropertiesMixin, ExtensibleJsonSchemaMixin,
|
class HasDocs(AdditionalPropertiesMixin, ExtensibleDbtClassMixin, Replaceable):
|
||||||
Replaceable):
|
|
||||||
name: str
|
name: str
|
||||||
description: str = ''
|
description: str = ""
|
||||||
meta: Dict[str, Any] = field(default_factory=dict)
|
meta: Dict[str, Any] = field(default_factory=dict)
|
||||||
data_type: Optional[str] = None
|
data_type: Optional[str] = None
|
||||||
docs: Docs = field(default_factory=Docs)
|
docs: Docs = field(default_factory=Docs)
|
||||||
@@ -100,7 +101,7 @@ class UnparsedColumn(HasTests):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class HasColumnDocs(JsonSchemaMixin, Replaceable):
|
class HasColumnDocs(dbtClassMixin, Replaceable):
|
||||||
columns: Sequence[HasDocs] = field(default_factory=list)
|
columns: Sequence[HasDocs] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
@@ -110,7 +111,7 @@ class HasColumnTests(HasColumnDocs):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class HasYamlMetadata(JsonSchemaMixin):
|
class HasYamlMetadata(dbtClassMixin):
|
||||||
original_file_path: str
|
original_file_path: str
|
||||||
yaml_key: str
|
yaml_key: str
|
||||||
package_name: str
|
package_name: str
|
||||||
@@ -127,10 +128,10 @@ class UnparsedNodeUpdate(HasColumnTests, HasTests, HasYamlMetadata):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MacroArgument(JsonSchemaMixin):
|
class MacroArgument(dbtClassMixin):
|
||||||
name: str
|
name: str
|
||||||
type: Optional[str] = None
|
type: Optional[str] = None
|
||||||
description: str = ''
|
description: str = ""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -139,16 +140,16 @@ class UnparsedMacroUpdate(HasDocs, HasYamlMetadata):
|
|||||||
|
|
||||||
|
|
||||||
class TimePeriod(StrEnum):
|
class TimePeriod(StrEnum):
|
||||||
minute = 'minute'
|
minute = "minute"
|
||||||
hour = 'hour'
|
hour = "hour"
|
||||||
day = 'day'
|
day = "day"
|
||||||
|
|
||||||
def plural(self) -> str:
|
def plural(self) -> str:
|
||||||
return str(self) + 's'
|
return str(self) + "s"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Time(JsonSchemaMixin, Replaceable):
|
class Time(dbtClassMixin, Replaceable):
|
||||||
count: int
|
count: int
|
||||||
period: TimePeriod
|
period: TimePeriod
|
||||||
|
|
||||||
@@ -159,13 +160,14 @@ class Time(JsonSchemaMixin, Replaceable):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FreshnessThreshold(JsonSchemaMixin, Mergeable):
|
class FreshnessThreshold(dbtClassMixin, Mergeable):
|
||||||
warn_after: Optional[Time] = None
|
warn_after: Optional[Time] = None
|
||||||
error_after: Optional[Time] = None
|
error_after: Optional[Time] = None
|
||||||
filter: Optional[str] = None
|
filter: Optional[str] = None
|
||||||
|
|
||||||
def status(self, age: float) -> "dbt.contracts.results.FreshnessStatus":
|
def status(self, age: float) -> "dbt.contracts.results.FreshnessStatus":
|
||||||
from dbt.contracts.results import FreshnessStatus
|
from dbt.contracts.results import FreshnessStatus
|
||||||
|
|
||||||
if self.error_after and self.error_after.exceeded(age):
|
if self.error_after and self.error_after.exceeded(age):
|
||||||
return FreshnessStatus.Error
|
return FreshnessStatus.Error
|
||||||
elif self.warn_after and self.warn_after.exceeded(age):
|
elif self.warn_after and self.warn_after.exceeded(age):
|
||||||
@@ -178,24 +180,21 @@ class FreshnessThreshold(JsonSchemaMixin, Mergeable):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AdditionalPropertiesAllowed(
|
class AdditionalPropertiesAllowed(AdditionalPropertiesMixin, ExtensibleDbtClassMixin):
|
||||||
AdditionalPropertiesMixin,
|
|
||||||
ExtensibleJsonSchemaMixin
|
|
||||||
):
|
|
||||||
_extra: Dict[str, Any] = field(default_factory=dict)
|
_extra: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ExternalPartition(AdditionalPropertiesAllowed, Replaceable):
|
class ExternalPartition(AdditionalPropertiesAllowed, Replaceable):
|
||||||
name: str = ''
|
name: str = ""
|
||||||
description: str = ''
|
description: str = ""
|
||||||
data_type: str = ''
|
data_type: str = ""
|
||||||
meta: Dict[str, Any] = field(default_factory=dict)
|
meta: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.name == '' or self.data_type == '':
|
if self.name == "" or self.data_type == "":
|
||||||
raise CompilationException(
|
raise CompilationException(
|
||||||
'External partition columns must have names and data types'
|
"External partition columns must have names and data types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -212,7 +211,7 @@ class ExternalTable(AdditionalPropertiesAllowed, Mergeable):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Quoting(JsonSchemaMixin, Mergeable):
|
class Quoting(dbtClassMixin, Mergeable):
|
||||||
database: Optional[bool] = None
|
database: Optional[bool] = None
|
||||||
schema: Optional[bool] = None
|
schema: Optional[bool] = None
|
||||||
identifier: Optional[bool] = None
|
identifier: Optional[bool] = None
|
||||||
@@ -224,48 +223,44 @@ class UnparsedSourceTableDefinition(HasColumnTests, HasTests):
|
|||||||
loaded_at_field: Optional[str] = None
|
loaded_at_field: Optional[str] = None
|
||||||
identifier: Optional[str] = None
|
identifier: Optional[str] = None
|
||||||
quoting: Quoting = field(default_factory=Quoting)
|
quoting: Quoting = field(default_factory=Quoting)
|
||||||
freshness: Optional[FreshnessThreshold] = field(
|
freshness: Optional[FreshnessThreshold] = field(default_factory=FreshnessThreshold)
|
||||||
default_factory=FreshnessThreshold
|
|
||||||
)
|
|
||||||
external: Optional[ExternalTable] = None
|
external: Optional[ExternalTable] = None
|
||||||
tags: List[str] = field(default_factory=list)
|
tags: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
def to_dict(self, omit_none=True, validate=False):
|
def __post_serialize__(self, dct):
|
||||||
result = super().to_dict(omit_none=omit_none, validate=validate)
|
dct = super().__post_serialize__(dct)
|
||||||
if omit_none and self.freshness is None:
|
if "freshness" not in dct and self.freshness is None:
|
||||||
result['freshness'] = None
|
dct["freshness"] = None
|
||||||
return result
|
return dct
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UnparsedSourceDefinition(JsonSchemaMixin, Replaceable):
|
class UnparsedSourceDefinition(dbtClassMixin, Replaceable):
|
||||||
name: str
|
name: str
|
||||||
description: str = ''
|
description: str = ""
|
||||||
meta: Dict[str, Any] = field(default_factory=dict)
|
meta: Dict[str, Any] = field(default_factory=dict)
|
||||||
database: Optional[str] = None
|
database: Optional[str] = None
|
||||||
schema: Optional[str] = None
|
schema: Optional[str] = None
|
||||||
loader: str = ''
|
loader: str = ""
|
||||||
quoting: Quoting = field(default_factory=Quoting)
|
quoting: Quoting = field(default_factory=Quoting)
|
||||||
freshness: Optional[FreshnessThreshold] = field(
|
freshness: Optional[FreshnessThreshold] = field(default_factory=FreshnessThreshold)
|
||||||
default_factory=FreshnessThreshold
|
|
||||||
)
|
|
||||||
loaded_at_field: Optional[str] = None
|
loaded_at_field: Optional[str] = None
|
||||||
tables: List[UnparsedSourceTableDefinition] = field(default_factory=list)
|
tables: List[UnparsedSourceTableDefinition] = field(default_factory=list)
|
||||||
tags: List[str] = field(default_factory=list)
|
tags: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def yaml_key(self) -> 'str':
|
def yaml_key(self) -> "str":
|
||||||
return 'sources'
|
return "sources"
|
||||||
|
|
||||||
def to_dict(self, omit_none=True, validate=False):
|
def __post_serialize__(self, dct):
|
||||||
result = super().to_dict(omit_none=omit_none, validate=validate)
|
dct = super().__post_serialize__(dct)
|
||||||
if omit_none and self.freshness is None:
|
if "freshnewss" not in dct and self.freshness is None:
|
||||||
result['freshness'] = None
|
dct["freshness"] = None
|
||||||
return result
|
return dct
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SourceTablePatch(JsonSchemaMixin):
|
class SourceTablePatch(dbtClassMixin):
|
||||||
name: str
|
name: str
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
meta: Optional[Dict[str, Any]] = None
|
meta: Optional[Dict[str, Any]] = None
|
||||||
@@ -274,9 +269,7 @@ class SourceTablePatch(JsonSchemaMixin):
|
|||||||
loaded_at_field: Optional[str] = None
|
loaded_at_field: Optional[str] = None
|
||||||
identifier: Optional[str] = None
|
identifier: Optional[str] = None
|
||||||
quoting: Quoting = field(default_factory=Quoting)
|
quoting: Quoting = field(default_factory=Quoting)
|
||||||
freshness: Optional[FreshnessThreshold] = field(
|
freshness: Optional[FreshnessThreshold] = field(default_factory=FreshnessThreshold)
|
||||||
default_factory=FreshnessThreshold
|
|
||||||
)
|
|
||||||
external: Optional[ExternalTable] = None
|
external: Optional[ExternalTable] = None
|
||||||
tags: Optional[List[str]] = None
|
tags: Optional[List[str]] = None
|
||||||
tests: Optional[List[TestDef]] = None
|
tests: Optional[List[TestDef]] = None
|
||||||
@@ -284,27 +277,27 @@ class SourceTablePatch(JsonSchemaMixin):
|
|||||||
|
|
||||||
def to_patch_dict(self) -> Dict[str, Any]:
|
def to_patch_dict(self) -> Dict[str, Any]:
|
||||||
dct = self.to_dict(omit_none=True)
|
dct = self.to_dict(omit_none=True)
|
||||||
remove_keys = ('name')
|
remove_keys = "name"
|
||||||
for key in remove_keys:
|
for key in remove_keys:
|
||||||
if key in dct:
|
if key in dct:
|
||||||
del dct[key]
|
del dct[key]
|
||||||
|
|
||||||
if self.freshness is None:
|
if self.freshness is None:
|
||||||
dct['freshness'] = None
|
dct["freshness"] = None
|
||||||
|
|
||||||
return dct
|
return dct
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SourcePatch(JsonSchemaMixin, Replaceable):
|
class SourcePatch(dbtClassMixin, Replaceable):
|
||||||
name: str = field(
|
name: str = field(
|
||||||
metadata=dict(description='The name of the source to override'),
|
metadata=dict(description="The name of the source to override"),
|
||||||
)
|
)
|
||||||
overrides: str = field(
|
overrides: str = field(
|
||||||
metadata=dict(description='The package of the source to override'),
|
metadata=dict(description="The package of the source to override"),
|
||||||
)
|
)
|
||||||
path: Path = field(
|
path: Path = field(
|
||||||
metadata=dict(description='The path to the patch-defining yml file'),
|
metadata=dict(description="The path to the patch-defining yml file"),
|
||||||
)
|
)
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
meta: Optional[Dict[str, Any]] = None
|
meta: Optional[Dict[str, Any]] = None
|
||||||
@@ -321,13 +314,13 @@ class SourcePatch(JsonSchemaMixin, Replaceable):
|
|||||||
|
|
||||||
def to_patch_dict(self) -> Dict[str, Any]:
|
def to_patch_dict(self) -> Dict[str, Any]:
|
||||||
dct = self.to_dict(omit_none=True)
|
dct = self.to_dict(omit_none=True)
|
||||||
remove_keys = ('name', 'overrides', 'tables', 'path')
|
remove_keys = ("name", "overrides", "tables", "path")
|
||||||
for key in remove_keys:
|
for key in remove_keys:
|
||||||
if key in dct:
|
if key in dct:
|
||||||
del dct[key]
|
del dct[key]
|
||||||
|
|
||||||
if self.freshness is None:
|
if self.freshness is None:
|
||||||
dct['freshness'] = None
|
dct["freshness"] = None
|
||||||
|
|
||||||
return dct
|
return dct
|
||||||
|
|
||||||
@@ -340,7 +333,7 @@ class SourcePatch(JsonSchemaMixin, Replaceable):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UnparsedDocumentation(JsonSchemaMixin, Replaceable):
|
class UnparsedDocumentation(dbtClassMixin, Replaceable):
|
||||||
package_name: str
|
package_name: str
|
||||||
root_path: str
|
root_path: str
|
||||||
path: str
|
path: str
|
||||||
@@ -359,9 +352,9 @@ class UnparsedDocumentationFile(UnparsedDocumentation):
|
|||||||
# can't use total_ordering decorator here, as str provides an ordering already
|
# can't use total_ordering decorator here, as str provides an ordering already
|
||||||
# and it's not the one we want.
|
# and it's not the one we want.
|
||||||
class Maturity(StrEnum):
|
class Maturity(StrEnum):
|
||||||
low = 'low'
|
low = "low"
|
||||||
medium = 'medium'
|
medium = "medium"
|
||||||
high = 'high'
|
high = "high"
|
||||||
|
|
||||||
def __lt__(self, other):
|
def __lt__(self, other):
|
||||||
if not isinstance(other, Maturity):
|
if not isinstance(other, Maturity):
|
||||||
@@ -386,31 +379,31 @@ class Maturity(StrEnum):
|
|||||||
|
|
||||||
|
|
||||||
class ExposureType(StrEnum):
|
class ExposureType(StrEnum):
|
||||||
Dashboard = 'dashboard'
|
Dashboard = "dashboard"
|
||||||
Notebook = 'notebook'
|
Notebook = "notebook"
|
||||||
Analysis = 'analysis'
|
Analysis = "analysis"
|
||||||
ML = 'ml'
|
ML = "ml"
|
||||||
Application = 'application'
|
Application = "application"
|
||||||
|
|
||||||
|
|
||||||
class MaturityType(StrEnum):
|
class MaturityType(StrEnum):
|
||||||
Low = 'low'
|
Low = "low"
|
||||||
Medium = 'medium'
|
Medium = "medium"
|
||||||
High = 'high'
|
High = "high"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ExposureOwner(JsonSchemaMixin, Replaceable):
|
class ExposureOwner(dbtClassMixin, Replaceable):
|
||||||
email: str
|
email: str
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UnparsedExposure(JsonSchemaMixin, Replaceable):
|
class UnparsedExposure(dbtClassMixin, Replaceable):
|
||||||
name: str
|
name: str
|
||||||
type: ExposureType
|
type: ExposureType
|
||||||
owner: ExposureOwner
|
owner: ExposureOwner
|
||||||
description: str = ''
|
description: str = ""
|
||||||
maturity: Optional[MaturityType] = None
|
maturity: Optional[MaturityType] = None
|
||||||
url: Optional[str] = None
|
url: Optional[str] = None
|
||||||
depends_on: List[str] = field(default_factory=list)
|
depends_on: List[str] = field(default_factory=list)
|
||||||
|
|||||||
@@ -4,41 +4,57 @@ from dbt.helper_types import NoValue
|
|||||||
from dbt.logger import GLOBAL_LOGGER as logger # noqa
|
from dbt.logger import GLOBAL_LOGGER as logger # noqa
|
||||||
from dbt import tracking
|
from dbt import tracking
|
||||||
from dbt import ui
|
from dbt import ui
|
||||||
|
from dbt.dataclass_schema import (
|
||||||
from hologram import JsonSchemaMixin, ValidationError
|
dbtClassMixin,
|
||||||
from hologram.helpers import HyphenatedJsonSchemaMixin, register_pattern, \
|
ValidationError,
|
||||||
ExtensibleJsonSchemaMixin
|
HyphenatedDbtClassMixin,
|
||||||
|
ExtensibleDbtClassMixin,
|
||||||
|
register_pattern,
|
||||||
|
ValidatedStringMixin,
|
||||||
|
)
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional, List, Dict, Union, Any, NewType
|
from typing import Optional, List, Dict, Union, Any
|
||||||
|
from mashumaro.types import SerializableType
|
||||||
|
|
||||||
PIN_PACKAGE_URL = 'https://docs.getdbt.com/docs/package-management#section-specifying-package-versions' # noqa
|
PIN_PACKAGE_URL = "https://docs.getdbt.com/docs/package-management#section-specifying-package-versions" # noqa
|
||||||
DEFAULT_SEND_ANONYMOUS_USAGE_STATS = True
|
DEFAULT_SEND_ANONYMOUS_USAGE_STATS = True
|
||||||
|
|
||||||
|
|
||||||
Name = NewType('Name', str)
|
class Name(ValidatedStringMixin):
|
||||||
register_pattern(Name, r'^[^\d\W]\w*$')
|
ValidationRegex = r"^[^\d\W]\w*$"
|
||||||
|
|
||||||
|
|
||||||
|
register_pattern(Name, r"^[^\d\W]\w*$")
|
||||||
|
|
||||||
|
|
||||||
|
class SemverString(str, SerializableType):
|
||||||
|
def _serialize(self) -> str:
|
||||||
|
return self
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _deserialize(cls, value: str) -> "SemverString":
|
||||||
|
return SemverString(value)
|
||||||
|
|
||||||
|
|
||||||
# this does not support the full semver (does not allow a trailing -fooXYZ) and
|
# this does not support the full semver (does not allow a trailing -fooXYZ) and
|
||||||
# is not restrictive enough for full semver, (allows '1.0'). But it's like
|
# is not restrictive enough for full semver, (allows '1.0'). But it's like
|
||||||
# 'semver lite'.
|
# 'semver lite'.
|
||||||
SemverString = NewType('SemverString', str)
|
|
||||||
register_pattern(
|
register_pattern(
|
||||||
SemverString,
|
SemverString,
|
||||||
r'^(?:0|[1-9]\d*)\.(?:0|[1-9]\d*)(\.(?:0|[1-9]\d*))?$',
|
r"^(?:0|[1-9]\d*)\.(?:0|[1-9]\d*)(\.(?:0|[1-9]\d*))?$",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Quoting(JsonSchemaMixin, Mergeable):
|
class Quoting(dbtClassMixin, Mergeable):
|
||||||
identifier: Optional[bool]
|
schema: Optional[bool] = None
|
||||||
schema: Optional[bool]
|
database: Optional[bool] = None
|
||||||
database: Optional[bool]
|
project: Optional[bool] = None
|
||||||
project: Optional[bool]
|
identifier: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Package(Replaceable, HyphenatedJsonSchemaMixin):
|
class Package(Replaceable, HyphenatedDbtClassMixin):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@@ -54,7 +70,7 @@ RawVersion = Union[str, float]
|
|||||||
@dataclass
|
@dataclass
|
||||||
class GitPackage(Package):
|
class GitPackage(Package):
|
||||||
git: str
|
git: str
|
||||||
revision: Optional[RawVersion]
|
revision: Optional[RawVersion] = None
|
||||||
warn_unpinned: Optional[bool] = None
|
warn_unpinned: Optional[bool] = None
|
||||||
|
|
||||||
def get_revisions(self) -> List[str]:
|
def get_revisions(self) -> List[str]:
|
||||||
@@ -80,7 +96,7 @@ PackageSpec = Union[LocalPackage, GitPackage, RegistryPackage]
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PackageConfig(JsonSchemaMixin, Replaceable):
|
class PackageConfig(dbtClassMixin, Replaceable):
|
||||||
packages: List[PackageSpec]
|
packages: List[PackageSpec]
|
||||||
|
|
||||||
|
|
||||||
@@ -91,18 +107,17 @@ class ProjectPackageMetadata:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_project(cls, project):
|
def from_project(cls, project):
|
||||||
return cls(name=project.project_name,
|
return cls(name=project.project_name, packages=project.packages.packages)
|
||||||
packages=project.packages.packages)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Downloads(ExtensibleJsonSchemaMixin, Replaceable):
|
class Downloads(ExtensibleDbtClassMixin, Replaceable):
|
||||||
tarball: str
|
tarball: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RegistryPackageMetadata(
|
class RegistryPackageMetadata(
|
||||||
ExtensibleJsonSchemaMixin,
|
ExtensibleDbtClassMixin,
|
||||||
ProjectPackageMetadata,
|
ProjectPackageMetadata,
|
||||||
):
|
):
|
||||||
downloads: Downloads
|
downloads: Downloads
|
||||||
@@ -110,51 +125,51 @@ class RegistryPackageMetadata(
|
|||||||
|
|
||||||
# A list of all the reserved words that packages may not have as names.
|
# A list of all the reserved words that packages may not have as names.
|
||||||
BANNED_PROJECT_NAMES = {
|
BANNED_PROJECT_NAMES = {
|
||||||
'_sql_results',
|
"_sql_results",
|
||||||
'adapter',
|
"adapter",
|
||||||
'api',
|
"api",
|
||||||
'column',
|
"column",
|
||||||
'config',
|
"config",
|
||||||
'context',
|
"context",
|
||||||
'database',
|
"database",
|
||||||
'env',
|
"env",
|
||||||
'env_var',
|
"env_var",
|
||||||
'exceptions',
|
"exceptions",
|
||||||
'execute',
|
"execute",
|
||||||
'flags',
|
"flags",
|
||||||
'fromjson',
|
"fromjson",
|
||||||
'fromyaml',
|
"fromyaml",
|
||||||
'graph',
|
"graph",
|
||||||
'invocation_id',
|
"invocation_id",
|
||||||
'load_agate_table',
|
"load_agate_table",
|
||||||
'load_result',
|
"load_result",
|
||||||
'log',
|
"log",
|
||||||
'model',
|
"model",
|
||||||
'modules',
|
"modules",
|
||||||
'post_hooks',
|
"post_hooks",
|
||||||
'pre_hooks',
|
"pre_hooks",
|
||||||
'ref',
|
"ref",
|
||||||
'render',
|
"render",
|
||||||
'return',
|
"return",
|
||||||
'run_started_at',
|
"run_started_at",
|
||||||
'schema',
|
"schema",
|
||||||
'source',
|
"source",
|
||||||
'sql',
|
"sql",
|
||||||
'sql_now',
|
"sql_now",
|
||||||
'store_result',
|
"store_result",
|
||||||
'store_raw_result',
|
"store_raw_result",
|
||||||
'target',
|
"target",
|
||||||
'this',
|
"this",
|
||||||
'tojson',
|
"tojson",
|
||||||
'toyaml',
|
"toyaml",
|
||||||
'try_or_compiler_error',
|
"try_or_compiler_error",
|
||||||
'var',
|
"var",
|
||||||
'write',
|
"write",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Project(HyphenatedJsonSchemaMixin, Replaceable):
|
class Project(HyphenatedDbtClassMixin, Replaceable):
|
||||||
name: Name
|
name: Name
|
||||||
version: Union[SemverString, float]
|
version: Union[SemverString, float]
|
||||||
config_version: int
|
config_version: int
|
||||||
@@ -184,25 +199,23 @@ class Project(HyphenatedJsonSchemaMixin, Replaceable):
|
|||||||
vars: Optional[Dict[str, Any]] = field(
|
vars: Optional[Dict[str, Any]] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata=dict(
|
metadata=dict(
|
||||||
description='map project names to their vars override dicts',
|
description="map project names to their vars override dicts",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
packages: List[PackageSpec] = field(default_factory=list)
|
packages: List[PackageSpec] = field(default_factory=list)
|
||||||
query_comment: Optional[Union[QueryComment, NoValue, str]] = NoValue()
|
query_comment: Optional[Union[QueryComment, NoValue, str]] = NoValue()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data, validate=True) -> 'Project':
|
def validate(cls, data):
|
||||||
result = super().from_dict(data, validate=validate)
|
super().validate(data)
|
||||||
if result.name in BANNED_PROJECT_NAMES:
|
if data["name"] in BANNED_PROJECT_NAMES:
|
||||||
raise ValidationError(
|
raise ValidationError(
|
||||||
f'Invalid project name: {result.name} is a reserved word'
|
f"Invalid project name: {data['name']} is a reserved word"
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UserConfig(ExtensibleJsonSchemaMixin, Replaceable, UserConfigContract):
|
class UserConfig(ExtensibleDbtClassMixin, Replaceable, UserConfigContract):
|
||||||
send_anonymous_usage_stats: bool = DEFAULT_SEND_ANONYMOUS_USAGE_STATS
|
send_anonymous_usage_stats: bool = DEFAULT_SEND_ANONYMOUS_USAGE_STATS
|
||||||
use_colors: Optional[bool] = None
|
use_colors: Optional[bool] = None
|
||||||
partial_parse: Optional[bool] = None
|
partial_parse: Optional[bool] = None
|
||||||
@@ -222,9 +235,9 @@ class UserConfig(ExtensibleJsonSchemaMixin, Replaceable, UserConfigContract):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ProfileConfig(HyphenatedJsonSchemaMixin, Replaceable):
|
class ProfileConfig(HyphenatedDbtClassMixin, Replaceable):
|
||||||
profile_name: str = field(metadata={'preserve_underscore': True})
|
profile_name: str = field(metadata={"preserve_underscore": True})
|
||||||
target_name: str = field(metadata={'preserve_underscore': True})
|
target_name: str = field(metadata={"preserve_underscore": True})
|
||||||
config: UserConfig
|
config: UserConfig
|
||||||
threads: int
|
threads: int
|
||||||
# TODO: make this a dynamic union of some kind?
|
# TODO: make this a dynamic union of some kind?
|
||||||
@@ -233,21 +246,21 @@ class ProfileConfig(HyphenatedJsonSchemaMixin, Replaceable):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ConfiguredQuoting(Quoting, Replaceable):
|
class ConfiguredQuoting(Quoting, Replaceable):
|
||||||
identifier: bool
|
identifier: bool = True
|
||||||
schema: bool
|
schema: bool = True
|
||||||
database: Optional[bool]
|
database: Optional[bool] = None
|
||||||
project: Optional[bool]
|
project: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Configuration(Project, ProfileConfig):
|
class Configuration(Project, ProfileConfig):
|
||||||
cli_vars: Dict[str, Any] = field(
|
cli_vars: Dict[str, Any] = field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
metadata={'preserve_underscore': True},
|
metadata={"preserve_underscore": True},
|
||||||
)
|
)
|
||||||
quoting: Optional[ConfiguredQuoting] = None
|
quoting: Optional[ConfiguredQuoting] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ProjectList(JsonSchemaMixin):
|
class ProjectList(dbtClassMixin):
|
||||||
projects: Dict[str, Project]
|
projects: Dict[str, Project]
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass, fields
|
||||||
from typing import (
|
from typing import (
|
||||||
Optional, TypeVar, Generic, Dict,
|
Optional,
|
||||||
|
Dict,
|
||||||
)
|
)
|
||||||
from typing_extensions import Protocol
|
from typing_extensions import Protocol
|
||||||
|
|
||||||
from hologram import JsonSchemaMixin
|
from dbt.dataclass_schema import dbtClassMixin, StrEnum
|
||||||
from hologram.helpers import StrEnum
|
|
||||||
|
|
||||||
from dbt import deprecations
|
from dbt import deprecations
|
||||||
from dbt.contracts.util import Replaceable
|
from dbt.contracts.util import Replaceable
|
||||||
@@ -15,24 +15,24 @@ from dbt.utils import deep_merge
|
|||||||
|
|
||||||
|
|
||||||
class RelationType(StrEnum):
|
class RelationType(StrEnum):
|
||||||
Table = 'table'
|
Table = "table"
|
||||||
View = 'view'
|
View = "view"
|
||||||
CTE = 'cte'
|
CTE = "cte"
|
||||||
MaterializedView = 'materializedview'
|
MaterializedView = "materializedview"
|
||||||
External = 'external'
|
External = "external"
|
||||||
|
|
||||||
|
|
||||||
class ComponentName(StrEnum):
|
class ComponentName(StrEnum):
|
||||||
Database = 'database'
|
Database = "database"
|
||||||
Schema = 'schema'
|
Schema = "schema"
|
||||||
Identifier = 'identifier'
|
Identifier = "identifier"
|
||||||
|
|
||||||
|
|
||||||
class HasQuoting(Protocol):
|
class HasQuoting(Protocol):
|
||||||
quoting: Dict[str, bool]
|
quoting: Dict[str, bool]
|
||||||
|
|
||||||
|
|
||||||
class FakeAPIObject(JsonSchemaMixin, Replaceable, Mapping):
|
class FakeAPIObject(dbtClassMixin, Replaceable, Mapping):
|
||||||
# override the mapping truthiness, len is always >1
|
# override the mapping truthiness, len is always >1
|
||||||
def __bool__(self):
|
def __bool__(self):
|
||||||
return True
|
return True
|
||||||
@@ -44,30 +44,27 @@ class FakeAPIObject(JsonSchemaMixin, Replaceable, Mapping):
|
|||||||
raise KeyError(key) from None
|
raise KeyError(key) from None
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
deprecations.warn('not-a-dictionary', obj=self)
|
deprecations.warn("not-a-dictionary", obj=self)
|
||||||
for _, name in self._get_fields():
|
for _, name in self._get_fields():
|
||||||
yield name
|
yield name
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
deprecations.warn('not-a-dictionary', obj=self)
|
deprecations.warn("not-a-dictionary", obj=self)
|
||||||
return len(fields(self.__class__))
|
return len(fields(self.__class__))
|
||||||
|
|
||||||
def incorporate(self, **kwargs):
|
def incorporate(self, **kwargs):
|
||||||
value = self.to_dict()
|
value = self.to_dict(omit_none=True)
|
||||||
value = deep_merge(value, kwargs)
|
value = deep_merge(value, kwargs)
|
||||||
return self.from_dict(value)
|
return self.from_dict(value)
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar('T')
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class _ComponentObject(FakeAPIObject, Generic[T]):
|
class Policy(FakeAPIObject):
|
||||||
database: T
|
database: bool = True
|
||||||
schema: T
|
schema: bool = True
|
||||||
identifier: T
|
identifier: bool = True
|
||||||
|
|
||||||
def get_part(self, key: ComponentName) -> T:
|
def get_part(self, key: ComponentName) -> bool:
|
||||||
if key == ComponentName.Database:
|
if key == ComponentName.Database:
|
||||||
return self.database
|
return self.database
|
||||||
elif key == ComponentName.Schema:
|
elif key == ComponentName.Schema:
|
||||||
@@ -76,43 +73,35 @@ class _ComponentObject(FakeAPIObject, Generic[T]):
|
|||||||
return self.identifier
|
return self.identifier
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Got a key of {}, expected one of {}'
|
"Got a key of {}, expected one of {}".format(key, list(ComponentName))
|
||||||
.format(key, list(ComponentName))
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def replace_dict(self, dct: Dict[ComponentName, T]):
|
def replace_dict(self, dct: Dict[ComponentName, bool]):
|
||||||
kwargs: Dict[str, T] = {}
|
kwargs: Dict[str, bool] = {}
|
||||||
for k, v in dct.items():
|
for k, v in dct.items():
|
||||||
kwargs[str(k)] = v
|
kwargs[str(k)] = v
|
||||||
return self.replace(**kwargs)
|
return self.replace(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Policy(_ComponentObject[bool]):
|
class Path(FakeAPIObject):
|
||||||
database: bool = True
|
database: Optional[str] = None
|
||||||
schema: bool = True
|
schema: Optional[str] = None
|
||||||
identifier: bool = True
|
identifier: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Path(_ComponentObject[Optional[str]]):
|
|
||||||
database: Optional[str]
|
|
||||||
schema: Optional[str]
|
|
||||||
identifier: Optional[str]
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# handle pesky jinja2.Undefined sneaking in here and messing up rende
|
# handle pesky jinja2.Undefined sneaking in here and messing up rende
|
||||||
if not isinstance(self.database, (type(None), str)):
|
if not isinstance(self.database, (type(None), str)):
|
||||||
raise CompilationException(
|
raise CompilationException(
|
||||||
'Got an invalid path database: {}'.format(self.database)
|
"Got an invalid path database: {}".format(self.database)
|
||||||
)
|
)
|
||||||
if not isinstance(self.schema, (type(None), str)):
|
if not isinstance(self.schema, (type(None), str)):
|
||||||
raise CompilationException(
|
raise CompilationException(
|
||||||
'Got an invalid path schema: {}'.format(self.schema)
|
"Got an invalid path schema: {}".format(self.schema)
|
||||||
)
|
)
|
||||||
if not isinstance(self.identifier, (type(None), str)):
|
if not isinstance(self.identifier, (type(None), str)):
|
||||||
raise CompilationException(
|
raise CompilationException(
|
||||||
'Got an invalid path identifier: {}'.format(self.identifier)
|
"Got an invalid path identifier: {}".format(self.identifier)
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_lowered_part(self, key: ComponentName) -> Optional[str]:
|
def get_lowered_part(self, key: ComponentName) -> Optional[str]:
|
||||||
@@ -120,3 +109,21 @@ class Path(_ComponentObject[Optional[str]]):
|
|||||||
if part is not None:
|
if part is not None:
|
||||||
part = part.lower()
|
part = part.lower()
|
||||||
return part
|
return part
|
||||||
|
|
||||||
|
def get_part(self, key: ComponentName) -> Optional[str]:
|
||||||
|
if key == ComponentName.Database:
|
||||||
|
return self.database
|
||||||
|
elif key == ComponentName.Schema:
|
||||||
|
return self.schema
|
||||||
|
elif key == ComponentName.Identifier:
|
||||||
|
return self.identifier
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Got a key of {}, expected one of {}".format(key, list(ComponentName))
|
||||||
|
)
|
||||||
|
|
||||||
|
def replace_dict(self, dct: Dict[ComponentName, str]):
|
||||||
|
kwargs: Dict[str, str] = {}
|
||||||
|
for k, v in dct.items():
|
||||||
|
kwargs[str(k)] = v
|
||||||
|
return self.replace(**kwargs)
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
from dbt.contracts.graph.manifest import CompileResultNode
|
from dbt.contracts.graph.manifest import CompileResultNode
|
||||||
from dbt.contracts.graph.unparsed import (
|
from dbt.contracts.graph.unparsed import FreshnessThreshold
|
||||||
FreshnessThreshold
|
|
||||||
)
|
|
||||||
from dbt.contracts.graph.parsed import ParsedSourceDefinition
|
from dbt.contracts.graph.parsed import ParsedSourceDefinition
|
||||||
from dbt.contracts.util import (
|
from dbt.contracts.util import (
|
||||||
BaseArtifactMetadata,
|
BaseArtifactMetadata,
|
||||||
@@ -17,20 +15,27 @@ from dbt.logger import (
|
|||||||
GLOBAL_LOGGER as logger,
|
GLOBAL_LOGGER as logger,
|
||||||
)
|
)
|
||||||
from dbt.utils import lowercase
|
from dbt.utils import lowercase
|
||||||
from hologram.helpers import StrEnum
|
from dbt.dataclass_schema import dbtClassMixin, StrEnum
|
||||||
from hologram import JsonSchemaMixin
|
|
||||||
|
|
||||||
import agate
|
import agate
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Union, Dict, List, Optional, Any, NamedTuple, Sequence
|
from typing import (
|
||||||
|
Union,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Any,
|
||||||
|
NamedTuple,
|
||||||
|
Sequence,
|
||||||
|
)
|
||||||
|
|
||||||
from dbt.clients.system import write_json
|
from dbt.clients.system import write_json
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TimingInfo(JsonSchemaMixin):
|
class TimingInfo(dbtClassMixin):
|
||||||
name: str
|
name: str
|
||||||
started_at: Optional[datetime] = None
|
started_at: Optional[datetime] = None
|
||||||
completed_at: Optional[datetime] = None
|
completed_at: Optional[datetime] = None
|
||||||
@@ -53,7 +58,7 @@ class collect_timing_info:
|
|||||||
def __exit__(self, exc_type, exc_value, traceback):
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
self.timing_info.end()
|
self.timing_info.end()
|
||||||
with JsonOnly(), TimingProcessor(self.timing_info):
|
with JsonOnly(), TimingProcessor(self.timing_info):
|
||||||
logger.debug('finished collecting timing info')
|
logger.debug("finished collecting timing info")
|
||||||
|
|
||||||
|
|
||||||
class NodeStatus(StrEnum):
|
class NodeStatus(StrEnum):
|
||||||
@@ -87,13 +92,20 @@ class FreshnessStatus(StrEnum):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseResult(JsonSchemaMixin):
|
class BaseResult(dbtClassMixin):
|
||||||
status: Union[RunStatus, TestStatus, FreshnessStatus]
|
status: Union[RunStatus, TestStatus, FreshnessStatus]
|
||||||
timing: List[TimingInfo]
|
timing: List[TimingInfo]
|
||||||
thread_id: str
|
thread_id: str
|
||||||
execution_time: float
|
execution_time: float
|
||||||
message: Optional[Union[str, int]]
|
|
||||||
adapter_response: Dict[str, Any]
|
adapter_response: Dict[str, Any]
|
||||||
|
message: Optional[Union[str, int]]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __pre_deserialize__(cls, data):
|
||||||
|
data = super().__pre_deserialize__(data)
|
||||||
|
if "message" not in data:
|
||||||
|
data["message"] = None
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -103,7 +115,10 @@ class NodeResult(BaseResult):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RunResult(NodeResult):
|
class RunResult(NodeResult):
|
||||||
agate_table: Optional[agate.Table] = None
|
agate_table: Optional[agate.Table] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"serialize": lambda x: None, "deserialize": lambda x: None},
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def skipped(self):
|
def skipped(self):
|
||||||
@@ -111,7 +126,7 @@ class RunResult(NodeResult):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ExecutionResult(JsonSchemaMixin):
|
class ExecutionResult(dbtClassMixin):
|
||||||
results: Sequence[BaseResult]
|
results: Sequence[BaseResult]
|
||||||
elapsed_time: float
|
elapsed_time: float
|
||||||
|
|
||||||
@@ -145,7 +160,7 @@ def process_run_result(result: RunResult) -> RunResultOutput:
|
|||||||
thread_id=result.thread_id,
|
thread_id=result.thread_id,
|
||||||
execution_time=result.execution_time,
|
execution_time=result.execution_time,
|
||||||
message=result.message,
|
message=result.message,
|
||||||
adapter_response=result.adapter_response
|
adapter_response=result.adapter_response,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -168,7 +183,7 @@ class RunExecutionResult(
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('run-results', 1)
|
@schema_version("run-results", 1)
|
||||||
class RunResultsArtifact(ExecutionResult, ArtifactMixin):
|
class RunResultsArtifact(ExecutionResult, ArtifactMixin):
|
||||||
results: Sequence[RunResultOutput]
|
results: Sequence[RunResultOutput]
|
||||||
args: Dict[str, Any] = field(default_factory=dict)
|
args: Dict[str, Any] = field(default_factory=dict)
|
||||||
@@ -190,11 +205,11 @@ class RunResultsArtifact(ExecutionResult, ArtifactMixin):
|
|||||||
metadata=meta,
|
metadata=meta,
|
||||||
results=processed_results,
|
results=processed_results,
|
||||||
elapsed_time=elapsed_time,
|
elapsed_time=elapsed_time,
|
||||||
args=args
|
args=args,
|
||||||
)
|
)
|
||||||
|
|
||||||
def write(self, path: str, omit_none=False):
|
def write(self, path: str):
|
||||||
write_json(path, self.to_dict(omit_none=omit_none))
|
write_json(path, self.to_dict(omit_none=False))
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -204,15 +219,14 @@ class RunOperationResult(ExecutionResult):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RunOperationResultMetadata(BaseArtifactMetadata):
|
class RunOperationResultMetadata(BaseArtifactMetadata):
|
||||||
dbt_schema_version: str = field(default_factory=lambda: str(
|
dbt_schema_version: str = field(
|
||||||
RunOperationResultsArtifact.dbt_schema_version
|
default_factory=lambda: str(RunOperationResultsArtifact.dbt_schema_version)
|
||||||
))
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('run-operation-result', 1)
|
@schema_version("run-operation-result", 1)
|
||||||
class RunOperationResultsArtifact(RunOperationResult, ArtifactMixin):
|
class RunOperationResultsArtifact(RunOperationResult, ArtifactMixin):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_success(
|
def from_success(
|
||||||
cls,
|
cls,
|
||||||
@@ -231,6 +245,7 @@ class RunOperationResultsArtifact(RunOperationResult, ArtifactMixin):
|
|||||||
success=success,
|
success=success,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# due to issues with typing.Union collapsing subclasses, this can't subclass
|
# due to issues with typing.Union collapsing subclasses, this can't subclass
|
||||||
# PartialResult
|
# PartialResult
|
||||||
|
|
||||||
@@ -249,18 +264,18 @@ class SourceFreshnessResult(NodeResult):
|
|||||||
|
|
||||||
|
|
||||||
class FreshnessErrorEnum(StrEnum):
|
class FreshnessErrorEnum(StrEnum):
|
||||||
runtime_error = 'runtime error'
|
runtime_error = "runtime error"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SourceFreshnessRuntimeError(JsonSchemaMixin):
|
class SourceFreshnessRuntimeError(dbtClassMixin):
|
||||||
unique_id: str
|
unique_id: str
|
||||||
error: Optional[Union[str, int]]
|
error: Optional[Union[str, int]]
|
||||||
status: FreshnessErrorEnum
|
status: FreshnessErrorEnum
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SourceFreshnessOutput(JsonSchemaMixin):
|
class SourceFreshnessOutput(dbtClassMixin):
|
||||||
unique_id: str
|
unique_id: str
|
||||||
max_loaded_at: datetime
|
max_loaded_at: datetime
|
||||||
snapshotted_at: datetime
|
snapshotted_at: datetime
|
||||||
@@ -279,14 +294,11 @@ class PartialSourceFreshnessResult(NodeResult):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
FreshnessNodeResult = Union[PartialSourceFreshnessResult,
|
FreshnessNodeResult = Union[PartialSourceFreshnessResult, SourceFreshnessResult]
|
||||||
SourceFreshnessResult]
|
|
||||||
FreshnessNodeOutput = Union[SourceFreshnessRuntimeError, SourceFreshnessOutput]
|
FreshnessNodeOutput = Union[SourceFreshnessRuntimeError, SourceFreshnessOutput]
|
||||||
|
|
||||||
|
|
||||||
def process_freshness_result(
|
def process_freshness_result(result: FreshnessNodeResult) -> FreshnessNodeOutput:
|
||||||
result: FreshnessNodeResult
|
|
||||||
) -> FreshnessNodeOutput:
|
|
||||||
unique_id = result.node.unique_id
|
unique_id = result.node.unique_id
|
||||||
if result.status == FreshnessStatus.RuntimeErr:
|
if result.status == FreshnessStatus.RuntimeErr:
|
||||||
return SourceFreshnessRuntimeError(
|
return SourceFreshnessRuntimeError(
|
||||||
@@ -298,16 +310,15 @@ def process_freshness_result(
|
|||||||
# we know that this must be a SourceFreshnessResult
|
# we know that this must be a SourceFreshnessResult
|
||||||
if not isinstance(result, SourceFreshnessResult):
|
if not isinstance(result, SourceFreshnessResult):
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
'Got {} instead of a SourceFreshnessResult for a '
|
"Got {} instead of a SourceFreshnessResult for a "
|
||||||
'non-error result in freshness execution!'
|
"non-error result in freshness execution!".format(type(result))
|
||||||
.format(type(result))
|
|
||||||
)
|
)
|
||||||
# if we're here, we must have a non-None freshness threshold
|
# if we're here, we must have a non-None freshness threshold
|
||||||
criteria = result.node.freshness
|
criteria = result.node.freshness
|
||||||
if criteria is None:
|
if criteria is None:
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
'Somehow evaluated a freshness result for a source '
|
"Somehow evaluated a freshness result for a source "
|
||||||
'that has no freshness criteria!'
|
"that has no freshness criteria!"
|
||||||
)
|
)
|
||||||
return SourceFreshnessOutput(
|
return SourceFreshnessOutput(
|
||||||
unique_id=unique_id,
|
unique_id=unique_id,
|
||||||
@@ -316,16 +327,14 @@ def process_freshness_result(
|
|||||||
max_loaded_at_time_ago_in_s=result.age,
|
max_loaded_at_time_ago_in_s=result.age,
|
||||||
status=result.status,
|
status=result.status,
|
||||||
criteria=criteria,
|
criteria=criteria,
|
||||||
adapter_response=result.adapter_response
|
adapter_response=result.adapter_response,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FreshnessMetadata(BaseArtifactMetadata):
|
class FreshnessMetadata(BaseArtifactMetadata):
|
||||||
dbt_schema_version: str = field(
|
dbt_schema_version: str = field(
|
||||||
default_factory=lambda: str(
|
default_factory=lambda: str(FreshnessExecutionResultArtifact.dbt_schema_version)
|
||||||
FreshnessExecutionResultArtifact.dbt_schema_version
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -346,7 +355,7 @@ class FreshnessResult(ExecutionResult):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('sources', 1)
|
@schema_version("sources", 1)
|
||||||
class FreshnessExecutionResultArtifact(
|
class FreshnessExecutionResultArtifact(
|
||||||
ArtifactMixin,
|
ArtifactMixin,
|
||||||
VersionedSchema,
|
VersionedSchema,
|
||||||
@@ -368,46 +377,45 @@ class FreshnessExecutionResultArtifact(
|
|||||||
Primitive = Union[bool, str, float, None]
|
Primitive = Union[bool, str, float, None]
|
||||||
|
|
||||||
CatalogKey = NamedTuple(
|
CatalogKey = NamedTuple(
|
||||||
'CatalogKey',
|
"CatalogKey", [("database", Optional[str]), ("schema", str), ("name", str)]
|
||||||
[('database', Optional[str]), ('schema', str), ('name', str)]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class StatsItem(JsonSchemaMixin):
|
class StatsItem(dbtClassMixin):
|
||||||
id: str
|
id: str
|
||||||
label: str
|
label: str
|
||||||
value: Primitive
|
value: Primitive
|
||||||
description: Optional[str]
|
|
||||||
include: bool
|
include: bool
|
||||||
|
description: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
StatsDict = Dict[str, StatsItem]
|
StatsDict = Dict[str, StatsItem]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ColumnMetadata(JsonSchemaMixin):
|
class ColumnMetadata(dbtClassMixin):
|
||||||
type: str
|
type: str
|
||||||
comment: Optional[str]
|
|
||||||
index: int
|
index: int
|
||||||
name: str
|
name: str
|
||||||
|
comment: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
ColumnMap = Dict[str, ColumnMetadata]
|
ColumnMap = Dict[str, ColumnMetadata]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TableMetadata(JsonSchemaMixin):
|
class TableMetadata(dbtClassMixin):
|
||||||
type: str
|
type: str
|
||||||
database: Optional[str]
|
|
||||||
schema: str
|
schema: str
|
||||||
name: str
|
name: str
|
||||||
comment: Optional[str]
|
database: Optional[str] = None
|
||||||
owner: Optional[str]
|
comment: Optional[str] = None
|
||||||
|
owner: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CatalogTable(JsonSchemaMixin, Replaceable):
|
class CatalogTable(dbtClassMixin, Replaceable):
|
||||||
metadata: TableMetadata
|
metadata: TableMetadata
|
||||||
columns: ColumnMap
|
columns: ColumnMap
|
||||||
stats: StatsDict
|
stats: StatsDict
|
||||||
@@ -430,15 +438,21 @@ class CatalogMetadata(BaseArtifactMetadata):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CatalogResults(JsonSchemaMixin):
|
class CatalogResults(dbtClassMixin):
|
||||||
nodes: Dict[str, CatalogTable]
|
nodes: Dict[str, CatalogTable]
|
||||||
sources: Dict[str, CatalogTable]
|
sources: Dict[str, CatalogTable]
|
||||||
errors: Optional[List[str]]
|
errors: Optional[List[str]] = None
|
||||||
_compile_results: Optional[Any] = None
|
_compile_results: Optional[Any] = None
|
||||||
|
|
||||||
|
def __post_serialize__(self, dct):
|
||||||
|
dct = super().__post_serialize__(dct)
|
||||||
|
if "_compile_results" in dct:
|
||||||
|
del dct["_compile_results"]
|
||||||
|
return dct
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('catalog', 1)
|
@schema_version("catalog", 1)
|
||||||
class CatalogArtifact(CatalogResults, ArtifactMixin):
|
class CatalogArtifact(CatalogResults, ArtifactMixin):
|
||||||
metadata: CatalogMetadata
|
metadata: CatalogMetadata
|
||||||
|
|
||||||
@@ -449,8 +463,8 @@ class CatalogArtifact(CatalogResults, ArtifactMixin):
|
|||||||
nodes: Dict[str, CatalogTable],
|
nodes: Dict[str, CatalogTable],
|
||||||
sources: Dict[str, CatalogTable],
|
sources: Dict[str, CatalogTable],
|
||||||
compile_results: Optional[Any],
|
compile_results: Optional[Any],
|
||||||
errors: Optional[List[str]]
|
errors: Optional[List[str]],
|
||||||
) -> 'CatalogArtifact':
|
) -> "CatalogArtifact":
|
||||||
meta = CatalogMetadata(generated_at=generated_at)
|
meta = CatalogMetadata(generated_at=generated_at)
|
||||||
return cls(
|
return cls(
|
||||||
metadata=meta,
|
metadata=meta,
|
||||||
|
|||||||
@@ -5,13 +5,14 @@ from dataclasses import dataclass, field
|
|||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Optional, Union, List, Any, Dict, Type, Sequence
|
from typing import Optional, Union, List, Any, Dict, Type, Sequence
|
||||||
|
|
||||||
from hologram import JsonSchemaMixin
|
from dbt.dataclass_schema import dbtClassMixin, StrEnum
|
||||||
from hologram.helpers import StrEnum
|
|
||||||
|
|
||||||
from dbt.contracts.graph.compiled import CompileResultNode
|
from dbt.contracts.graph.compiled import CompileResultNode
|
||||||
from dbt.contracts.graph.manifest import WritableManifest
|
from dbt.contracts.graph.manifest import WritableManifest
|
||||||
from dbt.contracts.results import (
|
from dbt.contracts.results import (
|
||||||
RunResult, RunResultsArtifact, TimingInfo,
|
RunResult,
|
||||||
|
RunResultsArtifact,
|
||||||
|
TimingInfo,
|
||||||
CatalogArtifact,
|
CatalogArtifact,
|
||||||
CatalogResults,
|
CatalogResults,
|
||||||
ExecutionResult,
|
ExecutionResult,
|
||||||
@@ -34,16 +35,25 @@ TaskID = uuid.UUID
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RPCParameters(JsonSchemaMixin):
|
class RPCParameters(dbtClassMixin):
|
||||||
timeout: Optional[float]
|
|
||||||
task_tags: TaskTags
|
task_tags: TaskTags
|
||||||
|
timeout: Optional[float]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __pre_deserialize__(cls, data, omit_none=True):
|
||||||
|
data = super().__pre_deserialize__(data)
|
||||||
|
if "timeout" not in data:
|
||||||
|
data["timeout"] = None
|
||||||
|
if "task_tags" not in data:
|
||||||
|
data["task_tags"] = None
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RPCExecParameters(RPCParameters):
|
class RPCExecParameters(RPCParameters):
|
||||||
name: str
|
name: str
|
||||||
sql: str
|
sql: str
|
||||||
macros: Optional[str]
|
macros: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -132,7 +142,7 @@ class StatusParameters(RPCParameters):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GCSettings(JsonSchemaMixin):
|
class GCSettings(dbtClassMixin):
|
||||||
# start evicting the longest-ago-ended tasks here
|
# start evicting the longest-ago-ended tasks here
|
||||||
maxsize: int
|
maxsize: int
|
||||||
# start evicting all tasks before now - auto_reap_age when we have this
|
# start evicting all tasks before now - auto_reap_age when we have this
|
||||||
@@ -153,6 +163,7 @@ class GCParameters(RPCParameters):
|
|||||||
will be applied to the task manager before GC starts. By default the
|
will be applied to the task manager before GC starts. By default the
|
||||||
existing gc settings remain.
|
existing gc settings remain.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
task_ids: Optional[List[TaskID]] = None
|
task_ids: Optional[List[TaskID]] = None
|
||||||
before: Optional[datetime] = None
|
before: Optional[datetime] = None
|
||||||
settings: Optional[GCSettings] = None
|
settings: Optional[GCSettings] = None
|
||||||
@@ -174,6 +185,7 @@ class RPCSourceFreshnessParameters(RPCParameters):
|
|||||||
class GetManifestParameters(RPCParameters):
|
class GetManifestParameters(RPCParameters):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# Outputs
|
# Outputs
|
||||||
|
|
||||||
|
|
||||||
@@ -183,13 +195,13 @@ class RemoteResult(VersionedSchema):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('remote-deps-result', 1)
|
@schema_version("remote-deps-result", 1)
|
||||||
class RemoteDepsResult(RemoteResult):
|
class RemoteDepsResult(RemoteResult):
|
||||||
generated_at: datetime = field(default_factory=datetime.utcnow)
|
generated_at: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('remote-catalog-result', 1)
|
@schema_version("remote-catalog-result", 1)
|
||||||
class RemoteCatalogResults(CatalogResults, RemoteResult):
|
class RemoteCatalogResults(CatalogResults, RemoteResult):
|
||||||
generated_at: datetime = field(default_factory=datetime.utcnow)
|
generated_at: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
@@ -213,7 +225,7 @@ class RemoteCompileResultMixin(RemoteResult):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('remote-compile-result', 1)
|
@schema_version("remote-compile-result", 1)
|
||||||
class RemoteCompileResult(RemoteCompileResultMixin):
|
class RemoteCompileResult(RemoteCompileResultMixin):
|
||||||
generated_at: datetime = field(default_factory=datetime.utcnow)
|
generated_at: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
@@ -223,7 +235,7 @@ class RemoteCompileResult(RemoteCompileResultMixin):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('remote-execution-result', 1)
|
@schema_version("remote-execution-result", 1)
|
||||||
class RemoteExecutionResult(ExecutionResult, RemoteResult):
|
class RemoteExecutionResult(ExecutionResult, RemoteResult):
|
||||||
results: Sequence[RunResult]
|
results: Sequence[RunResult]
|
||||||
args: Dict[str, Any] = field(default_factory=dict)
|
args: Dict[str, Any] = field(default_factory=dict)
|
||||||
@@ -243,7 +255,7 @@ class RemoteExecutionResult(ExecutionResult, RemoteResult):
|
|||||||
cls,
|
cls,
|
||||||
base: RunExecutionResult,
|
base: RunExecutionResult,
|
||||||
logs: List[LogMessage],
|
logs: List[LogMessage],
|
||||||
) -> 'RemoteExecutionResult':
|
) -> "RemoteExecutionResult":
|
||||||
return cls(
|
return cls(
|
||||||
generated_at=base.generated_at,
|
generated_at=base.generated_at,
|
||||||
results=base.results,
|
results=base.results,
|
||||||
@@ -254,13 +266,13 @@ class RemoteExecutionResult(ExecutionResult, RemoteResult):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ResultTable(JsonSchemaMixin):
|
class ResultTable(dbtClassMixin):
|
||||||
column_names: List[str]
|
column_names: List[str]
|
||||||
rows: List[Any]
|
rows: List[Any]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('remote-run-operation-result', 1)
|
@schema_version("remote-run-operation-result", 1)
|
||||||
class RemoteRunOperationResult(RunOperationResult, RemoteResult):
|
class RemoteRunOperationResult(RunOperationResult, RemoteResult):
|
||||||
generated_at: datetime = field(default_factory=datetime.utcnow)
|
generated_at: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
@@ -269,7 +281,7 @@ class RemoteRunOperationResult(RunOperationResult, RemoteResult):
|
|||||||
cls,
|
cls,
|
||||||
base: RunOperationResultsArtifact,
|
base: RunOperationResultsArtifact,
|
||||||
logs: List[LogMessage],
|
logs: List[LogMessage],
|
||||||
) -> 'RemoteRunOperationResult':
|
) -> "RemoteRunOperationResult":
|
||||||
return cls(
|
return cls(
|
||||||
generated_at=base.metadata.generated_at,
|
generated_at=base.metadata.generated_at,
|
||||||
results=base.results,
|
results=base.results,
|
||||||
@@ -288,15 +300,14 @@ class RemoteRunOperationResult(RunOperationResult, RemoteResult):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('remote-freshness-result', 1)
|
@schema_version("remote-freshness-result", 1)
|
||||||
class RemoteFreshnessResult(FreshnessResult, RemoteResult):
|
class RemoteFreshnessResult(FreshnessResult, RemoteResult):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_local_result(
|
def from_local_result(
|
||||||
cls,
|
cls,
|
||||||
base: FreshnessResult,
|
base: FreshnessResult,
|
||||||
logs: List[LogMessage],
|
logs: List[LogMessage],
|
||||||
) -> 'RemoteFreshnessResult':
|
) -> "RemoteFreshnessResult":
|
||||||
return cls(
|
return cls(
|
||||||
metadata=base.metadata,
|
metadata=base.metadata,
|
||||||
results=base.results,
|
results=base.results,
|
||||||
@@ -310,7 +321,7 @@ class RemoteFreshnessResult(FreshnessResult, RemoteResult):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('remote-run-result', 1)
|
@schema_version("remote-run-result", 1)
|
||||||
class RemoteRunResult(RemoteCompileResultMixin):
|
class RemoteRunResult(RemoteCompileResultMixin):
|
||||||
table: ResultTable
|
table: ResultTable
|
||||||
generated_at: datetime = field(default_factory=datetime.utcnow)
|
generated_at: datetime = field(default_factory=datetime.utcnow)
|
||||||
@@ -328,14 +339,15 @@ RPCResult = Union[
|
|||||||
|
|
||||||
# GC types
|
# GC types
|
||||||
|
|
||||||
|
|
||||||
class GCResultState(StrEnum):
|
class GCResultState(StrEnum):
|
||||||
Deleted = 'deleted' # successful GC
|
Deleted = "deleted" # successful GC
|
||||||
Missing = 'missing' # nothing to GC
|
Missing = "missing" # nothing to GC
|
||||||
Running = 'running' # can't GC
|
Running = "running" # can't GC
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('remote-gc-result', 1)
|
@schema_version("remote-gc-result", 1)
|
||||||
class GCResult(RemoteResult):
|
class GCResult(RemoteResult):
|
||||||
logs: List[LogMessage] = field(default_factory=list)
|
logs: List[LogMessage] = field(default_factory=list)
|
||||||
deleted: List[TaskID] = field(default_factory=list)
|
deleted: List[TaskID] = field(default_factory=list)
|
||||||
@@ -350,21 +362,20 @@ class GCResult(RemoteResult):
|
|||||||
elif state == GCResultState.Deleted:
|
elif state == GCResultState.Deleted:
|
||||||
self.deleted.append(task_id)
|
self.deleted.append(task_id)
|
||||||
else:
|
else:
|
||||||
raise InternalException(
|
raise InternalException(f"Got invalid state in add_result: {state}")
|
||||||
f'Got invalid state in add_result: {state}'
|
|
||||||
)
|
|
||||||
|
|
||||||
# Task management types
|
# Task management types
|
||||||
|
|
||||||
|
|
||||||
class TaskHandlerState(StrEnum):
|
class TaskHandlerState(StrEnum):
|
||||||
NotStarted = 'not started'
|
NotStarted = "not started"
|
||||||
Initializing = 'initializing'
|
Initializing = "initializing"
|
||||||
Running = 'running'
|
Running = "running"
|
||||||
Success = 'success'
|
Success = "success"
|
||||||
Error = 'error'
|
Error = "error"
|
||||||
Killed = 'killed'
|
Killed = "killed"
|
||||||
Failed = 'failed'
|
Failed = "failed"
|
||||||
|
|
||||||
def __lt__(self, other) -> bool:
|
def __lt__(self, other) -> bool:
|
||||||
"""A logical ordering for TaskHandlerState:
|
"""A logical ordering for TaskHandlerState:
|
||||||
@@ -372,7 +383,7 @@ class TaskHandlerState(StrEnum):
|
|||||||
NotStarted < Initializing < Running < (Success, Error, Killed, Failed)
|
NotStarted < Initializing < Running < (Success, Error, Killed, Failed)
|
||||||
"""
|
"""
|
||||||
if not isinstance(other, TaskHandlerState):
|
if not isinstance(other, TaskHandlerState):
|
||||||
raise TypeError('cannot compare to non-TaskHandlerState')
|
raise TypeError("cannot compare to non-TaskHandlerState")
|
||||||
order = (self.NotStarted, self.Initializing, self.Running)
|
order = (self.NotStarted, self.Initializing, self.Running)
|
||||||
smaller = set()
|
smaller = set()
|
||||||
for value in order:
|
for value in order:
|
||||||
@@ -384,13 +395,11 @@ class TaskHandlerState(StrEnum):
|
|||||||
|
|
||||||
def __le__(self, other) -> bool:
|
def __le__(self, other) -> bool:
|
||||||
# so that ((Success <= Error) is True)
|
# so that ((Success <= Error) is True)
|
||||||
return ((self < other) or
|
return (self < other) or (self == other) or (self.finished and other.finished)
|
||||||
(self == other) or
|
|
||||||
(self.finished and other.finished))
|
|
||||||
|
|
||||||
def __gt__(self, other) -> bool:
|
def __gt__(self, other) -> bool:
|
||||||
if not isinstance(other, TaskHandlerState):
|
if not isinstance(other, TaskHandlerState):
|
||||||
raise TypeError('cannot compare to non-TaskHandlerState')
|
raise TypeError("cannot compare to non-TaskHandlerState")
|
||||||
order = (self.NotStarted, self.Initializing, self.Running)
|
order = (self.NotStarted, self.Initializing, self.Running)
|
||||||
smaller = set()
|
smaller = set()
|
||||||
for value in order:
|
for value in order:
|
||||||
@@ -401,9 +410,7 @@ class TaskHandlerState(StrEnum):
|
|||||||
|
|
||||||
def __ge__(self, other) -> bool:
|
def __ge__(self, other) -> bool:
|
||||||
# so that ((Success <= Error) is True)
|
# so that ((Success <= Error) is True)
|
||||||
return ((self > other) or
|
return (self > other) or (self == other) or (self.finished and other.finished)
|
||||||
(self == other) or
|
|
||||||
(self.finished and other.finished))
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def finished(self) -> bool:
|
def finished(self) -> bool:
|
||||||
@@ -411,47 +418,57 @@ class TaskHandlerState(StrEnum):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TaskTiming(JsonSchemaMixin):
|
class TaskTiming(dbtClassMixin):
|
||||||
state: TaskHandlerState
|
state: TaskHandlerState
|
||||||
start: Optional[datetime]
|
start: Optional[datetime]
|
||||||
end: Optional[datetime]
|
end: Optional[datetime]
|
||||||
elapsed: Optional[float]
|
elapsed: Optional[float]
|
||||||
|
|
||||||
|
# These ought to be defaults but superclass order doesn't
|
||||||
|
# allow that to work
|
||||||
|
@classmethod
|
||||||
|
def __pre_deserialize__(cls, data):
|
||||||
|
data = super().__pre_deserialize__(data)
|
||||||
|
for field_name in ("start", "end", "elapsed"):
|
||||||
|
if field_name not in data:
|
||||||
|
data[field_name] = None
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TaskRow(TaskTiming):
|
class TaskRow(TaskTiming):
|
||||||
task_id: TaskID
|
task_id: TaskID
|
||||||
request_id: Union[str, int]
|
|
||||||
request_source: str
|
request_source: str
|
||||||
method: str
|
method: str
|
||||||
timeout: Optional[float]
|
request_id: Union[str, int]
|
||||||
tags: TaskTags
|
tags: TaskTags = None
|
||||||
|
timeout: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('remote-ps-result', 1)
|
@schema_version("remote-ps-result", 1)
|
||||||
class PSResult(RemoteResult):
|
class PSResult(RemoteResult):
|
||||||
rows: List[TaskRow]
|
rows: List[TaskRow]
|
||||||
|
|
||||||
|
|
||||||
class KillResultStatus(StrEnum):
|
class KillResultStatus(StrEnum):
|
||||||
Missing = 'missing'
|
Missing = "missing"
|
||||||
NotStarted = 'not_started'
|
NotStarted = "not_started"
|
||||||
Killed = 'killed'
|
Killed = "killed"
|
||||||
Finished = 'finished'
|
Finished = "finished"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('remote-kill-result', 1)
|
@schema_version("remote-kill-result", 1)
|
||||||
class KillResult(RemoteResult):
|
class KillResult(RemoteResult):
|
||||||
state: KillResultStatus = KillResultStatus.Missing
|
state: KillResultStatus = KillResultStatus.Missing
|
||||||
logs: List[LogMessage] = field(default_factory=list)
|
logs: List[LogMessage] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('remote-manifest-result', 1)
|
@schema_version("remote-manifest-result", 1)
|
||||||
class GetManifestResult(RemoteResult):
|
class GetManifestResult(RemoteResult):
|
||||||
manifest: Optional[WritableManifest]
|
manifest: Optional[WritableManifest] = None
|
||||||
|
|
||||||
|
|
||||||
# this is kind of carefuly structured: BlocksManifestTasks is implied by
|
# this is kind of carefuly structured: BlocksManifestTasks is implied by
|
||||||
@@ -475,24 +492,33 @@ class PollResult(RemoteResult, TaskTiming):
|
|||||||
end: Optional[datetime]
|
end: Optional[datetime]
|
||||||
elapsed: Optional[float]
|
elapsed: Optional[float]
|
||||||
|
|
||||||
|
# These ought to be defaults but superclass order doesn't
|
||||||
|
# allow that to work
|
||||||
|
@classmethod
|
||||||
|
def __pre_deserialize__(cls, data):
|
||||||
|
data = super().__pre_deserialize__(data)
|
||||||
|
for field_name in ("start", "end", "elapsed"):
|
||||||
|
if field_name not in data:
|
||||||
|
data[field_name] = None
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('poll-remote-deps-result', 1)
|
@schema_version("poll-remote-deps-result", 1)
|
||||||
class PollRemoteEmptyCompleteResult(PollResult, RemoteResult):
|
class PollRemoteEmptyCompleteResult(PollResult, RemoteResult):
|
||||||
state: TaskHandlerState = field(
|
state: TaskHandlerState = field(
|
||||||
metadata=restrict_to(TaskHandlerState.Success,
|
metadata=restrict_to(TaskHandlerState.Success, TaskHandlerState.Failed),
|
||||||
TaskHandlerState.Failed),
|
|
||||||
)
|
)
|
||||||
generated_at: datetime = field(default_factory=datetime.utcnow)
|
generated_at: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_result(
|
def from_result(
|
||||||
cls: Type['PollRemoteEmptyCompleteResult'],
|
cls: Type["PollRemoteEmptyCompleteResult"],
|
||||||
base: RemoteDepsResult,
|
base: RemoteDepsResult,
|
||||||
tags: TaskTags,
|
tags: TaskTags,
|
||||||
timing: TaskTiming,
|
timing: TaskTiming,
|
||||||
logs: List[LogMessage],
|
logs: List[LogMessage],
|
||||||
) -> 'PollRemoteEmptyCompleteResult':
|
) -> "PollRemoteEmptyCompleteResult":
|
||||||
return cls(
|
return cls(
|
||||||
logs=logs,
|
logs=logs,
|
||||||
tags=tags,
|
tags=tags,
|
||||||
@@ -500,12 +526,12 @@ class PollRemoteEmptyCompleteResult(PollResult, RemoteResult):
|
|||||||
start=timing.start,
|
start=timing.start,
|
||||||
end=timing.end,
|
end=timing.end,
|
||||||
elapsed=timing.elapsed,
|
elapsed=timing.elapsed,
|
||||||
generated_at=base.generated_at
|
generated_at=base.generated_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('poll-remote-killed-result', 1)
|
@schema_version("poll-remote-killed-result", 1)
|
||||||
class PollKilledResult(PollResult):
|
class PollKilledResult(PollResult):
|
||||||
state: TaskHandlerState = field(
|
state: TaskHandlerState = field(
|
||||||
metadata=restrict_to(TaskHandlerState.Killed),
|
metadata=restrict_to(TaskHandlerState.Killed),
|
||||||
@@ -513,24 +539,23 @@ class PollKilledResult(PollResult):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('poll-remote-execution-result', 1)
|
@schema_version("poll-remote-execution-result", 1)
|
||||||
class PollExecuteCompleteResult(
|
class PollExecuteCompleteResult(
|
||||||
RemoteExecutionResult,
|
RemoteExecutionResult,
|
||||||
PollResult,
|
PollResult,
|
||||||
):
|
):
|
||||||
state: TaskHandlerState = field(
|
state: TaskHandlerState = field(
|
||||||
metadata=restrict_to(TaskHandlerState.Success,
|
metadata=restrict_to(TaskHandlerState.Success, TaskHandlerState.Failed),
|
||||||
TaskHandlerState.Failed),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_result(
|
def from_result(
|
||||||
cls: Type['PollExecuteCompleteResult'],
|
cls: Type["PollExecuteCompleteResult"],
|
||||||
base: RemoteExecutionResult,
|
base: RemoteExecutionResult,
|
||||||
tags: TaskTags,
|
tags: TaskTags,
|
||||||
timing: TaskTiming,
|
timing: TaskTiming,
|
||||||
logs: List[LogMessage],
|
logs: List[LogMessage],
|
||||||
) -> 'PollExecuteCompleteResult':
|
) -> "PollExecuteCompleteResult":
|
||||||
return cls(
|
return cls(
|
||||||
results=base.results,
|
results=base.results,
|
||||||
elapsed_time=base.elapsed_time,
|
elapsed_time=base.elapsed_time,
|
||||||
@@ -545,24 +570,23 @@ class PollExecuteCompleteResult(
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('poll-remote-compile-result', 1)
|
@schema_version("poll-remote-compile-result", 1)
|
||||||
class PollCompileCompleteResult(
|
class PollCompileCompleteResult(
|
||||||
RemoteCompileResult,
|
RemoteCompileResult,
|
||||||
PollResult,
|
PollResult,
|
||||||
):
|
):
|
||||||
state: TaskHandlerState = field(
|
state: TaskHandlerState = field(
|
||||||
metadata=restrict_to(TaskHandlerState.Success,
|
metadata=restrict_to(TaskHandlerState.Success, TaskHandlerState.Failed),
|
||||||
TaskHandlerState.Failed),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_result(
|
def from_result(
|
||||||
cls: Type['PollCompileCompleteResult'],
|
cls: Type["PollCompileCompleteResult"],
|
||||||
base: RemoteCompileResult,
|
base: RemoteCompileResult,
|
||||||
tags: TaskTags,
|
tags: TaskTags,
|
||||||
timing: TaskTiming,
|
timing: TaskTiming,
|
||||||
logs: List[LogMessage],
|
logs: List[LogMessage],
|
||||||
) -> 'PollCompileCompleteResult':
|
) -> "PollCompileCompleteResult":
|
||||||
return cls(
|
return cls(
|
||||||
raw_sql=base.raw_sql,
|
raw_sql=base.raw_sql,
|
||||||
compiled_sql=base.compiled_sql,
|
compiled_sql=base.compiled_sql,
|
||||||
@@ -574,29 +598,28 @@ class PollCompileCompleteResult(
|
|||||||
start=timing.start,
|
start=timing.start,
|
||||||
end=timing.end,
|
end=timing.end,
|
||||||
elapsed=timing.elapsed,
|
elapsed=timing.elapsed,
|
||||||
generated_at=base.generated_at
|
generated_at=base.generated_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('poll-remote-run-result', 1)
|
@schema_version("poll-remote-run-result", 1)
|
||||||
class PollRunCompleteResult(
|
class PollRunCompleteResult(
|
||||||
RemoteRunResult,
|
RemoteRunResult,
|
||||||
PollResult,
|
PollResult,
|
||||||
):
|
):
|
||||||
state: TaskHandlerState = field(
|
state: TaskHandlerState = field(
|
||||||
metadata=restrict_to(TaskHandlerState.Success,
|
metadata=restrict_to(TaskHandlerState.Success, TaskHandlerState.Failed),
|
||||||
TaskHandlerState.Failed),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_result(
|
def from_result(
|
||||||
cls: Type['PollRunCompleteResult'],
|
cls: Type["PollRunCompleteResult"],
|
||||||
base: RemoteRunResult,
|
base: RemoteRunResult,
|
||||||
tags: TaskTags,
|
tags: TaskTags,
|
||||||
timing: TaskTiming,
|
timing: TaskTiming,
|
||||||
logs: List[LogMessage],
|
logs: List[LogMessage],
|
||||||
) -> 'PollRunCompleteResult':
|
) -> "PollRunCompleteResult":
|
||||||
return cls(
|
return cls(
|
||||||
raw_sql=base.raw_sql,
|
raw_sql=base.raw_sql,
|
||||||
compiled_sql=base.compiled_sql,
|
compiled_sql=base.compiled_sql,
|
||||||
@@ -609,29 +632,28 @@ class PollRunCompleteResult(
|
|||||||
start=timing.start,
|
start=timing.start,
|
||||||
end=timing.end,
|
end=timing.end,
|
||||||
elapsed=timing.elapsed,
|
elapsed=timing.elapsed,
|
||||||
generated_at=base.generated_at
|
generated_at=base.generated_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('poll-remote-run-operation-result', 1)
|
@schema_version("poll-remote-run-operation-result", 1)
|
||||||
class PollRunOperationCompleteResult(
|
class PollRunOperationCompleteResult(
|
||||||
RemoteRunOperationResult,
|
RemoteRunOperationResult,
|
||||||
PollResult,
|
PollResult,
|
||||||
):
|
):
|
||||||
state: TaskHandlerState = field(
|
state: TaskHandlerState = field(
|
||||||
metadata=restrict_to(TaskHandlerState.Success,
|
metadata=restrict_to(TaskHandlerState.Success, TaskHandlerState.Failed),
|
||||||
TaskHandlerState.Failed),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_result(
|
def from_result(
|
||||||
cls: Type['PollRunOperationCompleteResult'],
|
cls: Type["PollRunOperationCompleteResult"],
|
||||||
base: RemoteRunOperationResult,
|
base: RemoteRunOperationResult,
|
||||||
tags: TaskTags,
|
tags: TaskTags,
|
||||||
timing: TaskTiming,
|
timing: TaskTiming,
|
||||||
logs: List[LogMessage],
|
logs: List[LogMessage],
|
||||||
) -> 'PollRunOperationCompleteResult':
|
) -> "PollRunOperationCompleteResult":
|
||||||
return cls(
|
return cls(
|
||||||
success=base.success,
|
success=base.success,
|
||||||
results=base.results,
|
results=base.results,
|
||||||
@@ -647,21 +669,20 @@ class PollRunOperationCompleteResult(
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('poll-remote-catalog-result', 1)
|
@schema_version("poll-remote-catalog-result", 1)
|
||||||
class PollCatalogCompleteResult(RemoteCatalogResults, PollResult):
|
class PollCatalogCompleteResult(RemoteCatalogResults, PollResult):
|
||||||
state: TaskHandlerState = field(
|
state: TaskHandlerState = field(
|
||||||
metadata=restrict_to(TaskHandlerState.Success,
|
metadata=restrict_to(TaskHandlerState.Success, TaskHandlerState.Failed),
|
||||||
TaskHandlerState.Failed),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_result(
|
def from_result(
|
||||||
cls: Type['PollCatalogCompleteResult'],
|
cls: Type["PollCatalogCompleteResult"],
|
||||||
base: RemoteCatalogResults,
|
base: RemoteCatalogResults,
|
||||||
tags: TaskTags,
|
tags: TaskTags,
|
||||||
timing: TaskTiming,
|
timing: TaskTiming,
|
||||||
logs: List[LogMessage],
|
logs: List[LogMessage],
|
||||||
) -> 'PollCatalogCompleteResult':
|
) -> "PollCatalogCompleteResult":
|
||||||
return cls(
|
return cls(
|
||||||
nodes=base.nodes,
|
nodes=base.nodes,
|
||||||
sources=base.sources,
|
sources=base.sources,
|
||||||
@@ -678,27 +699,26 @@ class PollCatalogCompleteResult(RemoteCatalogResults, PollResult):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('poll-remote-in-progress-result', 1)
|
@schema_version("poll-remote-in-progress-result", 1)
|
||||||
class PollInProgressResult(PollResult):
|
class PollInProgressResult(PollResult):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('poll-remote-get-manifest-result', 1)
|
@schema_version("poll-remote-get-manifest-result", 1)
|
||||||
class PollGetManifestResult(GetManifestResult, PollResult):
|
class PollGetManifestResult(GetManifestResult, PollResult):
|
||||||
state: TaskHandlerState = field(
|
state: TaskHandlerState = field(
|
||||||
metadata=restrict_to(TaskHandlerState.Success,
|
metadata=restrict_to(TaskHandlerState.Success, TaskHandlerState.Failed),
|
||||||
TaskHandlerState.Failed),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_result(
|
def from_result(
|
||||||
cls: Type['PollGetManifestResult'],
|
cls: Type["PollGetManifestResult"],
|
||||||
base: GetManifestResult,
|
base: GetManifestResult,
|
||||||
tags: TaskTags,
|
tags: TaskTags,
|
||||||
timing: TaskTiming,
|
timing: TaskTiming,
|
||||||
logs: List[LogMessage],
|
logs: List[LogMessage],
|
||||||
) -> 'PollGetManifestResult':
|
) -> "PollGetManifestResult":
|
||||||
return cls(
|
return cls(
|
||||||
manifest=base.manifest,
|
manifest=base.manifest,
|
||||||
logs=logs,
|
logs=logs,
|
||||||
@@ -711,21 +731,20 @@ class PollGetManifestResult(GetManifestResult, PollResult):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('poll-remote-freshness-result', 1)
|
@schema_version("poll-remote-freshness-result", 1)
|
||||||
class PollFreshnessResult(RemoteFreshnessResult, PollResult):
|
class PollFreshnessResult(RemoteFreshnessResult, PollResult):
|
||||||
state: TaskHandlerState = field(
|
state: TaskHandlerState = field(
|
||||||
metadata=restrict_to(TaskHandlerState.Success,
|
metadata=restrict_to(TaskHandlerState.Success, TaskHandlerState.Failed),
|
||||||
TaskHandlerState.Failed),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_result(
|
def from_result(
|
||||||
cls: Type['PollFreshnessResult'],
|
cls: Type["PollFreshnessResult"],
|
||||||
base: RemoteFreshnessResult,
|
base: RemoteFreshnessResult,
|
||||||
tags: TaskTags,
|
tags: TaskTags,
|
||||||
timing: TaskTiming,
|
timing: TaskTiming,
|
||||||
logs: List[LogMessage],
|
logs: List[LogMessage],
|
||||||
) -> 'PollFreshnessResult':
|
) -> "PollFreshnessResult":
|
||||||
return cls(
|
return cls(
|
||||||
logs=logs,
|
logs=logs,
|
||||||
tags=tags,
|
tags=tags,
|
||||||
@@ -738,18 +757,19 @@ class PollFreshnessResult(RemoteFreshnessResult, PollResult):
|
|||||||
elapsed_time=base.elapsed_time,
|
elapsed_time=base.elapsed_time,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Manifest parsing types
|
# Manifest parsing types
|
||||||
|
|
||||||
|
|
||||||
class ManifestStatus(StrEnum):
|
class ManifestStatus(StrEnum):
|
||||||
Init = 'init'
|
Init = "init"
|
||||||
Compiling = 'compiling'
|
Compiling = "compiling"
|
||||||
Ready = 'ready'
|
Ready = "ready"
|
||||||
Error = 'error'
|
Error = "error"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@schema_version('remote-status-result', 1)
|
@schema_version("remote-status-result", 1)
|
||||||
class LastParse(RemoteResult):
|
class LastParse(RemoteResult):
|
||||||
state: ManifestStatus = ManifestStatus.Init
|
state: ManifestStatus = ManifestStatus.Init
|
||||||
logs: List[LogMessage] = field(default_factory=list)
|
logs: List[LogMessage] = field(default_factory=list)
|
||||||
|
|||||||
@@ -1,18 +1,18 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from hologram import JsonSchemaMixin
|
from dbt.dataclass_schema import dbtClassMixin
|
||||||
|
|
||||||
from typing import List, Dict, Any, Union
|
from typing import List, Dict, Any, Union
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SelectorDefinition(JsonSchemaMixin):
|
class SelectorDefinition(dbtClassMixin):
|
||||||
name: str
|
name: str
|
||||||
definition: Union[str, Dict[str, Any]]
|
definition: Union[str, Dict[str, Any]]
|
||||||
description: str = ''
|
description: str = ""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SelectorFile(JsonSchemaMixin):
|
class SelectorFile(dbtClassMixin):
|
||||||
selectors: List[SelectorDefinition]
|
selectors: List[SelectorDefinition]
|
||||||
version: int = 2
|
version: int = 2
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ class PreviousState:
|
|||||||
self.path: Path = path
|
self.path: Path = path
|
||||||
self.manifest: Optional[WritableManifest] = None
|
self.manifest: Optional[WritableManifest] = None
|
||||||
|
|
||||||
manifest_path = self.path / 'manifest.json'
|
manifest_path = self.path / "manifest.json"
|
||||||
if manifest_path.exists() and manifest_path.is_file():
|
if manifest_path.exists() and manifest_path.is_file():
|
||||||
try:
|
try:
|
||||||
self.manifest = WritableManifest.read(str(manifest_path))
|
self.manifest = WritableManifest.read(str(manifest_path))
|
||||||
|
|||||||
@@ -1,19 +1,16 @@
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import (
|
from typing import List, Tuple, ClassVar, Type, TypeVar, Dict, Any, Optional
|
||||||
List, Tuple, ClassVar, Type, TypeVar, Dict, Any, Optional
|
|
||||||
)
|
|
||||||
|
|
||||||
from dbt.clients.system import write_json, read_json
|
from dbt.clients.system import write_json, read_json
|
||||||
from dbt.exceptions import (
|
from dbt.exceptions import (
|
||||||
IncompatibleSchemaException,
|
|
||||||
InternalException,
|
InternalException,
|
||||||
RuntimeException,
|
RuntimeException,
|
||||||
)
|
)
|
||||||
from dbt.version import __version__
|
from dbt.version import __version__
|
||||||
from dbt.tracking import get_invocation_id
|
from dbt.tracking import get_invocation_id
|
||||||
from hologram import JsonSchemaMixin
|
from dbt.dataclass_schema import dbtClassMixin
|
||||||
|
|
||||||
MacroKey = Tuple[str, str]
|
MacroKey = Tuple[str, str]
|
||||||
SourceKey = Tuple[str, str]
|
SourceKey = Tuple[str, str]
|
||||||
@@ -57,8 +54,8 @@ class Mergeable(Replaceable):
|
|||||||
|
|
||||||
|
|
||||||
class Writable:
|
class Writable:
|
||||||
def write(self, path: str, omit_none: bool = False):
|
def write(self, path: str):
|
||||||
write_json(path, self.to_dict(omit_none=omit_none)) # type: ignore
|
write_json(path, self.to_dict(omit_none=False)) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class AdditionalPropertiesMixin:
|
class AdditionalPropertiesMixin:
|
||||||
@@ -67,24 +64,44 @@ class AdditionalPropertiesMixin:
|
|||||||
The underlying class definition must include a type definition for a field
|
The underlying class definition must include a type definition for a field
|
||||||
named '_extra' that is of type `Dict[str, Any]`.
|
named '_extra' that is of type `Dict[str, Any]`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ADDITIONAL_PROPERTIES = True
|
ADDITIONAL_PROPERTIES = True
|
||||||
|
|
||||||
|
# This takes attributes in the dictionary that are
|
||||||
|
# not in the class definitions and puts them in an
|
||||||
|
# _extra dict in the class
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data, validate=True):
|
def __pre_deserialize__(cls, data):
|
||||||
self = super().from_dict(data=data, validate=validate)
|
# dir() did not work because fields with
|
||||||
keys = self.to_dict(validate=False, omit_none=False)
|
# metadata settings are not found
|
||||||
|
# The original version of this would create the
|
||||||
|
# object first and then update extra with the
|
||||||
|
# extra keys, but that won't work here, so
|
||||||
|
# we're copying the dict so we don't insert the
|
||||||
|
# _extra in the original data. This also requires
|
||||||
|
# that Mashumaro actually build the '_extra' field
|
||||||
|
cls_keys = cls._get_field_names()
|
||||||
|
new_dict = {}
|
||||||
for key, value in data.items():
|
for key, value in data.items():
|
||||||
if key not in keys:
|
if key not in cls_keys and key != "_extra":
|
||||||
self.extra[key] = value
|
if "_extra" not in new_dict:
|
||||||
return self
|
new_dict["_extra"] = {}
|
||||||
|
new_dict["_extra"][key] = value
|
||||||
|
else:
|
||||||
|
new_dict[key] = value
|
||||||
|
data = new_dict
|
||||||
|
data = super().__pre_deserialize__(data)
|
||||||
|
return data
|
||||||
|
|
||||||
def to_dict(self, omit_none=True, validate=False):
|
def __post_serialize__(self, dct):
|
||||||
data = super().to_dict(omit_none=omit_none, validate=validate)
|
data = super().__post_serialize__(dct)
|
||||||
data.update(self.extra)
|
data.update(self.extra)
|
||||||
|
if "_extra" in data:
|
||||||
|
del data["_extra"]
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def replace(self, **kwargs):
|
def replace(self, **kwargs):
|
||||||
dct = self.to_dict(omit_none=False, validate=False)
|
dct = self.to_dict(omit_none=False)
|
||||||
dct.update(kwargs)
|
dct.update(kwargs)
|
||||||
return self.from_dict(dct)
|
return self.from_dict(dct)
|
||||||
|
|
||||||
@@ -106,7 +123,8 @@ class Readable:
|
|||||||
return cls.from_dict(data) # type: ignore
|
return cls.from_dict(data) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
BASE_SCHEMAS_URL = 'https://schemas.getdbt.com/dbt/{name}/v{version}.json'
|
BASE_SCHEMAS_URL = "https://schemas.getdbt.com/"
|
||||||
|
SCHEMA_PATH = "dbt/{name}/v{version}.json"
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
@@ -114,36 +132,34 @@ class SchemaVersion:
|
|||||||
name: str
|
name: str
|
||||||
version: int
|
version: int
|
||||||
|
|
||||||
|
@property
|
||||||
|
def path(self) -> str:
|
||||||
|
return SCHEMA_PATH.format(name=self.name, version=self.version)
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return BASE_SCHEMAS_URL.format(
|
return BASE_SCHEMAS_URL + self.path
|
||||||
name=self.name,
|
|
||||||
version=self.version,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
SCHEMA_VERSION_KEY = 'dbt_schema_version'
|
SCHEMA_VERSION_KEY = "dbt_schema_version"
|
||||||
|
|
||||||
|
|
||||||
METADATA_ENV_PREFIX = 'DBT_ENV_CUSTOM_ENV_'
|
METADATA_ENV_PREFIX = "DBT_ENV_CUSTOM_ENV_"
|
||||||
|
|
||||||
|
|
||||||
def get_metadata_env() -> Dict[str, str]:
|
def get_metadata_env() -> Dict[str, str]:
|
||||||
return {
|
return {
|
||||||
k[len(METADATA_ENV_PREFIX):]: v for k, v in os.environ.items()
|
k[len(METADATA_ENV_PREFIX) :]: v
|
||||||
|
for k, v in os.environ.items()
|
||||||
if k.startswith(METADATA_ENV_PREFIX)
|
if k.startswith(METADATA_ENV_PREFIX)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class BaseArtifactMetadata(JsonSchemaMixin):
|
class BaseArtifactMetadata(dbtClassMixin):
|
||||||
dbt_schema_version: str
|
dbt_schema_version: str
|
||||||
dbt_version: str = __version__
|
dbt_version: str = __version__
|
||||||
generated_at: datetime = dataclasses.field(
|
generated_at: datetime = dataclasses.field(default_factory=datetime.utcnow)
|
||||||
default_factory=datetime.utcnow
|
invocation_id: Optional[str] = dataclasses.field(default_factory=get_invocation_id)
|
||||||
)
|
|
||||||
invocation_id: Optional[str] = dataclasses.field(
|
|
||||||
default_factory=get_invocation_id
|
|
||||||
)
|
|
||||||
env: Dict[str, str] = dataclasses.field(default_factory=get_metadata_env)
|
env: Dict[str, str] = dataclasses.field(default_factory=get_metadata_env)
|
||||||
|
|
||||||
|
|
||||||
@@ -154,22 +170,23 @@ def schema_version(name: str, version: int):
|
|||||||
version=version,
|
version=version,
|
||||||
)
|
)
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
return inner
|
return inner
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class VersionedSchema(JsonSchemaMixin):
|
class VersionedSchema(dbtClassMixin):
|
||||||
dbt_schema_version: ClassVar[SchemaVersion]
|
dbt_schema_version: ClassVar[SchemaVersion]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def json_schema(cls, embeddable: bool = False) -> Dict[str, Any]:
|
def json_schema(cls, embeddable: bool = False) -> Dict[str, Any]:
|
||||||
result = super().json_schema(embeddable=embeddable)
|
result = super().json_schema(embeddable=embeddable)
|
||||||
if not embeddable:
|
if not embeddable:
|
||||||
result['$id'] = str(cls.dbt_schema_version)
|
result["$id"] = str(cls.dbt_schema_version)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar('T', bound='ArtifactMixin')
|
T = TypeVar("T", bound="ArtifactMixin")
|
||||||
|
|
||||||
|
|
||||||
# metadata should really be a Generic[T_M] where T_M is a TypeVar bound to
|
# metadata should really be a Generic[T_M] where T_M is a TypeVar bound to
|
||||||
@@ -180,18 +197,7 @@ class ArtifactMixin(VersionedSchema, Writable, Readable):
|
|||||||
metadata: BaseArtifactMetadata
|
metadata: BaseArtifactMetadata
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(
|
def validate(cls, data):
|
||||||
cls: Type[T], data: Dict[str, Any], validate: bool = True
|
super().validate(data)
|
||||||
) -> T:
|
|
||||||
if cls.dbt_schema_version is None:
|
if cls.dbt_schema_version is None:
|
||||||
raise InternalException(
|
raise InternalException("Cannot call from_dict with no schema version!")
|
||||||
'Cannot call from_dict with no schema version!'
|
|
||||||
)
|
|
||||||
|
|
||||||
if validate:
|
|
||||||
expected = str(cls.dbt_schema_version)
|
|
||||||
found = data.get('metadata', {}).get(SCHEMA_VERSION_KEY)
|
|
||||||
if found != expected:
|
|
||||||
raise IncompatibleSchemaException(expected, found)
|
|
||||||
|
|
||||||
return super().from_dict(data=data, validate=validate)
|
|
||||||
|
|||||||
165
core/dbt/dataclass_schema.py
Normal file
165
core/dbt/dataclass_schema.py
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
from typing import (
|
||||||
|
Type,
|
||||||
|
ClassVar,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
import re
|
||||||
|
from dataclasses import fields
|
||||||
|
from enum import Enum
|
||||||
|
from datetime import datetime
|
||||||
|
from dateutil.parser import parse
|
||||||
|
|
||||||
|
from hologram import JsonSchemaMixin, FieldEncoder, ValidationError
|
||||||
|
|
||||||
|
# type: ignore
|
||||||
|
from mashumaro import DataClassDictMixin
|
||||||
|
from mashumaro.config import TO_DICT_ADD_OMIT_NONE_FLAG, BaseConfig as MashBaseConfig
|
||||||
|
from mashumaro.types import SerializableType, SerializationStrategy
|
||||||
|
|
||||||
|
|
||||||
|
class DateTimeSerialization(SerializationStrategy):
|
||||||
|
def serialize(self, value):
|
||||||
|
out = value.isoformat()
|
||||||
|
# Assume UTC if timezone is missing
|
||||||
|
if value.tzinfo is None:
|
||||||
|
out = out + "Z"
|
||||||
|
return out
|
||||||
|
|
||||||
|
def deserialize(self, value):
|
||||||
|
return value if isinstance(value, datetime) else parse(cast(str, value))
|
||||||
|
|
||||||
|
|
||||||
|
# This class pulls in both JsonSchemaMixin from Hologram and
|
||||||
|
# DataClassDictMixin from our fork of Mashumaro. The 'to_dict'
|
||||||
|
# and 'from_dict' methods come from Mashumaro. Building
|
||||||
|
# jsonschemas for every class and the 'validate' method
|
||||||
|
# come from Hologram.
|
||||||
|
class dbtClassMixin(DataClassDictMixin, JsonSchemaMixin):
|
||||||
|
"""Mixin which adds methods to generate a JSON schema and
|
||||||
|
convert to and from JSON encodable dicts with validation
|
||||||
|
against the schema
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Config(MashBaseConfig):
|
||||||
|
code_generation_options = [
|
||||||
|
TO_DICT_ADD_OMIT_NONE_FLAG,
|
||||||
|
]
|
||||||
|
serialization_strategy = {
|
||||||
|
datetime: DateTimeSerialization(),
|
||||||
|
}
|
||||||
|
|
||||||
|
_hyphenated: ClassVar[bool] = False
|
||||||
|
ADDITIONAL_PROPERTIES: ClassVar[bool] = False
|
||||||
|
|
||||||
|
# This is called by the mashumaro to_dict in order to handle
|
||||||
|
# nested classes.
|
||||||
|
# Munges the dict that's returned.
|
||||||
|
def __post_serialize__(self, dct):
|
||||||
|
if self._hyphenated:
|
||||||
|
new_dict = {}
|
||||||
|
for key in dct:
|
||||||
|
if "_" in key:
|
||||||
|
new_key = key.replace("_", "-")
|
||||||
|
new_dict[new_key] = dct[key]
|
||||||
|
else:
|
||||||
|
new_dict[key] = dct[key]
|
||||||
|
dct = new_dict
|
||||||
|
|
||||||
|
return dct
|
||||||
|
|
||||||
|
# This is called by the mashumaro _from_dict method, before
|
||||||
|
# performing the conversion to a dict
|
||||||
|
@classmethod
|
||||||
|
def __pre_deserialize__(cls, data):
|
||||||
|
if cls._hyphenated:
|
||||||
|
new_dict = {}
|
||||||
|
for key in data:
|
||||||
|
if "-" in key:
|
||||||
|
new_key = key.replace("-", "_")
|
||||||
|
new_dict[new_key] = data[key]
|
||||||
|
else:
|
||||||
|
new_dict[key] = data[key]
|
||||||
|
data = new_dict
|
||||||
|
return data
|
||||||
|
|
||||||
|
# This is used in the hologram._encode_field method, which calls
|
||||||
|
# a 'to_dict' method which does not have the same parameters in
|
||||||
|
# hologram and in mashumaro.
|
||||||
|
def _local_to_dict(self, **kwargs):
|
||||||
|
args = {}
|
||||||
|
if "omit_none" in kwargs:
|
||||||
|
args["omit_none"] = kwargs["omit_none"]
|
||||||
|
return self.to_dict(**args)
|
||||||
|
|
||||||
|
|
||||||
|
class ValidatedStringMixin(str, SerializableType):
|
||||||
|
ValidationRegex = ""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _deserialize(cls, value: str) -> "ValidatedStringMixin":
|
||||||
|
cls.validate(value)
|
||||||
|
return ValidatedStringMixin(value)
|
||||||
|
|
||||||
|
def _serialize(self) -> str:
|
||||||
|
return str(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate(cls, value):
|
||||||
|
res = re.match(cls.ValidationRegex, value)
|
||||||
|
|
||||||
|
if res is None:
|
||||||
|
raise ValidationError(f"Invalid value: {value}") # TODO
|
||||||
|
|
||||||
|
|
||||||
|
# These classes must be in this order or it doesn't work
|
||||||
|
class StrEnum(str, SerializableType, Enum):
|
||||||
|
def __str__(self):
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
# https://docs.python.org/3.6/library/enum.html#using-automatic-values
|
||||||
|
def _generate_next_value_(name, *_):
|
||||||
|
return name
|
||||||
|
|
||||||
|
def _serialize(self) -> str:
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _deserialize(cls, value: str):
|
||||||
|
return cls(value)
|
||||||
|
|
||||||
|
|
||||||
|
class HyphenatedDbtClassMixin(dbtClassMixin):
|
||||||
|
# used by from_dict/to_dict
|
||||||
|
_hyphenated: ClassVar[bool] = True
|
||||||
|
|
||||||
|
# used by jsonschema validation, _get_fields
|
||||||
|
@classmethod
|
||||||
|
def field_mapping(cls):
|
||||||
|
result = {}
|
||||||
|
for field in fields(cls):
|
||||||
|
skip = field.metadata.get("preserve_underscore")
|
||||||
|
if skip:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if "_" in field.name:
|
||||||
|
result[field.name] = field.name.replace("_", "-")
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class ExtensibleDbtClassMixin(dbtClassMixin):
|
||||||
|
ADDITIONAL_PROPERTIES = True
|
||||||
|
|
||||||
|
|
||||||
|
# This is used by Hologram in jsonschema validation
|
||||||
|
def register_pattern(base_type: Type, pattern: str) -> None:
|
||||||
|
"""base_type should be a typing.NewType that should always have the given
|
||||||
|
regex pattern. That means that its underlying type ('__supertype__') had
|
||||||
|
better be a str!
|
||||||
|
"""
|
||||||
|
|
||||||
|
class PatternEncoder(FieldEncoder):
|
||||||
|
@property
|
||||||
|
def json_schema(self):
|
||||||
|
return {"type": "string", "pattern": pattern}
|
||||||
|
|
||||||
|
dbtClassMixin.register_field_encoders({base_type: PatternEncoder()})
|
||||||
@@ -14,39 +14,31 @@ class DBTDeprecation:
|
|||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
if self._name is not None:
|
if self._name is not None:
|
||||||
return self._name
|
return self._name
|
||||||
raise NotImplementedError(
|
raise NotImplementedError("name not implemented for {}".format(self))
|
||||||
'name not implemented for {}'.format(self)
|
|
||||||
)
|
|
||||||
|
|
||||||
def track_deprecation_warn(self) -> None:
|
def track_deprecation_warn(self) -> None:
|
||||||
if dbt.tracking.active_user is not None:
|
if dbt.tracking.active_user is not None:
|
||||||
dbt.tracking.track_deprecation_warn({
|
dbt.tracking.track_deprecation_warn({"deprecation_name": self.name})
|
||||||
"deprecation_name": self.name
|
|
||||||
})
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
if self._description is not None:
|
if self._description is not None:
|
||||||
return self._description
|
return self._description
|
||||||
raise NotImplementedError(
|
raise NotImplementedError("description not implemented for {}".format(self))
|
||||||
'description not implemented for {}'.format(self)
|
|
||||||
)
|
|
||||||
|
|
||||||
def show(self, *args, **kwargs) -> None:
|
def show(self, *args, **kwargs) -> None:
|
||||||
if self.name not in active_deprecations:
|
if self.name not in active_deprecations:
|
||||||
desc = self.description.format(**kwargs)
|
desc = self.description.format(**kwargs)
|
||||||
msg = ui.line_wrap_message(
|
msg = ui.line_wrap_message(desc, prefix="* Deprecation Warning: ")
|
||||||
desc, prefix='* Deprecation Warning: '
|
|
||||||
)
|
|
||||||
dbt.exceptions.warn_or_error(msg)
|
dbt.exceptions.warn_or_error(msg)
|
||||||
self.track_deprecation_warn()
|
self.track_deprecation_warn()
|
||||||
active_deprecations.add(self.name)
|
active_deprecations.add(self.name)
|
||||||
|
|
||||||
|
|
||||||
class MaterializationReturnDeprecation(DBTDeprecation):
|
class MaterializationReturnDeprecation(DBTDeprecation):
|
||||||
_name = 'materialization-return'
|
_name = "materialization-return"
|
||||||
|
|
||||||
_description = '''\
|
_description = """\
|
||||||
The materialization ("{materialization}") did not explicitly return a list
|
The materialization ("{materialization}") did not explicitly return a list
|
||||||
of relations to add to the cache. By default the target relation will be
|
of relations to add to the cache. By default the target relation will be
|
||||||
added, but this behavior will be removed in a future version of dbt.
|
added, but this behavior will be removed in a future version of dbt.
|
||||||
@@ -56,22 +48,22 @@ class MaterializationReturnDeprecation(DBTDeprecation):
|
|||||||
For more information, see:
|
For more information, see:
|
||||||
|
|
||||||
https://docs.getdbt.com/v0.15/docs/creating-new-materializations#section-6-returning-relations
|
https://docs.getdbt.com/v0.15/docs/creating-new-materializations#section-6-returning-relations
|
||||||
'''
|
"""
|
||||||
|
|
||||||
|
|
||||||
class NotADictionaryDeprecation(DBTDeprecation):
|
class NotADictionaryDeprecation(DBTDeprecation):
|
||||||
_name = 'not-a-dictionary'
|
_name = "not-a-dictionary"
|
||||||
|
|
||||||
_description = '''\
|
_description = """\
|
||||||
The object ("{obj}") was used as a dictionary. In a future version of dbt
|
The object ("{obj}") was used as a dictionary. In a future version of dbt
|
||||||
this capability will be removed from objects of this type.
|
this capability will be removed from objects of this type.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
|
|
||||||
class ColumnQuotingDeprecation(DBTDeprecation):
|
class ColumnQuotingDeprecation(DBTDeprecation):
|
||||||
_name = 'column-quoting-unset'
|
_name = "column-quoting-unset"
|
||||||
|
|
||||||
_description = '''\
|
_description = """\
|
||||||
The quote_columns parameter was not set for seeds, so the default value of
|
The quote_columns parameter was not set for seeds, so the default value of
|
||||||
False was chosen. The default will change to True in a future release.
|
False was chosen. The default will change to True in a future release.
|
||||||
|
|
||||||
@@ -80,13 +72,13 @@ class ColumnQuotingDeprecation(DBTDeprecation):
|
|||||||
For more information, see:
|
For more information, see:
|
||||||
|
|
||||||
https://docs.getdbt.com/v0.15/docs/seeds#section-specify-column-quoting
|
https://docs.getdbt.com/v0.15/docs/seeds#section-specify-column-quoting
|
||||||
'''
|
"""
|
||||||
|
|
||||||
|
|
||||||
class ModelsKeyNonModelDeprecation(DBTDeprecation):
|
class ModelsKeyNonModelDeprecation(DBTDeprecation):
|
||||||
_name = 'models-key-mismatch'
|
_name = "models-key-mismatch"
|
||||||
|
|
||||||
_description = '''\
|
_description = """\
|
||||||
"{node.name}" is a {node.resource_type} node, but it is specified in
|
"{node.name}" is a {node.resource_type} node, but it is specified in
|
||||||
the {patch.yaml_key} section of {patch.original_file_path}.
|
the {patch.yaml_key} section of {patch.original_file_path}.
|
||||||
|
|
||||||
@@ -96,25 +88,25 @@ class ModelsKeyNonModelDeprecation(DBTDeprecation):
|
|||||||
the {expected_key} key instead.
|
the {expected_key} key instead.
|
||||||
|
|
||||||
This warning will become an error in a future release.
|
This warning will become an error in a future release.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
|
|
||||||
class ExecuteMacrosReleaseDeprecation(DBTDeprecation):
|
class ExecuteMacrosReleaseDeprecation(DBTDeprecation):
|
||||||
_name = 'execute-macro-release'
|
_name = "execute-macro-release"
|
||||||
_description = '''\
|
_description = """\
|
||||||
The "release" argument to execute_macro is now ignored, and will be removed
|
The "release" argument to execute_macro is now ignored, and will be removed
|
||||||
in a future relase of dbt. At that time, providing a `release` argument
|
in a future relase of dbt. At that time, providing a `release` argument
|
||||||
will result in an error.
|
will result in an error.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
|
|
||||||
class AdapterMacroDeprecation(DBTDeprecation):
|
class AdapterMacroDeprecation(DBTDeprecation):
|
||||||
_name = 'adapter-macro'
|
_name = "adapter-macro"
|
||||||
_description = '''\
|
_description = """\
|
||||||
The "adapter_macro" macro has been deprecated. Instead, use the
|
The "adapter_macro" macro has been deprecated. Instead, use the
|
||||||
`adapter.dispatch` method to find a macro and call the result.
|
`adapter.dispatch` method to find a macro and call the result.
|
||||||
adapter_macro was called for: {macro_name}
|
adapter_macro was called for: {macro_name}
|
||||||
'''
|
"""
|
||||||
|
|
||||||
|
|
||||||
_adapter_renamed_description = """\
|
_adapter_renamed_description = """\
|
||||||
@@ -128,11 +120,11 @@ Documentation for {new_name} can be found here:
|
|||||||
|
|
||||||
|
|
||||||
def renamed_method(old_name: str, new_name: str):
|
def renamed_method(old_name: str, new_name: str):
|
||||||
|
|
||||||
class AdapterDeprecationWarning(DBTDeprecation):
|
class AdapterDeprecationWarning(DBTDeprecation):
|
||||||
_name = 'adapter:{}'.format(old_name)
|
_name = "adapter:{}".format(old_name)
|
||||||
_description = _adapter_renamed_description.format(old_name=old_name,
|
_description = _adapter_renamed_description.format(
|
||||||
new_name=new_name)
|
old_name=old_name, new_name=new_name
|
||||||
|
)
|
||||||
|
|
||||||
dep = AdapterDeprecationWarning()
|
dep = AdapterDeprecationWarning()
|
||||||
deprecations_list.append(dep)
|
deprecations_list.append(dep)
|
||||||
@@ -142,9 +134,7 @@ def renamed_method(old_name: str, new_name: str):
|
|||||||
def warn(name, *args, **kwargs):
|
def warn(name, *args, **kwargs):
|
||||||
if name not in deprecations:
|
if name not in deprecations:
|
||||||
# this should (hopefully) never happen
|
# this should (hopefully) never happen
|
||||||
raise RuntimeError(
|
raise RuntimeError("Error showing deprecation warning: {}".format(name))
|
||||||
"Error showing deprecation warning: {}".format(name)
|
|
||||||
)
|
|
||||||
|
|
||||||
deprecations[name].show(*args, **kwargs)
|
deprecations[name].show(*args, **kwargs)
|
||||||
|
|
||||||
@@ -163,9 +153,7 @@ deprecations_list: List[DBTDeprecation] = [
|
|||||||
AdapterMacroDeprecation(),
|
AdapterMacroDeprecation(),
|
||||||
]
|
]
|
||||||
|
|
||||||
deprecations: Dict[str, DBTDeprecation] = {
|
deprecations: Dict[str, DBTDeprecation] = {d.name: d for d in deprecations_list}
|
||||||
d.name: d for d in deprecations_list
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def reset_deprecations():
|
def reset_deprecations():
|
||||||
|
|||||||
@@ -22,12 +22,12 @@ def downloads_directory():
|
|||||||
# the user might have set an environment variable. Set it to that, and do
|
# the user might have set an environment variable. Set it to that, and do
|
||||||
# not remove it when finished.
|
# not remove it when finished.
|
||||||
if DOWNLOADS_PATH is None:
|
if DOWNLOADS_PATH is None:
|
||||||
DOWNLOADS_PATH = os.getenv('DBT_DOWNLOADS_DIR')
|
DOWNLOADS_PATH = os.getenv("DBT_DOWNLOADS_DIR")
|
||||||
remove_downloads = False
|
remove_downloads = False
|
||||||
# if we are making a per-run temp directory, remove it at the end of
|
# if we are making a per-run temp directory, remove it at the end of
|
||||||
# successful runs
|
# successful runs
|
||||||
if DOWNLOADS_PATH is None:
|
if DOWNLOADS_PATH is None:
|
||||||
DOWNLOADS_PATH = tempfile.mkdtemp(prefix='dbt-downloads-')
|
DOWNLOADS_PATH = tempfile.mkdtemp(prefix="dbt-downloads-")
|
||||||
remove_downloads = True
|
remove_downloads = True
|
||||||
|
|
||||||
system.make_directory(DOWNLOADS_PATH)
|
system.make_directory(DOWNLOADS_PATH)
|
||||||
@@ -62,7 +62,7 @@ class PinnedPackage(BasePackage):
|
|||||||
if not version:
|
if not version:
|
||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
return '{}@{}'.format(self.name, version)
|
return "{}@{}".format(self.name, version)
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def get_version(self) -> Optional[str]:
|
def get_version(self) -> Optional[str]:
|
||||||
@@ -94,8 +94,8 @@ class PinnedPackage(BasePackage):
|
|||||||
return os.path.join(project.modules_path, dest_dirname)
|
return os.path.join(project.modules_path, dest_dirname)
|
||||||
|
|
||||||
|
|
||||||
SomePinned = TypeVar('SomePinned', bound=PinnedPackage)
|
SomePinned = TypeVar("SomePinned", bound=PinnedPackage)
|
||||||
SomeUnpinned = TypeVar('SomeUnpinned', bound='UnpinnedPackage')
|
SomeUnpinned = TypeVar("SomeUnpinned", bound="UnpinnedPackage")
|
||||||
|
|
||||||
|
|
||||||
class UnpinnedPackage(Generic[SomePinned], BasePackage):
|
class UnpinnedPackage(Generic[SomePinned], BasePackage):
|
||||||
@@ -8,18 +8,16 @@ from dbt.contracts.project import (
|
|||||||
ProjectPackageMetadata,
|
ProjectPackageMetadata,
|
||||||
GitPackage,
|
GitPackage,
|
||||||
)
|
)
|
||||||
from dbt.deps.base import PinnedPackage, UnpinnedPackage, get_downloads_path
|
from dbt.deps import PinnedPackage, UnpinnedPackage, get_downloads_path
|
||||||
from dbt.exceptions import (
|
from dbt.exceptions import ExecutableError, warn_or_error, raise_dependency_error
|
||||||
ExecutableError, warn_or_error, raise_dependency_error
|
|
||||||
)
|
|
||||||
from dbt.logger import GLOBAL_LOGGER as logger
|
from dbt.logger import GLOBAL_LOGGER as logger
|
||||||
from dbt import ui
|
from dbt import ui
|
||||||
|
|
||||||
PIN_PACKAGE_URL = 'https://docs.getdbt.com/docs/package-management#section-specifying-package-versions' # noqa
|
PIN_PACKAGE_URL = "https://docs.getdbt.com/docs/package-management#section-specifying-package-versions" # noqa
|
||||||
|
|
||||||
|
|
||||||
def md5sum(s: str):
|
def md5sum(s: str):
|
||||||
return hashlib.md5(s.encode('latin-1')).hexdigest()
|
return hashlib.md5(s.encode("latin-1")).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
class GitPackageMixin:
|
class GitPackageMixin:
|
||||||
@@ -32,13 +30,11 @@ class GitPackageMixin:
|
|||||||
return self.git
|
return self.git
|
||||||
|
|
||||||
def source_type(self) -> str:
|
def source_type(self) -> str:
|
||||||
return 'git'
|
return "git"
|
||||||
|
|
||||||
|
|
||||||
class GitPinnedPackage(GitPackageMixin, PinnedPackage):
|
class GitPinnedPackage(GitPackageMixin, PinnedPackage):
|
||||||
def __init__(
|
def __init__(self, git: str, revision: str, warn_unpinned: bool = True) -> None:
|
||||||
self, git: str, revision: str, warn_unpinned: bool = True
|
|
||||||
) -> None:
|
|
||||||
super().__init__(git)
|
super().__init__(git)
|
||||||
self.revision = revision
|
self.revision = revision
|
||||||
self.warn_unpinned = warn_unpinned
|
self.warn_unpinned = warn_unpinned
|
||||||
@@ -48,7 +44,18 @@ class GitPinnedPackage(GitPackageMixin, PinnedPackage):
|
|||||||
return self.revision
|
return self.revision
|
||||||
|
|
||||||
def nice_version_name(self):
|
def nice_version_name(self):
|
||||||
return 'revision {}'.format(self.revision)
|
if self.revision == "HEAD":
|
||||||
|
return "HEAD (default branch)"
|
||||||
|
else:
|
||||||
|
return "revision {}".format(self.revision)
|
||||||
|
|
||||||
|
def unpinned_msg(self):
|
||||||
|
if self.revision == "HEAD":
|
||||||
|
return "not pinned, using HEAD (default branch)"
|
||||||
|
elif self.revision in ("main", "master"):
|
||||||
|
return f'pinned to the "{self.revision}" branch'
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
def _checkout(self):
|
def _checkout(self):
|
||||||
"""Performs a shallow clone of the repository into the downloads
|
"""Performs a shallow clone of the repository into the downloads
|
||||||
@@ -57,27 +64,31 @@ class GitPinnedPackage(GitPackageMixin, PinnedPackage):
|
|||||||
the path to the checked out directory."""
|
the path to the checked out directory."""
|
||||||
try:
|
try:
|
||||||
dir_ = git.clone_and_checkout(
|
dir_ = git.clone_and_checkout(
|
||||||
self.git, get_downloads_path(), branch=self.revision,
|
self.git,
|
||||||
dirname=self._checkout_name
|
get_downloads_path(),
|
||||||
|
branch=self.revision,
|
||||||
|
dirname=self._checkout_name,
|
||||||
)
|
)
|
||||||
except ExecutableError as exc:
|
except ExecutableError as exc:
|
||||||
if exc.cmd and exc.cmd[0] == 'git':
|
if exc.cmd and exc.cmd[0] == "git":
|
||||||
logger.error(
|
logger.error(
|
||||||
'Make sure git is installed on your machine. More '
|
"Make sure git is installed on your machine. More "
|
||||||
'information: '
|
"information: "
|
||||||
'https://docs.getdbt.com/docs/package-management'
|
"https://docs.getdbt.com/docs/package-management"
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
return os.path.join(get_downloads_path(), dir_)
|
return os.path.join(get_downloads_path(), dir_)
|
||||||
|
|
||||||
def _fetch_metadata(self, project, renderer) -> ProjectPackageMetadata:
|
def _fetch_metadata(self, project, renderer) -> ProjectPackageMetadata:
|
||||||
path = self._checkout()
|
path = self._checkout()
|
||||||
if self.revision == 'master' and self.warn_unpinned:
|
|
||||||
|
if self.unpinned_msg() and self.warn_unpinned:
|
||||||
warn_or_error(
|
warn_or_error(
|
||||||
'The git package "{}" is not pinned.\n\tThis can introduce '
|
'The git package "{}" \n\tis {}.\n\tThis can introduce '
|
||||||
'breaking changes into your project without warning!\n\nSee {}'
|
"breaking changes into your project without warning!\n\nSee {}".format(
|
||||||
.format(self.git, PIN_PACKAGE_URL),
|
self.git, self.unpinned_msg(), PIN_PACKAGE_URL
|
||||||
log_fmt=ui.yellow('WARNING: {}')
|
),
|
||||||
|
log_fmt=ui.yellow("WARNING: {}"),
|
||||||
)
|
)
|
||||||
loaded = Project.from_project_root(path, renderer)
|
loaded = Project.from_project_root(path, renderer)
|
||||||
return ProjectPackageMetadata.from_project(loaded)
|
return ProjectPackageMetadata.from_project(loaded)
|
||||||
@@ -102,26 +113,21 @@ class GitUnpinnedPackage(GitPackageMixin, UnpinnedPackage[GitPinnedPackage]):
|
|||||||
self.warn_unpinned = warn_unpinned
|
self.warn_unpinned = warn_unpinned
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_contract(
|
def from_contract(cls, contract: GitPackage) -> "GitUnpinnedPackage":
|
||||||
cls, contract: GitPackage
|
|
||||||
) -> 'GitUnpinnedPackage':
|
|
||||||
revisions = contract.get_revisions()
|
revisions = contract.get_revisions()
|
||||||
|
|
||||||
# we want to map None -> True
|
# we want to map None -> True
|
||||||
warn_unpinned = contract.warn_unpinned is not False
|
warn_unpinned = contract.warn_unpinned is not False
|
||||||
return cls(git=contract.git, revisions=revisions,
|
return cls(git=contract.git, revisions=revisions, warn_unpinned=warn_unpinned)
|
||||||
warn_unpinned=warn_unpinned)
|
|
||||||
|
|
||||||
def all_names(self) -> List[str]:
|
def all_names(self) -> List[str]:
|
||||||
if self.git.endswith('.git'):
|
if self.git.endswith(".git"):
|
||||||
other = self.git[:-4]
|
other = self.git[:-4]
|
||||||
else:
|
else:
|
||||||
other = self.git + '.git'
|
other = self.git + ".git"
|
||||||
return [self.git, other]
|
return [self.git, other]
|
||||||
|
|
||||||
def incorporate(
|
def incorporate(self, other: "GitUnpinnedPackage") -> "GitUnpinnedPackage":
|
||||||
self, other: 'GitUnpinnedPackage'
|
|
||||||
) -> 'GitUnpinnedPackage':
|
|
||||||
warn_unpinned = self.warn_unpinned and other.warn_unpinned
|
warn_unpinned = self.warn_unpinned and other.warn_unpinned
|
||||||
|
|
||||||
return GitUnpinnedPackage(
|
return GitUnpinnedPackage(
|
||||||
@@ -133,13 +139,13 @@ class GitUnpinnedPackage(GitPackageMixin, UnpinnedPackage[GitPinnedPackage]):
|
|||||||
def resolved(self) -> GitPinnedPackage:
|
def resolved(self) -> GitPinnedPackage:
|
||||||
requested = set(self.revisions)
|
requested = set(self.revisions)
|
||||||
if len(requested) == 0:
|
if len(requested) == 0:
|
||||||
requested = {'master'}
|
requested = {"HEAD"}
|
||||||
elif len(requested) > 1:
|
elif len(requested) > 1:
|
||||||
raise_dependency_error(
|
raise_dependency_error(
|
||||||
'git dependencies should contain exactly one version. '
|
"git dependencies should contain exactly one version. "
|
||||||
'{} contains: {}'.format(self.git, requested))
|
"{} contains: {}".format(self.git, requested)
|
||||||
|
)
|
||||||
|
|
||||||
return GitPinnedPackage(
|
return GitPinnedPackage(
|
||||||
git=self.git, revision=requested.pop(),
|
git=self.git, revision=requested.pop(), warn_unpinned=self.warn_unpinned
|
||||||
warn_unpinned=self.warn_unpinned
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
from dbt.clients import system
|
from dbt.clients import system
|
||||||
from dbt.deps.base import PinnedPackage, UnpinnedPackage
|
from dbt.deps import PinnedPackage, UnpinnedPackage
|
||||||
from dbt.contracts.project import (
|
from dbt.contracts.project import (
|
||||||
ProjectPackageMetadata,
|
ProjectPackageMetadata,
|
||||||
LocalPackage,
|
LocalPackage,
|
||||||
@@ -19,7 +19,7 @@ class LocalPackageMixin:
|
|||||||
return self.local
|
return self.local
|
||||||
|
|
||||||
def source_type(self):
|
def source_type(self):
|
||||||
return 'local'
|
return "local"
|
||||||
|
|
||||||
|
|
||||||
class LocalPinnedPackage(LocalPackageMixin, PinnedPackage):
|
class LocalPinnedPackage(LocalPackageMixin, PinnedPackage):
|
||||||
@@ -30,7 +30,7 @@ class LocalPinnedPackage(LocalPackageMixin, PinnedPackage):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def nice_version_name(self):
|
def nice_version_name(self):
|
||||||
return '<local @ {}>'.format(self.local)
|
return "<local @ {}>".format(self.local)
|
||||||
|
|
||||||
def resolve_path(self, project):
|
def resolve_path(self, project):
|
||||||
return system.resolve_path_from_base(
|
return system.resolve_path_from_base(
|
||||||
@@ -39,9 +39,7 @@ class LocalPinnedPackage(LocalPackageMixin, PinnedPackage):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _fetch_metadata(self, project, renderer):
|
def _fetch_metadata(self, project, renderer):
|
||||||
loaded = project.from_project_root(
|
loaded = project.from_project_root(self.resolve_path(project), renderer)
|
||||||
self.resolve_path(project), renderer
|
|
||||||
)
|
|
||||||
return ProjectPackageMetadata.from_project(loaded)
|
return ProjectPackageMetadata.from_project(loaded)
|
||||||
|
|
||||||
def install(self, project, renderer):
|
def install(self, project, renderer):
|
||||||
@@ -57,27 +55,22 @@ class LocalPinnedPackage(LocalPackageMixin, PinnedPackage):
|
|||||||
system.remove_file(dest_path)
|
system.remove_file(dest_path)
|
||||||
|
|
||||||
if can_create_symlink:
|
if can_create_symlink:
|
||||||
logger.debug(' Creating symlink to local dependency.')
|
logger.debug(" Creating symlink to local dependency.")
|
||||||
system.make_symlink(src_path, dest_path)
|
system.make_symlink(src_path, dest_path)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.debug(' Symlinks are not available on this '
|
logger.debug(
|
||||||
'OS, copying dependency.')
|
" Symlinks are not available on this " "OS, copying dependency."
|
||||||
|
)
|
||||||
shutil.copytree(src_path, dest_path)
|
shutil.copytree(src_path, dest_path)
|
||||||
|
|
||||||
|
|
||||||
class LocalUnpinnedPackage(
|
class LocalUnpinnedPackage(LocalPackageMixin, UnpinnedPackage[LocalPinnedPackage]):
|
||||||
LocalPackageMixin, UnpinnedPackage[LocalPinnedPackage]
|
|
||||||
):
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_contract(
|
def from_contract(cls, contract: LocalPackage) -> "LocalUnpinnedPackage":
|
||||||
cls, contract: LocalPackage
|
|
||||||
) -> 'LocalUnpinnedPackage':
|
|
||||||
return cls(local=contract.local)
|
return cls(local=contract.local)
|
||||||
|
|
||||||
def incorporate(
|
def incorporate(self, other: "LocalUnpinnedPackage") -> "LocalUnpinnedPackage":
|
||||||
self, other: 'LocalUnpinnedPackage'
|
|
||||||
) -> 'LocalUnpinnedPackage':
|
|
||||||
return LocalUnpinnedPackage(local=self.local)
|
return LocalUnpinnedPackage(local=self.local)
|
||||||
|
|
||||||
def resolved(self) -> LocalPinnedPackage:
|
def resolved(self) -> LocalPinnedPackage:
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from dbt.contracts.project import (
|
|||||||
RegistryPackageMetadata,
|
RegistryPackageMetadata,
|
||||||
RegistryPackage,
|
RegistryPackage,
|
||||||
)
|
)
|
||||||
from dbt.deps.base import PinnedPackage, UnpinnedPackage, get_downloads_path
|
from dbt.deps import PinnedPackage, UnpinnedPackage, get_downloads_path
|
||||||
from dbt.exceptions import (
|
from dbt.exceptions import (
|
||||||
package_version_not_found,
|
package_version_not_found,
|
||||||
VersionsNotCompatibleException,
|
VersionsNotCompatibleException,
|
||||||
@@ -26,7 +26,7 @@ class RegistryPackageMixin:
|
|||||||
return self.package
|
return self.package
|
||||||
|
|
||||||
def source_type(self) -> str:
|
def source_type(self) -> str:
|
||||||
return 'hub'
|
return "hub"
|
||||||
|
|
||||||
|
|
||||||
class RegistryPinnedPackage(RegistryPackageMixin, PinnedPackage):
|
class RegistryPinnedPackage(RegistryPackageMixin, PinnedPackage):
|
||||||
@@ -39,13 +39,13 @@ class RegistryPinnedPackage(RegistryPackageMixin, PinnedPackage):
|
|||||||
return self.package
|
return self.package
|
||||||
|
|
||||||
def source_type(self):
|
def source_type(self):
|
||||||
return 'hub'
|
return "hub"
|
||||||
|
|
||||||
def get_version(self):
|
def get_version(self):
|
||||||
return self.version
|
return self.version
|
||||||
|
|
||||||
def nice_version_name(self):
|
def nice_version_name(self):
|
||||||
return 'version {}'.format(self.version)
|
return "version {}".format(self.version)
|
||||||
|
|
||||||
def _fetch_metadata(self, project, renderer) -> RegistryPackageMetadata:
|
def _fetch_metadata(self, project, renderer) -> RegistryPackageMetadata:
|
||||||
dct = registry.package_version(self.package, self.version)
|
dct = registry.package_version(self.package, self.version)
|
||||||
@@ -54,10 +54,8 @@ class RegistryPinnedPackage(RegistryPackageMixin, PinnedPackage):
|
|||||||
def install(self, project, renderer):
|
def install(self, project, renderer):
|
||||||
metadata = self.fetch_metadata(project, renderer)
|
metadata = self.fetch_metadata(project, renderer)
|
||||||
|
|
||||||
tar_name = '{}.{}.tar.gz'.format(self.package, self.version)
|
tar_name = "{}.{}.tar.gz".format(self.package, self.version)
|
||||||
tar_path = os.path.realpath(
|
tar_path = os.path.realpath(os.path.join(get_downloads_path(), tar_name))
|
||||||
os.path.join(get_downloads_path(), tar_name)
|
|
||||||
)
|
|
||||||
system.make_directory(os.path.dirname(tar_path))
|
system.make_directory(os.path.dirname(tar_path))
|
||||||
|
|
||||||
download_url = metadata.downloads.tarball
|
download_url = metadata.downloads.tarball
|
||||||
@@ -70,9 +68,7 @@ class RegistryPinnedPackage(RegistryPackageMixin, PinnedPackage):
|
|||||||
class RegistryUnpinnedPackage(
|
class RegistryUnpinnedPackage(
|
||||||
RegistryPackageMixin, UnpinnedPackage[RegistryPinnedPackage]
|
RegistryPackageMixin, UnpinnedPackage[RegistryPinnedPackage]
|
||||||
):
|
):
|
||||||
def __init__(
|
def __init__(self, package: str, versions: List[semver.VersionSpecifier]) -> None:
|
||||||
self, package: str, versions: List[semver.VersionSpecifier]
|
|
||||||
) -> None:
|
|
||||||
super().__init__(package)
|
super().__init__(package)
|
||||||
self.versions = versions
|
self.versions = versions
|
||||||
|
|
||||||
@@ -82,20 +78,15 @@ class RegistryUnpinnedPackage(
|
|||||||
package_not_found(self.package)
|
package_not_found(self.package)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_contract(
|
def from_contract(cls, contract: RegistryPackage) -> "RegistryUnpinnedPackage":
|
||||||
cls, contract: RegistryPackage
|
|
||||||
) -> 'RegistryUnpinnedPackage':
|
|
||||||
raw_version = contract.get_versions()
|
raw_version = contract.get_versions()
|
||||||
|
|
||||||
versions = [
|
versions = [semver.VersionSpecifier.from_version_string(v) for v in raw_version]
|
||||||
semver.VersionSpecifier.from_version_string(v)
|
|
||||||
for v in raw_version
|
|
||||||
]
|
|
||||||
return cls(package=contract.package, versions=versions)
|
return cls(package=contract.package, versions=versions)
|
||||||
|
|
||||||
def incorporate(
|
def incorporate(
|
||||||
self, other: 'RegistryUnpinnedPackage'
|
self, other: "RegistryUnpinnedPackage"
|
||||||
) -> 'RegistryUnpinnedPackage':
|
) -> "RegistryUnpinnedPackage":
|
||||||
return RegistryUnpinnedPackage(
|
return RegistryUnpinnedPackage(
|
||||||
package=self.package,
|
package=self.package,
|
||||||
versions=self.versions + other.versions,
|
versions=self.versions + other.versions,
|
||||||
@@ -106,8 +97,7 @@ class RegistryUnpinnedPackage(
|
|||||||
try:
|
try:
|
||||||
range_ = semver.reduce_versions(*self.versions)
|
range_ = semver.reduce_versions(*self.versions)
|
||||||
except VersionsNotCompatibleException as e:
|
except VersionsNotCompatibleException as e:
|
||||||
new_msg = ('Version error for package {}: {}'
|
new_msg = "Version error for package {}: {}".format(self.name, e)
|
||||||
.format(self.name, e))
|
|
||||||
raise DependencyException(new_msg) from e
|
raise DependencyException(new_msg) from e
|
||||||
|
|
||||||
available = registry.get_available_versions(self.package)
|
available = registry.get_available_versions(self.package)
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from dbt.exceptions import raise_dependency_error, InternalException
|
|||||||
from dbt.context.target import generate_target_context
|
from dbt.context.target import generate_target_context
|
||||||
from dbt.config import Project, RuntimeConfig
|
from dbt.config import Project, RuntimeConfig
|
||||||
from dbt.config.renderer import DbtProjectYamlRenderer
|
from dbt.config.renderer import DbtProjectYamlRenderer
|
||||||
from dbt.deps.base import BasePackage, PinnedPackage, UnpinnedPackage
|
from dbt.deps import BasePackage, PinnedPackage, UnpinnedPackage
|
||||||
from dbt.deps.local import LocalUnpinnedPackage
|
from dbt.deps.local import LocalUnpinnedPackage
|
||||||
from dbt.deps.git import GitUnpinnedPackage
|
from dbt.deps.git import GitUnpinnedPackage
|
||||||
from dbt.deps.registry import RegistryUnpinnedPackage
|
from dbt.deps.registry import RegistryUnpinnedPackage
|
||||||
@@ -49,12 +49,10 @@ class PackageListing:
|
|||||||
key_str: str = self._pick_key(key)
|
key_str: str = self._pick_key(key)
|
||||||
self.packages[key_str] = value
|
self.packages[key_str] = value
|
||||||
|
|
||||||
def _mismatched_types(
|
def _mismatched_types(self, old: UnpinnedPackage, new: UnpinnedPackage) -> NoReturn:
|
||||||
self, old: UnpinnedPackage, new: UnpinnedPackage
|
|
||||||
) -> NoReturn:
|
|
||||||
raise_dependency_error(
|
raise_dependency_error(
|
||||||
f'Cannot incorporate {new} ({new.__class__.__name__}) in {old} '
|
f"Cannot incorporate {new} ({new.__class__.__name__}) in {old} "
|
||||||
f'({old.__class__.__name__}): mismatched types'
|
f"({old.__class__.__name__}): mismatched types"
|
||||||
)
|
)
|
||||||
|
|
||||||
def incorporate(self, package: UnpinnedPackage):
|
def incorporate(self, package: UnpinnedPackage):
|
||||||
@@ -78,14 +76,14 @@ class PackageListing:
|
|||||||
pkg = RegistryUnpinnedPackage.from_contract(contract)
|
pkg = RegistryUnpinnedPackage.from_contract(contract)
|
||||||
else:
|
else:
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
'Invalid package type {}'.format(type(contract))
|
"Invalid package type {}".format(type(contract))
|
||||||
)
|
)
|
||||||
self.incorporate(pkg)
|
self.incorporate(pkg)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_contracts(
|
def from_contracts(
|
||||||
cls: Type['PackageListing'], src: List[PackageContract]
|
cls: Type["PackageListing"], src: List[PackageContract]
|
||||||
) -> 'PackageListing':
|
) -> "PackageListing":
|
||||||
self = cls({})
|
self = cls({})
|
||||||
self.update_from(src)
|
self.update_from(src)
|
||||||
return self
|
return self
|
||||||
@@ -108,14 +106,14 @@ def _check_for_duplicate_project_names(
|
|||||||
if project_name in seen:
|
if project_name in seen:
|
||||||
raise_dependency_error(
|
raise_dependency_error(
|
||||||
f'Found duplicate project "{project_name}". This occurs when '
|
f'Found duplicate project "{project_name}". This occurs when '
|
||||||
'a dependency has the same project name as some other '
|
"a dependency has the same project name as some other "
|
||||||
'dependency.'
|
"dependency."
|
||||||
)
|
)
|
||||||
elif project_name == config.project_name:
|
elif project_name == config.project_name:
|
||||||
raise_dependency_error(
|
raise_dependency_error(
|
||||||
'Found a dependency with the same name as the root project '
|
"Found a dependency with the same name as the root project "
|
||||||
f'"{project_name}". Package names must be unique in a project.'
|
f'"{project_name}". Package names must be unique in a project.'
|
||||||
' Please rename one of these packages.'
|
" Please rename one of these packages."
|
||||||
)
|
)
|
||||||
seen.add(project_name)
|
seen.add(project_name)
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
if os.name != 'nt':
|
|
||||||
|
if os.name != "nt":
|
||||||
# https://bugs.python.org/issue41567
|
# https://bugs.python.org/issue41567
|
||||||
import multiprocessing.popen_spawn_posix # type: ignore
|
import multiprocessing.popen_spawn_posix # type: ignore
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -23,7 +24,7 @@ def env_set_truthy(key: str) -> Optional[str]:
|
|||||||
otherwise.
|
otherwise.
|
||||||
"""
|
"""
|
||||||
value = os.getenv(key)
|
value = os.getenv(key)
|
||||||
if not value or value.lower() in ('0', 'false', 'f'):
|
if not value or value.lower() in ("0", "false", "f"):
|
||||||
return None
|
return None
|
||||||
return value
|
return value
|
||||||
|
|
||||||
@@ -36,24 +37,23 @@ def env_set_path(key: str) -> Optional[Path]:
|
|||||||
return Path(value)
|
return Path(value)
|
||||||
|
|
||||||
|
|
||||||
SINGLE_THREADED_WEBSERVER = env_set_truthy('DBT_SINGLE_THREADED_WEBSERVER')
|
SINGLE_THREADED_WEBSERVER = env_set_truthy("DBT_SINGLE_THREADED_WEBSERVER")
|
||||||
SINGLE_THREADED_HANDLER = env_set_truthy('DBT_SINGLE_THREADED_HANDLER')
|
SINGLE_THREADED_HANDLER = env_set_truthy("DBT_SINGLE_THREADED_HANDLER")
|
||||||
MACRO_DEBUGGING = env_set_truthy('DBT_MACRO_DEBUGGING')
|
MACRO_DEBUGGING = env_set_truthy("DBT_MACRO_DEBUGGING")
|
||||||
DEFER_MODE = env_set_truthy('DBT_DEFER_TO_STATE')
|
DEFER_MODE = env_set_truthy("DBT_DEFER_TO_STATE")
|
||||||
ARTIFACT_STATE_PATH = env_set_path('DBT_ARTIFACT_STATE_PATH')
|
ARTIFACT_STATE_PATH = env_set_path("DBT_ARTIFACT_STATE_PATH")
|
||||||
|
|
||||||
|
|
||||||
def _get_context():
|
def _get_context():
|
||||||
# TODO: change this back to use fork() on linux when we have made that safe
|
# TODO: change this back to use fork() on linux when we have made that safe
|
||||||
return multiprocessing.get_context('spawn')
|
return multiprocessing.get_context("spawn")
|
||||||
|
|
||||||
|
|
||||||
MP_CONTEXT = _get_context()
|
MP_CONTEXT = _get_context()
|
||||||
|
|
||||||
|
|
||||||
def reset():
|
def reset():
|
||||||
global STRICT_MODE, FULL_REFRESH, USE_CACHE, WARN_ERROR, TEST_NEW_PARSER, \
|
global STRICT_MODE, FULL_REFRESH, USE_CACHE, WARN_ERROR, TEST_NEW_PARSER, WRITE_JSON, PARTIAL_PARSE, MP_CONTEXT, USE_COLORS
|
||||||
WRITE_JSON, PARTIAL_PARSE, MP_CONTEXT, USE_COLORS
|
|
||||||
|
|
||||||
STRICT_MODE = False
|
STRICT_MODE = False
|
||||||
FULL_REFRESH = False
|
FULL_REFRESH = False
|
||||||
@@ -67,26 +67,22 @@ def reset():
|
|||||||
|
|
||||||
|
|
||||||
def set_from_args(args):
|
def set_from_args(args):
|
||||||
global STRICT_MODE, FULL_REFRESH, USE_CACHE, WARN_ERROR, TEST_NEW_PARSER, \
|
global STRICT_MODE, FULL_REFRESH, USE_CACHE, WARN_ERROR, TEST_NEW_PARSER, WRITE_JSON, PARTIAL_PARSE, MP_CONTEXT, USE_COLORS
|
||||||
WRITE_JSON, PARTIAL_PARSE, MP_CONTEXT, USE_COLORS
|
|
||||||
|
|
||||||
USE_CACHE = getattr(args, 'use_cache', USE_CACHE)
|
USE_CACHE = getattr(args, "use_cache", USE_CACHE)
|
||||||
|
|
||||||
FULL_REFRESH = getattr(args, 'full_refresh', FULL_REFRESH)
|
FULL_REFRESH = getattr(args, "full_refresh", FULL_REFRESH)
|
||||||
STRICT_MODE = getattr(args, 'strict', STRICT_MODE)
|
STRICT_MODE = getattr(args, "strict", STRICT_MODE)
|
||||||
WARN_ERROR = (
|
WARN_ERROR = STRICT_MODE or getattr(args, "warn_error", STRICT_MODE or WARN_ERROR)
|
||||||
STRICT_MODE or
|
|
||||||
getattr(args, 'warn_error', STRICT_MODE or WARN_ERROR)
|
|
||||||
)
|
|
||||||
|
|
||||||
TEST_NEW_PARSER = getattr(args, 'test_new_parser', TEST_NEW_PARSER)
|
TEST_NEW_PARSER = getattr(args, "test_new_parser", TEST_NEW_PARSER)
|
||||||
WRITE_JSON = getattr(args, 'write_json', WRITE_JSON)
|
WRITE_JSON = getattr(args, "write_json", WRITE_JSON)
|
||||||
PARTIAL_PARSE = getattr(args, 'partial_parse', None)
|
PARTIAL_PARSE = getattr(args, "partial_parse", None)
|
||||||
MP_CONTEXT = _get_context()
|
MP_CONTEXT = _get_context()
|
||||||
|
|
||||||
# The use_colors attribute will always have a value because it is assigned
|
# The use_colors attribute will always have a value because it is assigned
|
||||||
# None by default from the add_mutually_exclusive_group function
|
# None by default from the add_mutually_exclusive_group function
|
||||||
use_colors_override = getattr(args, 'use_colors')
|
use_colors_override = getattr(args, "use_colors")
|
||||||
|
|
||||||
if use_colors_override is not None:
|
if use_colors_override is not None:
|
||||||
USE_COLORS = use_colors_override
|
USE_COLORS = use_colors_override
|
||||||
|
|||||||
@@ -1,10 +1,8 @@
|
|||||||
# special support for CLI argument parsing.
|
# special support for CLI argument parsing.
|
||||||
import itertools
|
import itertools
|
||||||
import yaml
|
from dbt.clients.yaml_helper import yaml, Loader, Dumper # noqa: F401
|
||||||
|
|
||||||
from typing import (
|
from typing import Dict, List, Optional, Tuple, Any, Union
|
||||||
Dict, List, Optional, Tuple, Any, Union
|
|
||||||
)
|
|
||||||
|
|
||||||
from dbt.contracts.selection import SelectorDefinition, SelectorFile
|
from dbt.contracts.selection import SelectorDefinition, SelectorFile
|
||||||
from dbt.exceptions import InternalException, ValidationException
|
from dbt.exceptions import InternalException, ValidationException
|
||||||
@@ -17,21 +15,17 @@ from .selector_spec import (
|
|||||||
SelectionCriteria,
|
SelectionCriteria,
|
||||||
)
|
)
|
||||||
|
|
||||||
INTERSECTION_DELIMITER = ','
|
INTERSECTION_DELIMITER = ","
|
||||||
|
|
||||||
DEFAULT_INCLUDES: List[str] = ['fqn:*', 'source:*', 'exposure:*']
|
DEFAULT_INCLUDES: List[str] = ["fqn:*", "source:*", "exposure:*"]
|
||||||
DEFAULT_EXCLUDES: List[str] = []
|
DEFAULT_EXCLUDES: List[str] = []
|
||||||
DATA_TEST_SELECTOR: str = 'test_type:data'
|
DATA_TEST_SELECTOR: str = "test_type:data"
|
||||||
SCHEMA_TEST_SELECTOR: str = 'test_type:schema'
|
SCHEMA_TEST_SELECTOR: str = "test_type:schema"
|
||||||
|
|
||||||
|
|
||||||
def parse_union(
|
def parse_union(components: List[str], expect_exists: bool) -> SelectionUnion:
|
||||||
components: List[str], expect_exists: bool
|
|
||||||
) -> SelectionUnion:
|
|
||||||
# turn ['a b', 'c'] -> ['a', 'b', 'c']
|
# turn ['a b', 'c'] -> ['a', 'b', 'c']
|
||||||
raw_specs = itertools.chain.from_iterable(
|
raw_specs = itertools.chain.from_iterable(r.split(" ") for r in components)
|
||||||
r.split(' ') for r in components
|
|
||||||
)
|
|
||||||
union_components: List[SelectionSpec] = []
|
union_components: List[SelectionSpec] = []
|
||||||
|
|
||||||
# ['a', 'b', 'c,d'] -> union('a', 'b', intersection('c', 'd'))
|
# ['a', 'b', 'c,d'] -> union('a', 'b', intersection('c', 'd'))
|
||||||
@@ -40,11 +34,13 @@ def parse_union(
|
|||||||
SelectionCriteria.from_single_spec(part)
|
SelectionCriteria.from_single_spec(part)
|
||||||
for part in raw_spec.split(INTERSECTION_DELIMITER)
|
for part in raw_spec.split(INTERSECTION_DELIMITER)
|
||||||
]
|
]
|
||||||
union_components.append(SelectionIntersection(
|
union_components.append(
|
||||||
components=intersection_components,
|
SelectionIntersection(
|
||||||
expect_exists=expect_exists,
|
components=intersection_components,
|
||||||
raw=raw_spec,
|
expect_exists=expect_exists,
|
||||||
))
|
raw=raw_spec,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return SelectionUnion(
|
return SelectionUnion(
|
||||||
components=union_components,
|
components=union_components,
|
||||||
@@ -78,9 +74,7 @@ def parse_test_selectors(
|
|||||||
union_components = []
|
union_components = []
|
||||||
|
|
||||||
if data:
|
if data:
|
||||||
union_components.append(
|
union_components.append(SelectionCriteria.from_single_spec(DATA_TEST_SELECTOR))
|
||||||
SelectionCriteria.from_single_spec(DATA_TEST_SELECTOR)
|
|
||||||
)
|
|
||||||
if schema:
|
if schema:
|
||||||
union_components.append(
|
union_components.append(
|
||||||
SelectionCriteria.from_single_spec(SCHEMA_TEST_SELECTOR)
|
SelectionCriteria.from_single_spec(SCHEMA_TEST_SELECTOR)
|
||||||
@@ -98,27 +92,21 @@ def parse_test_selectors(
|
|||||||
raw=[DATA_TEST_SELECTOR, SCHEMA_TEST_SELECTOR],
|
raw=[DATA_TEST_SELECTOR, SCHEMA_TEST_SELECTOR],
|
||||||
)
|
)
|
||||||
|
|
||||||
return SelectionIntersection(
|
return SelectionIntersection(components=[base, intersect_with], expect_exists=True)
|
||||||
components=[base, intersect_with], expect_exists=True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
RawDefinition = Union[str, Dict[str, Any]]
|
RawDefinition = Union[str, Dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
def _get_list_dicts(
|
def _get_list_dicts(dct: Dict[str, Any], key: str) -> List[RawDefinition]:
|
||||||
dct: Dict[str, Any], key: str
|
|
||||||
) -> List[RawDefinition]:
|
|
||||||
result: List[RawDefinition] = []
|
result: List[RawDefinition] = []
|
||||||
if key not in dct:
|
if key not in dct:
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
f'Expected to find key {key} in dict, only found {list(dct)}'
|
f"Expected to find key {key} in dict, only found {list(dct)}"
|
||||||
)
|
)
|
||||||
values = dct[key]
|
values = dct[key]
|
||||||
if not isinstance(values, list):
|
if not isinstance(values, list):
|
||||||
raise ValidationException(
|
raise ValidationException(f'Invalid value for key "{key}". Expected a list.')
|
||||||
f'Invalid value for key "{key}". Expected a list.'
|
|
||||||
)
|
|
||||||
for value in values:
|
for value in values:
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
for value_key in value:
|
for value_key in value:
|
||||||
@@ -133,36 +121,31 @@ def _get_list_dicts(
|
|||||||
else:
|
else:
|
||||||
raise ValidationException(
|
raise ValidationException(
|
||||||
f'Invalid value type {type(value)} in key "{key}", expected '
|
f'Invalid value type {type(value)} in key "{key}", expected '
|
||||||
f'dict or str (value: {value}).'
|
f"dict or str (value: {value})."
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _parse_exclusions(definition) -> Optional[SelectionSpec]:
|
def _parse_exclusions(definition) -> Optional[SelectionSpec]:
|
||||||
exclusions = _get_list_dicts(definition, 'exclude')
|
exclusions = _get_list_dicts(definition, "exclude")
|
||||||
parsed_exclusions = [
|
parsed_exclusions = [parse_from_definition(excl) for excl in exclusions]
|
||||||
parse_from_definition(excl) for excl in exclusions
|
|
||||||
]
|
|
||||||
if len(parsed_exclusions) == 1:
|
if len(parsed_exclusions) == 1:
|
||||||
return parsed_exclusions[0]
|
return parsed_exclusions[0]
|
||||||
elif len(parsed_exclusions) > 1:
|
elif len(parsed_exclusions) > 1:
|
||||||
return SelectionUnion(
|
return SelectionUnion(components=parsed_exclusions, raw=exclusions)
|
||||||
components=parsed_exclusions,
|
|
||||||
raw=exclusions
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _parse_include_exclude_subdefs(
|
def _parse_include_exclude_subdefs(
|
||||||
definitions: List[RawDefinition]
|
definitions: List[RawDefinition],
|
||||||
) -> Tuple[List[SelectionSpec], Optional[SelectionSpec]]:
|
) -> Tuple[List[SelectionSpec], Optional[SelectionSpec]]:
|
||||||
include_parts: List[SelectionSpec] = []
|
include_parts: List[SelectionSpec] = []
|
||||||
diff_arg: Optional[SelectionSpec] = None
|
diff_arg: Optional[SelectionSpec] = None
|
||||||
|
|
||||||
for definition in definitions:
|
for definition in definitions:
|
||||||
if isinstance(definition, dict) and 'exclude' in definition:
|
if isinstance(definition, dict) and "exclude" in definition:
|
||||||
# do not allow multiple exclude: defs at the same level
|
# do not allow multiple exclude: defs at the same level
|
||||||
if diff_arg is not None:
|
if diff_arg is not None:
|
||||||
yaml_sel_cfg = yaml.dump(definition)
|
yaml_sel_cfg = yaml.dump(definition)
|
||||||
@@ -178,7 +161,7 @@ def _parse_include_exclude_subdefs(
|
|||||||
|
|
||||||
|
|
||||||
def parse_union_definition(definition: Dict[str, Any]) -> SelectionSpec:
|
def parse_union_definition(definition: Dict[str, Any]) -> SelectionSpec:
|
||||||
union_def_parts = _get_list_dicts(definition, 'union')
|
union_def_parts = _get_list_dicts(definition, "union")
|
||||||
include, exclude = _parse_include_exclude_subdefs(union_def_parts)
|
include, exclude = _parse_include_exclude_subdefs(union_def_parts)
|
||||||
|
|
||||||
union = SelectionUnion(components=include)
|
union = SelectionUnion(components=include)
|
||||||
@@ -187,16 +170,11 @@ def parse_union_definition(definition: Dict[str, Any]) -> SelectionSpec:
|
|||||||
union.raw = definition
|
union.raw = definition
|
||||||
return union
|
return union
|
||||||
else:
|
else:
|
||||||
return SelectionDifference(
|
return SelectionDifference(components=[union, exclude], raw=definition)
|
||||||
components=[union, exclude],
|
|
||||||
raw=definition
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_intersection_definition(
|
def parse_intersection_definition(definition: Dict[str, Any]) -> SelectionSpec:
|
||||||
definition: Dict[str, Any]
|
intersection_def_parts = _get_list_dicts(definition, "intersection")
|
||||||
) -> SelectionSpec:
|
|
||||||
intersection_def_parts = _get_list_dicts(definition, 'intersection')
|
|
||||||
include, exclude = _parse_include_exclude_subdefs(intersection_def_parts)
|
include, exclude = _parse_include_exclude_subdefs(intersection_def_parts)
|
||||||
intersection = SelectionIntersection(components=include)
|
intersection = SelectionIntersection(components=include)
|
||||||
|
|
||||||
@@ -204,10 +182,7 @@ def parse_intersection_definition(
|
|||||||
intersection.raw = definition
|
intersection.raw = definition
|
||||||
return intersection
|
return intersection
|
||||||
else:
|
else:
|
||||||
return SelectionDifference(
|
return SelectionDifference(components=[intersection, exclude], raw=definition)
|
||||||
components=[intersection, exclude],
|
|
||||||
raw=definition
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_dict_definition(definition: Dict[str, Any]) -> SelectionSpec:
|
def parse_dict_definition(definition: Dict[str, Any]) -> SelectionSpec:
|
||||||
@@ -221,14 +196,14 @@ def parse_dict_definition(definition: Dict[str, Any]) -> SelectionSpec:
|
|||||||
f'"{type(key)}" ({key})'
|
f'"{type(key)}" ({key})'
|
||||||
)
|
)
|
||||||
dct = {
|
dct = {
|
||||||
'method': key,
|
"method": key,
|
||||||
'value': value,
|
"value": value,
|
||||||
}
|
}
|
||||||
elif 'method' in definition and 'value' in definition:
|
elif "method" in definition and "value" in definition:
|
||||||
dct = definition
|
dct = definition
|
||||||
if 'exclude' in definition:
|
if "exclude" in definition:
|
||||||
diff_arg = _parse_exclusions(definition)
|
diff_arg = _parse_exclusions(definition)
|
||||||
dct = {k: v for k, v in dct.items() if k != 'exclude'}
|
dct = {k: v for k, v in dct.items() if k != "exclude"}
|
||||||
else:
|
else:
|
||||||
raise ValidationException(
|
raise ValidationException(
|
||||||
f'Expected either 1 key or else "method" '
|
f'Expected either 1 key or else "method" '
|
||||||
@@ -236,20 +211,21 @@ def parse_dict_definition(definition: Dict[str, Any]) -> SelectionSpec:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# if key isn't a valid method name, this will raise
|
# if key isn't a valid method name, this will raise
|
||||||
base = SelectionCriteria.from_dict(definition, dct)
|
base = SelectionCriteria.selection_criteria_from_dict(definition, dct)
|
||||||
if diff_arg is None:
|
if diff_arg is None:
|
||||||
return base
|
return base
|
||||||
else:
|
else:
|
||||||
return SelectionDifference(components=[base, diff_arg])
|
return SelectionDifference(components=[base, diff_arg])
|
||||||
|
|
||||||
|
|
||||||
def parse_from_definition(
|
def parse_from_definition(definition: RawDefinition, rootlevel=False) -> SelectionSpec:
|
||||||
definition: RawDefinition, rootlevel=False
|
|
||||||
) -> SelectionSpec:
|
|
||||||
|
|
||||||
if (isinstance(definition, dict) and
|
if (
|
||||||
('union' in definition or 'intersection' in definition) and
|
isinstance(definition, dict)
|
||||||
rootlevel and len(definition) > 1):
|
and ("union" in definition or "intersection" in definition)
|
||||||
|
and rootlevel
|
||||||
|
and len(definition) > 1
|
||||||
|
):
|
||||||
keys = ",".join(definition.keys())
|
keys = ",".join(definition.keys())
|
||||||
raise ValidationException(
|
raise ValidationException(
|
||||||
f"Only a single 'union' or 'intersection' key is allowed "
|
f"Only a single 'union' or 'intersection' key is allowed "
|
||||||
@@ -257,25 +233,24 @@ def parse_from_definition(
|
|||||||
)
|
)
|
||||||
if isinstance(definition, str):
|
if isinstance(definition, str):
|
||||||
return SelectionCriteria.from_single_spec(definition)
|
return SelectionCriteria.from_single_spec(definition)
|
||||||
elif 'union' in definition:
|
elif "union" in definition:
|
||||||
return parse_union_definition(definition)
|
return parse_union_definition(definition)
|
||||||
elif 'intersection' in definition:
|
elif "intersection" in definition:
|
||||||
return parse_intersection_definition(definition)
|
return parse_intersection_definition(definition)
|
||||||
elif isinstance(definition, dict):
|
elif isinstance(definition, dict):
|
||||||
return parse_dict_definition(definition)
|
return parse_dict_definition(definition)
|
||||||
else:
|
else:
|
||||||
raise ValidationException(
|
raise ValidationException(
|
||||||
f'Expected to find union, intersection, str or dict, instead '
|
f"Expected to find union, intersection, str or dict, instead "
|
||||||
f'found {type(definition)}: {definition}'
|
f"found {type(definition)}: {definition}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_from_selectors_definition(
|
def parse_from_selectors_definition(source: SelectorFile) -> Dict[str, SelectionSpec]:
|
||||||
source: SelectorFile
|
|
||||||
) -> Dict[str, SelectionSpec]:
|
|
||||||
result: Dict[str, SelectionSpec] = {}
|
result: Dict[str, SelectionSpec] = {}
|
||||||
selector: SelectorDefinition
|
selector: SelectorDefinition
|
||||||
for selector in source.selectors:
|
for selector in source.selectors:
|
||||||
result[selector.name] = parse_from_definition(selector.definition,
|
result[selector.name] = parse_from_definition(
|
||||||
rootlevel=True)
|
selector.definition, rootlevel=True
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -1,17 +1,16 @@
|
|||||||
from typing import (
|
from typing import Set, Iterable, Iterator, Optional, NewType
|
||||||
Set, Iterable, Iterator, Optional, NewType
|
|
||||||
)
|
|
||||||
import networkx as nx # type: ignore
|
import networkx as nx # type: ignore
|
||||||
|
|
||||||
from dbt.exceptions import InternalException
|
from dbt.exceptions import InternalException
|
||||||
|
|
||||||
UniqueId = NewType('UniqueId', str)
|
UniqueId = NewType("UniqueId", str)
|
||||||
|
|
||||||
|
|
||||||
class Graph:
|
class Graph:
|
||||||
"""A wrapper around the networkx graph that understands SelectionCriteria
|
"""A wrapper around the networkx graph that understands SelectionCriteria
|
||||||
and how they interact with the graph.
|
and how they interact with the graph.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, graph):
|
def __init__(self, graph):
|
||||||
self.graph = graph
|
self.graph = graph
|
||||||
|
|
||||||
@@ -29,12 +28,11 @@ class Graph:
|
|||||||
) -> Set[UniqueId]:
|
) -> Set[UniqueId]:
|
||||||
"""Returns all nodes having a path to `node` in `graph`"""
|
"""Returns all nodes having a path to `node` in `graph`"""
|
||||||
if not self.graph.has_node(node):
|
if not self.graph.has_node(node):
|
||||||
raise InternalException(f'Node {node} not found in the graph!')
|
raise InternalException(f"Node {node} not found in the graph!")
|
||||||
with nx.utils.reversed(self.graph):
|
with nx.utils.reversed(self.graph):
|
||||||
anc = nx.single_source_shortest_path_length(G=self.graph,
|
anc = nx.single_source_shortest_path_length(
|
||||||
source=node,
|
G=self.graph, source=node, cutoff=max_depth
|
||||||
cutoff=max_depth)\
|
).keys()
|
||||||
.keys()
|
|
||||||
return anc - {node}
|
return anc - {node}
|
||||||
|
|
||||||
def descendants(
|
def descendants(
|
||||||
@@ -42,16 +40,13 @@ class Graph:
|
|||||||
) -> Set[UniqueId]:
|
) -> Set[UniqueId]:
|
||||||
"""Returns all nodes reachable from `node` in `graph`"""
|
"""Returns all nodes reachable from `node` in `graph`"""
|
||||||
if not self.graph.has_node(node):
|
if not self.graph.has_node(node):
|
||||||
raise InternalException(f'Node {node} not found in the graph!')
|
raise InternalException(f"Node {node} not found in the graph!")
|
||||||
des = nx.single_source_shortest_path_length(G=self.graph,
|
des = nx.single_source_shortest_path_length(
|
||||||
source=node,
|
G=self.graph, source=node, cutoff=max_depth
|
||||||
cutoff=max_depth)\
|
).keys()
|
||||||
.keys()
|
|
||||||
return des - {node}
|
return des - {node}
|
||||||
|
|
||||||
def select_childrens_parents(
|
def select_childrens_parents(self, selected: Set[UniqueId]) -> Set[UniqueId]:
|
||||||
self, selected: Set[UniqueId]
|
|
||||||
) -> Set[UniqueId]:
|
|
||||||
ancestors_for = self.select_children(selected) | selected
|
ancestors_for = self.select_children(selected) | selected
|
||||||
return self.select_parents(ancestors_for) | ancestors_for
|
return self.select_parents(ancestors_for) | ancestors_for
|
||||||
|
|
||||||
@@ -77,7 +72,7 @@ class Graph:
|
|||||||
successors.update(self.graph.successors(node))
|
successors.update(self.graph.successors(node))
|
||||||
return successors
|
return successors
|
||||||
|
|
||||||
def get_subset_graph(self, selected: Iterable[UniqueId]) -> 'Graph':
|
def get_subset_graph(self, selected: Iterable[UniqueId]) -> "Graph":
|
||||||
"""Create and return a new graph that is a shallow copy of the graph,
|
"""Create and return a new graph that is a shallow copy of the graph,
|
||||||
but with only the nodes in include_nodes. Transitive edges across
|
but with only the nodes in include_nodes. Transitive edges across
|
||||||
removed nodes are preserved as explicit new edges.
|
removed nodes are preserved as explicit new edges.
|
||||||
@@ -98,7 +93,7 @@ class Graph:
|
|||||||
)
|
)
|
||||||
return Graph(new_graph)
|
return Graph(new_graph)
|
||||||
|
|
||||||
def subgraph(self, nodes: Iterable[UniqueId]) -> 'Graph':
|
def subgraph(self, nodes: Iterable[UniqueId]) -> "Graph":
|
||||||
return Graph(self.graph.subgraph(nodes))
|
return Graph(self.graph.subgraph(nodes))
|
||||||
|
|
||||||
def get_dependent_nodes(self, node: UniqueId):
|
def get_dependent_nodes(self, node: UniqueId):
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
import threading
|
import threading
|
||||||
from queue import PriorityQueue
|
from queue import PriorityQueue
|
||||||
from typing import (
|
from typing import Dict, Set, Optional
|
||||||
Dict, Set, Optional
|
|
||||||
)
|
|
||||||
|
|
||||||
import networkx as nx # type: ignore
|
import networkx as nx # type: ignore
|
||||||
|
|
||||||
@@ -21,9 +19,8 @@ class GraphQueue:
|
|||||||
that separate threads do not call `.empty()` or `__len__()` and `.get()` at
|
that separate threads do not call `.empty()` or `__len__()` and `.get()` at
|
||||||
the same time, as there is an unlocked race!
|
the same time, as there is an unlocked race!
|
||||||
"""
|
"""
|
||||||
def __init__(
|
|
||||||
self, graph: nx.DiGraph, manifest: Manifest, selected: Set[UniqueId]
|
def __init__(self, graph: nx.DiGraph, manifest: Manifest, selected: Set[UniqueId]):
|
||||||
):
|
|
||||||
self.graph = graph
|
self.graph = graph
|
||||||
self.manifest = manifest
|
self.manifest = manifest
|
||||||
self._selected = selected
|
self._selected = selected
|
||||||
@@ -75,10 +72,13 @@ class GraphQueue:
|
|||||||
"""
|
"""
|
||||||
scores = {}
|
scores = {}
|
||||||
for node in self.graph.nodes():
|
for node in self.graph.nodes():
|
||||||
score = -1 * len([
|
score = -1 * len(
|
||||||
d for d in nx.descendants(self.graph, node)
|
[
|
||||||
if self._include_in_cost(d)
|
d
|
||||||
])
|
for d in nx.descendants(self.graph, node)
|
||||||
|
if self._include_in_cost(d)
|
||||||
|
]
|
||||||
|
)
|
||||||
scores[node] = score
|
scores[node] = score
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
from typing import Set, List, Optional
|
from typing import Set, List, Optional
|
||||||
|
|
||||||
from .graph import Graph, UniqueId
|
from .graph import Graph, UniqueId
|
||||||
@@ -25,14 +24,13 @@ def get_package_names(nodes):
|
|||||||
def alert_non_existence(raw_spec, nodes):
|
def alert_non_existence(raw_spec, nodes):
|
||||||
if len(nodes) == 0:
|
if len(nodes) == 0:
|
||||||
warn_or_error(
|
warn_or_error(
|
||||||
f"The selector '{str(raw_spec)}' does not match any nodes and will"
|
f"The selection criterion '{str(raw_spec)}' does not match" f" any nodes"
|
||||||
f" be ignored"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class NodeSelector(MethodManager):
|
class NodeSelector(MethodManager):
|
||||||
"""The node selector is aware of the graph and manifest,
|
"""The node selector is aware of the graph and manifest,"""
|
||||||
"""
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
graph: Graph,
|
graph: Graph,
|
||||||
@@ -45,13 +43,16 @@ class NodeSelector(MethodManager):
|
|||||||
# build a subgraph containing only non-empty, enabled nodes and enabled
|
# build a subgraph containing only non-empty, enabled nodes and enabled
|
||||||
# sources.
|
# sources.
|
||||||
graph_members = {
|
graph_members = {
|
||||||
unique_id for unique_id in self.full_graph.nodes()
|
unique_id
|
||||||
|
for unique_id in self.full_graph.nodes()
|
||||||
if self._is_graph_member(unique_id)
|
if self._is_graph_member(unique_id)
|
||||||
}
|
}
|
||||||
self.graph = self.full_graph.subgraph(graph_members)
|
self.graph = self.full_graph.subgraph(graph_members)
|
||||||
|
|
||||||
def select_included(
|
def select_included(
|
||||||
self, included_nodes: Set[UniqueId], spec: SelectionCriteria,
|
self,
|
||||||
|
included_nodes: Set[UniqueId],
|
||||||
|
spec: SelectionCriteria,
|
||||||
) -> Set[UniqueId]:
|
) -> Set[UniqueId]:
|
||||||
"""Select the explicitly included nodes, using the given spec. Return
|
"""Select the explicitly included nodes, using the given spec. Return
|
||||||
the selected set of unique IDs.
|
the selected set of unique IDs.
|
||||||
@@ -116,10 +117,7 @@ class NodeSelector(MethodManager):
|
|||||||
if isinstance(spec, SelectionCriteria):
|
if isinstance(spec, SelectionCriteria):
|
||||||
result = self.get_nodes_from_criteria(spec)
|
result = self.get_nodes_from_criteria(spec)
|
||||||
else:
|
else:
|
||||||
node_selections = [
|
node_selections = [self.select_nodes(component) for component in spec]
|
||||||
self.select_nodes(component)
|
|
||||||
for component in spec
|
|
||||||
]
|
|
||||||
result = spec.combined(node_selections)
|
result = spec.combined(node_selections)
|
||||||
if spec.expect_exists:
|
if spec.expect_exists:
|
||||||
alert_non_existence(spec.raw, result)
|
alert_non_existence(spec.raw, result)
|
||||||
@@ -149,18 +147,14 @@ class NodeSelector(MethodManager):
|
|||||||
elif unique_id in self.manifest.exposures:
|
elif unique_id in self.manifest.exposures:
|
||||||
node = self.manifest.exposures[unique_id]
|
node = self.manifest.exposures[unique_id]
|
||||||
else:
|
else:
|
||||||
raise InternalException(
|
raise InternalException(f"Node {unique_id} not found in the manifest!")
|
||||||
f'Node {unique_id} not found in the manifest!'
|
|
||||||
)
|
|
||||||
return self.node_is_match(node)
|
return self.node_is_match(node)
|
||||||
|
|
||||||
def filter_selection(self, selected: Set[UniqueId]) -> Set[UniqueId]:
|
def filter_selection(self, selected: Set[UniqueId]) -> Set[UniqueId]:
|
||||||
"""Return the subset of selected nodes that is a match for this
|
"""Return the subset of selected nodes that is a match for this
|
||||||
selector.
|
selector.
|
||||||
"""
|
"""
|
||||||
return {
|
return {unique_id for unique_id in selected if self._is_match(unique_id)}
|
||||||
unique_id for unique_id in selected if self._is_match(unique_id)
|
|
||||||
}
|
|
||||||
|
|
||||||
def expand_selection(self, selected: Set[UniqueId]) -> Set[UniqueId]:
|
def expand_selection(self, selected: Set[UniqueId]) -> Set[UniqueId]:
|
||||||
"""Perform selector-specific expansion."""
|
"""Perform selector-specific expansion."""
|
||||||
@@ -169,14 +163,14 @@ class NodeSelector(MethodManager):
|
|||||||
def get_selected(self, spec: SelectionSpec) -> Set[UniqueId]:
|
def get_selected(self, spec: SelectionSpec) -> Set[UniqueId]:
|
||||||
"""get_selected runs trhough the node selection process:
|
"""get_selected runs trhough the node selection process:
|
||||||
|
|
||||||
- node selection. Based on the include/exclude sets, the set
|
- node selection. Based on the include/exclude sets, the set
|
||||||
of matched unique IDs is returned
|
of matched unique IDs is returned
|
||||||
- expand the graph at each leaf node, before combination
|
- expand the graph at each leaf node, before combination
|
||||||
- selectors might override this. for example, this is where
|
- selectors might override this. for example, this is where
|
||||||
tests are added
|
tests are added
|
||||||
- filtering:
|
- filtering:
|
||||||
- selectors can filter the nodes after all of them have been
|
- selectors can filter the nodes after all of them have been
|
||||||
selected
|
selected
|
||||||
"""
|
"""
|
||||||
selected_nodes = self.select_nodes(spec)
|
selected_nodes = self.select_nodes(spec)
|
||||||
filtered_nodes = self.filter_selection(selected_nodes)
|
filtered_nodes = self.filter_selection(selected_nodes)
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from itertools import chain
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Set, List, Dict, Iterator, Tuple, Any, Union, Type, Optional
|
from typing import Set, List, Dict, Iterator, Tuple, Any, Union, Type, Optional
|
||||||
|
|
||||||
from hologram.helpers import StrEnum
|
from dbt.dataclass_schema import StrEnum
|
||||||
|
|
||||||
from .graph import UniqueId
|
from .graph import UniqueId
|
||||||
|
|
||||||
@@ -31,28 +31,28 @@ from dbt.node_types import NodeType
|
|||||||
from dbt.ui import warning_tag
|
from dbt.ui import warning_tag
|
||||||
|
|
||||||
|
|
||||||
SELECTOR_GLOB = '*'
|
SELECTOR_GLOB = "*"
|
||||||
SELECTOR_DELIMITER = ':'
|
SELECTOR_DELIMITER = ":"
|
||||||
|
|
||||||
|
|
||||||
class MethodName(StrEnum):
|
class MethodName(StrEnum):
|
||||||
FQN = 'fqn'
|
FQN = "fqn"
|
||||||
Tag = 'tag'
|
Tag = "tag"
|
||||||
Source = 'source'
|
Source = "source"
|
||||||
Path = 'path'
|
Path = "path"
|
||||||
Package = 'package'
|
Package = "package"
|
||||||
Config = 'config'
|
Config = "config"
|
||||||
TestName = 'test_name'
|
TestName = "test_name"
|
||||||
TestType = 'test_type'
|
TestType = "test_type"
|
||||||
ResourceType = 'resource_type'
|
ResourceType = "resource_type"
|
||||||
State = 'state'
|
State = "state"
|
||||||
Exposure = 'exposure'
|
Exposure = "exposure"
|
||||||
|
|
||||||
|
|
||||||
def is_selected_node(real_node, node_selector):
|
def is_selected_node(real_node, node_selector):
|
||||||
for i, selector_part in enumerate(node_selector):
|
for i, selector_part in enumerate(node_selector):
|
||||||
|
|
||||||
is_last = (i == len(node_selector) - 1)
|
is_last = i == len(node_selector) - 1
|
||||||
|
|
||||||
# if we hit a GLOB, then this node is selected
|
# if we hit a GLOB, then this node is selected
|
||||||
if selector_part == SELECTOR_GLOB:
|
if selector_part == SELECTOR_GLOB:
|
||||||
@@ -83,15 +83,14 @@ class SelectorMethod(metaclass=abc.ABCMeta):
|
|||||||
self,
|
self,
|
||||||
manifest: Manifest,
|
manifest: Manifest,
|
||||||
previous_state: Optional[PreviousState],
|
previous_state: Optional[PreviousState],
|
||||||
arguments: List[str]
|
arguments: List[str],
|
||||||
):
|
):
|
||||||
self.manifest: Manifest = manifest
|
self.manifest: Manifest = manifest
|
||||||
self.previous_state = previous_state
|
self.previous_state = previous_state
|
||||||
self.arguments: List[str] = arguments
|
self.arguments: List[str] = arguments
|
||||||
|
|
||||||
def parsed_nodes(
|
def parsed_nodes(
|
||||||
self,
|
self, included_nodes: Set[UniqueId]
|
||||||
included_nodes: Set[UniqueId]
|
|
||||||
) -> Iterator[Tuple[UniqueId, ManifestNode]]:
|
) -> Iterator[Tuple[UniqueId, ManifestNode]]:
|
||||||
|
|
||||||
for key, node in self.manifest.nodes.items():
|
for key, node in self.manifest.nodes.items():
|
||||||
@@ -101,8 +100,7 @@ class SelectorMethod(metaclass=abc.ABCMeta):
|
|||||||
yield unique_id, node
|
yield unique_id, node
|
||||||
|
|
||||||
def source_nodes(
|
def source_nodes(
|
||||||
self,
|
self, included_nodes: Set[UniqueId]
|
||||||
included_nodes: Set[UniqueId]
|
|
||||||
) -> Iterator[Tuple[UniqueId, ParsedSourceDefinition]]:
|
) -> Iterator[Tuple[UniqueId, ParsedSourceDefinition]]:
|
||||||
|
|
||||||
for key, source in self.manifest.sources.items():
|
for key, source in self.manifest.sources.items():
|
||||||
@@ -112,8 +110,7 @@ class SelectorMethod(metaclass=abc.ABCMeta):
|
|||||||
yield unique_id, source
|
yield unique_id, source
|
||||||
|
|
||||||
def exposure_nodes(
|
def exposure_nodes(
|
||||||
self,
|
self, included_nodes: Set[UniqueId]
|
||||||
included_nodes: Set[UniqueId]
|
|
||||||
) -> Iterator[Tuple[UniqueId, ParsedExposure]]:
|
) -> Iterator[Tuple[UniqueId, ParsedExposure]]:
|
||||||
|
|
||||||
for key, exposure in self.manifest.exposures.items():
|
for key, exposure in self.manifest.exposures.items():
|
||||||
@@ -123,26 +120,28 @@ class SelectorMethod(metaclass=abc.ABCMeta):
|
|||||||
yield unique_id, exposure
|
yield unique_id, exposure
|
||||||
|
|
||||||
def all_nodes(
|
def all_nodes(
|
||||||
self,
|
self, included_nodes: Set[UniqueId]
|
||||||
included_nodes: Set[UniqueId]
|
|
||||||
) -> Iterator[Tuple[UniqueId, SelectorTarget]]:
|
) -> Iterator[Tuple[UniqueId, SelectorTarget]]:
|
||||||
yield from chain(self.parsed_nodes(included_nodes),
|
yield from chain(
|
||||||
self.source_nodes(included_nodes),
|
self.parsed_nodes(included_nodes),
|
||||||
self.exposure_nodes(included_nodes))
|
self.source_nodes(included_nodes),
|
||||||
|
self.exposure_nodes(included_nodes),
|
||||||
|
)
|
||||||
|
|
||||||
def configurable_nodes(
|
def configurable_nodes(
|
||||||
self,
|
self, included_nodes: Set[UniqueId]
|
||||||
included_nodes: Set[UniqueId]
|
|
||||||
) -> Iterator[Tuple[UniqueId, CompileResultNode]]:
|
) -> Iterator[Tuple[UniqueId, CompileResultNode]]:
|
||||||
yield from chain(self.parsed_nodes(included_nodes),
|
yield from chain(
|
||||||
self.source_nodes(included_nodes))
|
self.parsed_nodes(included_nodes), self.source_nodes(included_nodes)
|
||||||
|
)
|
||||||
|
|
||||||
def non_source_nodes(
|
def non_source_nodes(
|
||||||
self,
|
self,
|
||||||
included_nodes: Set[UniqueId],
|
included_nodes: Set[UniqueId],
|
||||||
) -> Iterator[Tuple[UniqueId, Union[ParsedExposure, ManifestNode]]]:
|
) -> Iterator[Tuple[UniqueId, Union[ParsedExposure, ManifestNode]]]:
|
||||||
yield from chain(self.parsed_nodes(included_nodes),
|
yield from chain(
|
||||||
self.exposure_nodes(included_nodes))
|
self.parsed_nodes(included_nodes), self.exposure_nodes(included_nodes)
|
||||||
|
)
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def search(
|
def search(
|
||||||
@@ -150,7 +149,7 @@ class SelectorMethod(metaclass=abc.ABCMeta):
|
|||||||
included_nodes: Set[UniqueId],
|
included_nodes: Set[UniqueId],
|
||||||
selector: str,
|
selector: str,
|
||||||
) -> Iterator[UniqueId]:
|
) -> Iterator[UniqueId]:
|
||||||
raise NotImplementedError('subclasses should implement this')
|
raise NotImplementedError("subclasses should implement this")
|
||||||
|
|
||||||
|
|
||||||
class QualifiedNameSelectorMethod(SelectorMethod):
|
class QualifiedNameSelectorMethod(SelectorMethod):
|
||||||
@@ -216,7 +215,7 @@ class SourceSelectorMethod(SelectorMethod):
|
|||||||
self, included_nodes: Set[UniqueId], selector: str
|
self, included_nodes: Set[UniqueId], selector: str
|
||||||
) -> Iterator[UniqueId]:
|
) -> Iterator[UniqueId]:
|
||||||
"""yields nodes from included are the specified source."""
|
"""yields nodes from included are the specified source."""
|
||||||
parts = selector.split('.')
|
parts = selector.split(".")
|
||||||
target_package = SELECTOR_GLOB
|
target_package = SELECTOR_GLOB
|
||||||
if len(parts) == 1:
|
if len(parts) == 1:
|
||||||
target_source, target_table = parts[0], None
|
target_source, target_table = parts[0], None
|
||||||
@@ -227,9 +226,9 @@ class SourceSelectorMethod(SelectorMethod):
|
|||||||
else: # len(parts) > 3 or len(parts) == 0
|
else: # len(parts) > 3 or len(parts) == 0
|
||||||
msg = (
|
msg = (
|
||||||
'Invalid source selector value "{}". Sources must be of the '
|
'Invalid source selector value "{}". Sources must be of the '
|
||||||
'form `${{source_name}}`, '
|
"form `${{source_name}}`, "
|
||||||
'`${{source_name}}.${{target_name}}`, or '
|
"`${{source_name}}.${{target_name}}`, or "
|
||||||
'`${{package_name}}.${{source_name}}.${{target_name}}'
|
"`${{package_name}}.${{source_name}}.${{target_name}}"
|
||||||
).format(selector)
|
).format(selector)
|
||||||
raise RuntimeException(msg)
|
raise RuntimeException(msg)
|
||||||
|
|
||||||
@@ -248,7 +247,7 @@ class ExposureSelectorMethod(SelectorMethod):
|
|||||||
def search(
|
def search(
|
||||||
self, included_nodes: Set[UniqueId], selector: str
|
self, included_nodes: Set[UniqueId], selector: str
|
||||||
) -> Iterator[UniqueId]:
|
) -> Iterator[UniqueId]:
|
||||||
parts = selector.split('.')
|
parts = selector.split(".")
|
||||||
target_package = SELECTOR_GLOB
|
target_package = SELECTOR_GLOB
|
||||||
if len(parts) == 1:
|
if len(parts) == 1:
|
||||||
target_name = parts[0]
|
target_name = parts[0]
|
||||||
@@ -257,8 +256,8 @@ class ExposureSelectorMethod(SelectorMethod):
|
|||||||
else:
|
else:
|
||||||
msg = (
|
msg = (
|
||||||
'Invalid exposure selector value "{}". Exposures must be of '
|
'Invalid exposure selector value "{}". Exposures must be of '
|
||||||
'the form ${{exposure_name}} or '
|
"the form ${{exposure_name}} or "
|
||||||
'${{exposure_package.exposure_name}}'
|
"${{exposure_package.exposure_name}}"
|
||||||
).format(selector)
|
).format(selector)
|
||||||
raise RuntimeException(msg)
|
raise RuntimeException(msg)
|
||||||
|
|
||||||
@@ -275,9 +274,7 @@ class PathSelectorMethod(SelectorMethod):
|
|||||||
def search(
|
def search(
|
||||||
self, included_nodes: Set[UniqueId], selector: str
|
self, included_nodes: Set[UniqueId], selector: str
|
||||||
) -> Iterator[UniqueId]:
|
) -> Iterator[UniqueId]:
|
||||||
"""Yields nodes from inclucded that match the given path.
|
"""Yields nodes from inclucded that match the given path."""
|
||||||
|
|
||||||
"""
|
|
||||||
# use '.' and not 'root' for easy comparison
|
# use '.' and not 'root' for easy comparison
|
||||||
root = Path.cwd()
|
root = Path.cwd()
|
||||||
paths = set(p.relative_to(root) for p in root.glob(selector))
|
paths = set(p.relative_to(root) for p in root.glob(selector))
|
||||||
@@ -336,7 +333,7 @@ class ConfigSelectorMethod(SelectorMethod):
|
|||||||
parts = self.arguments
|
parts = self.arguments
|
||||||
# special case: if the user wanted to compare test severity,
|
# special case: if the user wanted to compare test severity,
|
||||||
# make the comparison case-insensitive
|
# make the comparison case-insensitive
|
||||||
if parts == ['severity']:
|
if parts == ["severity"]:
|
||||||
selector = CaseInsensitive(selector)
|
selector = CaseInsensitive(selector)
|
||||||
|
|
||||||
# search sources is kind of useless now source configs only have
|
# search sources is kind of useless now source configs only have
|
||||||
@@ -382,14 +379,13 @@ class TestTypeSelectorMethod(SelectorMethod):
|
|||||||
self, included_nodes: Set[UniqueId], selector: str
|
self, included_nodes: Set[UniqueId], selector: str
|
||||||
) -> Iterator[UniqueId]:
|
) -> Iterator[UniqueId]:
|
||||||
search_types: Tuple[Type, ...]
|
search_types: Tuple[Type, ...]
|
||||||
if selector == 'schema':
|
if selector == "schema":
|
||||||
search_types = (ParsedSchemaTestNode, CompiledSchemaTestNode)
|
search_types = (ParsedSchemaTestNode, CompiledSchemaTestNode)
|
||||||
elif selector == 'data':
|
elif selector == "data":
|
||||||
search_types = (ParsedDataTestNode, CompiledDataTestNode)
|
search_types = (ParsedDataTestNode, CompiledDataTestNode)
|
||||||
else:
|
else:
|
||||||
raise RuntimeException(
|
raise RuntimeException(
|
||||||
f'Invalid test type selector {selector}: expected "data" or '
|
f'Invalid test type selector {selector}: expected "data" or ' '"schema"'
|
||||||
'"schema"'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for node, real_node in self.parsed_nodes(included_nodes):
|
for node, real_node in self.parsed_nodes(included_nodes):
|
||||||
@@ -405,25 +401,23 @@ class StateSelectorMethod(SelectorMethod):
|
|||||||
def _macros_modified(self) -> List[str]:
|
def _macros_modified(self) -> List[str]:
|
||||||
# we checked in the caller!
|
# we checked in the caller!
|
||||||
if self.previous_state is None or self.previous_state.manifest is None:
|
if self.previous_state is None or self.previous_state.manifest is None:
|
||||||
raise InternalException(
|
raise InternalException("No comparison manifest in _macros_modified")
|
||||||
'No comparison manifest in _macros_modified'
|
|
||||||
)
|
|
||||||
old_macros = self.previous_state.manifest.macros
|
old_macros = self.previous_state.manifest.macros
|
||||||
new_macros = self.manifest.macros
|
new_macros = self.manifest.macros
|
||||||
|
|
||||||
modified = []
|
modified = []
|
||||||
for uid, macro in new_macros.items():
|
for uid, macro in new_macros.items():
|
||||||
name = f'{macro.package_name}.{macro.name}'
|
name = f"{macro.package_name}.{macro.name}"
|
||||||
if uid in old_macros:
|
if uid in old_macros:
|
||||||
old_macro = old_macros[uid]
|
old_macro = old_macros[uid]
|
||||||
if macro.macro_sql != old_macro.macro_sql:
|
if macro.macro_sql != old_macro.macro_sql:
|
||||||
modified.append(f'{name} changed')
|
modified.append(f"{name} changed")
|
||||||
else:
|
else:
|
||||||
modified.append(f'{name} added')
|
modified.append(f"{name} added")
|
||||||
|
|
||||||
for uid, macro in old_macros.items():
|
for uid, macro in old_macros.items():
|
||||||
if uid not in new_macros:
|
if uid not in new_macros:
|
||||||
modified.append(f'{macro.package_name}.{macro.name} removed')
|
modified.append(f"{macro.package_name}.{macro.name} removed")
|
||||||
|
|
||||||
return modified[:3]
|
return modified[:3]
|
||||||
|
|
||||||
@@ -437,12 +431,14 @@ class StateSelectorMethod(SelectorMethod):
|
|||||||
if self.macros_were_modified is None:
|
if self.macros_were_modified is None:
|
||||||
self.macros_were_modified = self._macros_modified()
|
self.macros_were_modified = self._macros_modified()
|
||||||
if self.macros_were_modified:
|
if self.macros_were_modified:
|
||||||
log_str = ', '.join(self.macros_were_modified)
|
log_str = ", ".join(self.macros_were_modified)
|
||||||
logger.warning(warning_tag(
|
logger.warning(
|
||||||
f'During a state comparison, dbt detected a change in '
|
warning_tag(
|
||||||
f'macros. This will not be marked as a modification. Some '
|
f"During a state comparison, dbt detected a change in "
|
||||||
f'macros: {log_str}'
|
f"macros. This will not be marked as a modification. Some "
|
||||||
))
|
f"macros: {log_str}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return not new.same_contents(old) # type: ignore
|
return not new.same_contents(old) # type: ignore
|
||||||
|
|
||||||
@@ -458,12 +454,12 @@ class StateSelectorMethod(SelectorMethod):
|
|||||||
) -> Iterator[UniqueId]:
|
) -> Iterator[UniqueId]:
|
||||||
if self.previous_state is None or self.previous_state.manifest is None:
|
if self.previous_state is None or self.previous_state.manifest is None:
|
||||||
raise RuntimeException(
|
raise RuntimeException(
|
||||||
'Got a state selector method, but no comparison manifest'
|
"Got a state selector method, but no comparison manifest"
|
||||||
)
|
)
|
||||||
|
|
||||||
state_checks = {
|
state_checks = {
|
||||||
'modified': self.check_modified,
|
"modified": self.check_modified,
|
||||||
'new': self.check_new,
|
"new": self.check_new,
|
||||||
}
|
}
|
||||||
if selector in state_checks:
|
if selector in state_checks:
|
||||||
checker = state_checks[selector]
|
checker = state_checks[selector]
|
||||||
@@ -517,7 +513,7 @@ class MethodManager:
|
|||||||
if method not in self.SELECTOR_METHODS:
|
if method not in self.SELECTOR_METHODS:
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
f'Method name "{method}" is a valid node selection '
|
f'Method name "{method}" is a valid node selection '
|
||||||
f'method name, but it is not handled'
|
f"method name, but it is not handled"
|
||||||
)
|
)
|
||||||
cls: Type[SelectorMethod] = self.SELECTOR_METHODS[method]
|
cls: Type[SelectorMethod] = self.SELECTOR_METHODS[method]
|
||||||
return cls(self.manifest, self.previous_state, method_arguments)
|
return cls(self.manifest, self.previous_state, method_arguments)
|
||||||
|
|||||||
@@ -3,23 +3,21 @@ import re
|
|||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from typing import (
|
from typing import Set, Iterator, List, Optional, Dict, Union, Any, Iterable, Tuple
|
||||||
Set, Iterator, List, Optional, Dict, Union, Any, Iterable, Tuple
|
|
||||||
)
|
|
||||||
from .graph import UniqueId
|
from .graph import UniqueId
|
||||||
from .selector_methods import MethodName
|
from .selector_methods import MethodName
|
||||||
from dbt.exceptions import RuntimeException, InvalidSelectorException
|
from dbt.exceptions import RuntimeException, InvalidSelectorException
|
||||||
|
|
||||||
|
|
||||||
RAW_SELECTOR_PATTERN = re.compile(
|
RAW_SELECTOR_PATTERN = re.compile(
|
||||||
r'\A'
|
r"\A"
|
||||||
r'(?P<childrens_parents>(\@))?'
|
r"(?P<childrens_parents>(\@))?"
|
||||||
r'(?P<parents>((?P<parents_depth>(\d*))\+))?'
|
r"(?P<parents>((?P<parents_depth>(\d*))\+))?"
|
||||||
r'((?P<method>([\w.]+)):)?(?P<value>(.*?))'
|
r"((?P<method>([\w.]+)):)?(?P<value>(.*?))"
|
||||||
r'(?P<children>(\+(?P<children_depth>(\d*))))?'
|
r"(?P<children>(\+(?P<children_depth>(\d*))))?"
|
||||||
r'\Z'
|
r"\Z"
|
||||||
)
|
)
|
||||||
SELECTOR_METHOD_SEPARATOR = '.'
|
SELECTOR_METHOD_SEPARATOR = "."
|
||||||
|
|
||||||
|
|
||||||
def _probably_path(value: str):
|
def _probably_path(value: str):
|
||||||
@@ -43,15 +41,15 @@ def _match_to_int(match: Dict[str, str], key: str) -> Optional[int]:
|
|||||||
return int(raw)
|
return int(raw)
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
raise RuntimeException(
|
raise RuntimeException(
|
||||||
f'Invalid node spec - could not handle parent depth {raw}'
|
f"Invalid node spec - could not handle parent depth {raw}"
|
||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
|
|
||||||
SelectionSpec = Union[
|
SelectionSpec = Union[
|
||||||
'SelectionCriteria',
|
"SelectionCriteria",
|
||||||
'SelectionIntersection',
|
"SelectionIntersection",
|
||||||
'SelectionDifference',
|
"SelectionDifference",
|
||||||
'SelectionUnion',
|
"SelectionUnion",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -71,7 +69,7 @@ class SelectionCriteria:
|
|||||||
if self.children and self.childrens_parents:
|
if self.children and self.childrens_parents:
|
||||||
raise RuntimeException(
|
raise RuntimeException(
|
||||||
f'Invalid node spec {self.raw} - "@" prefix and "+" suffix '
|
f'Invalid node spec {self.raw} - "@" prefix and "+" suffix '
|
||||||
'are incompatible'
|
"are incompatible"
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -82,12 +80,10 @@ class SelectionCriteria:
|
|||||||
return MethodName.FQN
|
return MethodName.FQN
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parse_method(
|
def parse_method(cls, groupdict: Dict[str, Any]) -> Tuple[MethodName, List[str]]:
|
||||||
cls, groupdict: Dict[str, Any]
|
raw_method = groupdict.get("method")
|
||||||
) -> Tuple[MethodName, List[str]]:
|
|
||||||
raw_method = groupdict.get('method')
|
|
||||||
if raw_method is None:
|
if raw_method is None:
|
||||||
return cls.default_method(groupdict['value']), []
|
return cls.default_method(groupdict["value"]), []
|
||||||
|
|
||||||
method_parts: List[str] = raw_method.split(SELECTOR_METHOD_SEPARATOR)
|
method_parts: List[str] = raw_method.split(SELECTOR_METHOD_SEPARATOR)
|
||||||
try:
|
try:
|
||||||
@@ -102,24 +98,24 @@ class SelectionCriteria:
|
|||||||
return method_name, method_arguments
|
return method_name, method_arguments
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, raw: Any, dct: Dict[str, Any]) -> 'SelectionCriteria':
|
def selection_criteria_from_dict(
|
||||||
if 'value' not in dct:
|
cls, raw: Any, dct: Dict[str, Any]
|
||||||
raise RuntimeException(
|
) -> "SelectionCriteria":
|
||||||
f'Invalid node spec "{raw}" - no search value!'
|
if "value" not in dct:
|
||||||
)
|
raise RuntimeException(f'Invalid node spec "{raw}" - no search value!')
|
||||||
method_name, method_arguments = cls.parse_method(dct)
|
method_name, method_arguments = cls.parse_method(dct)
|
||||||
|
|
||||||
parents_depth = _match_to_int(dct, 'parents_depth')
|
parents_depth = _match_to_int(dct, "parents_depth")
|
||||||
children_depth = _match_to_int(dct, 'children_depth')
|
children_depth = _match_to_int(dct, "children_depth")
|
||||||
return cls(
|
return cls(
|
||||||
raw=raw,
|
raw=raw,
|
||||||
method=method_name,
|
method=method_name,
|
||||||
method_arguments=method_arguments,
|
method_arguments=method_arguments,
|
||||||
value=dct['value'],
|
value=dct["value"],
|
||||||
childrens_parents=bool(dct.get('childrens_parents')),
|
childrens_parents=bool(dct.get("childrens_parents")),
|
||||||
parents=bool(dct.get('parents')),
|
parents=bool(dct.get("parents")),
|
||||||
parents_depth=parents_depth,
|
parents_depth=parents_depth,
|
||||||
children=bool(dct.get('children')),
|
children=bool(dct.get("children")),
|
||||||
children_depth=children_depth,
|
children_depth=children_depth,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -127,30 +123,30 @@ class SelectionCriteria:
|
|||||||
def dict_from_single_spec(cls, raw: str):
|
def dict_from_single_spec(cls, raw: str):
|
||||||
result = RAW_SELECTOR_PATTERN.match(raw)
|
result = RAW_SELECTOR_PATTERN.match(raw)
|
||||||
if result is None:
|
if result is None:
|
||||||
return {'error': 'Invalid selector spec'}
|
return {"error": "Invalid selector spec"}
|
||||||
dct: Dict[str, Any] = result.groupdict()
|
dct: Dict[str, Any] = result.groupdict()
|
||||||
method_name, method_arguments = cls.parse_method(dct)
|
method_name, method_arguments = cls.parse_method(dct)
|
||||||
meth_name = str(method_name)
|
meth_name = str(method_name)
|
||||||
if method_arguments:
|
if method_arguments:
|
||||||
meth_name = meth_name + '.' + '.'.join(method_arguments)
|
meth_name = meth_name + "." + ".".join(method_arguments)
|
||||||
dct['method'] = meth_name
|
dct["method"] = meth_name
|
||||||
dct = {k: v for k, v in dct.items() if (v is not None and v != '')}
|
dct = {k: v for k, v in dct.items() if (v is not None and v != "")}
|
||||||
if 'childrens_parents' in dct:
|
if "childrens_parents" in dct:
|
||||||
dct['childrens_parents'] = bool(dct.get('childrens_parents'))
|
dct["childrens_parents"] = bool(dct.get("childrens_parents"))
|
||||||
if 'parents' in dct:
|
if "parents" in dct:
|
||||||
dct['parents'] = bool(dct.get('parents'))
|
dct["parents"] = bool(dct.get("parents"))
|
||||||
if 'children' in dct:
|
if "children" in dct:
|
||||||
dct['children'] = bool(dct.get('children'))
|
dct["children"] = bool(dct.get("children"))
|
||||||
return dct
|
return dct
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_single_spec(cls, raw: str) -> 'SelectionCriteria':
|
def from_single_spec(cls, raw: str) -> "SelectionCriteria":
|
||||||
result = RAW_SELECTOR_PATTERN.match(raw)
|
result = RAW_SELECTOR_PATTERN.match(raw)
|
||||||
if result is None:
|
if result is None:
|
||||||
# bad spec!
|
# bad spec!
|
||||||
raise RuntimeException(f'Invalid selector spec "{raw}"')
|
raise RuntimeException(f'Invalid selector spec "{raw}"')
|
||||||
|
|
||||||
return cls.from_dict(raw, result.groupdict())
|
return cls.selection_criteria_from_dict(raw, result.groupdict())
|
||||||
|
|
||||||
|
|
||||||
class BaseSelectionGroup(Iterable[SelectionSpec], metaclass=ABCMeta):
|
class BaseSelectionGroup(Iterable[SelectionSpec], metaclass=ABCMeta):
|
||||||
@@ -173,9 +169,7 @@ class BaseSelectionGroup(Iterable[SelectionSpec], metaclass=ABCMeta):
|
|||||||
self,
|
self,
|
||||||
selections: List[Set[UniqueId]],
|
selections: List[Set[UniqueId]],
|
||||||
) -> Set[UniqueId]:
|
) -> Set[UniqueId]:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError("_combine_selections not implemented!")
|
||||||
'_combine_selections not implemented!'
|
|
||||||
)
|
|
||||||
|
|
||||||
def combined(self, selections: List[Set[UniqueId]]) -> Set[UniqueId]:
|
def combined(self, selections: List[Set[UniqueId]]) -> Set[UniqueId]:
|
||||||
if not selections:
|
if not selections:
|
||||||
|
|||||||
@@ -2,20 +2,35 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import NewType, Tuple, AbstractSet
|
from typing import Tuple, AbstractSet, Union
|
||||||
|
|
||||||
from hologram import (
|
from dbt.dataclass_schema import (
|
||||||
FieldEncoder, JsonSchemaMixin, JsonDict, ValidationError
|
dbtClassMixin,
|
||||||
|
ValidationError,
|
||||||
|
StrEnum,
|
||||||
)
|
)
|
||||||
from hologram.helpers import StrEnum
|
from hologram import FieldEncoder, JsonDict
|
||||||
|
from mashumaro.types import SerializableType
|
||||||
|
|
||||||
Port = NewType('Port', int)
|
|
||||||
|
class Port(int, SerializableType):
|
||||||
|
@classmethod
|
||||||
|
def _deserialize(cls, value: Union[int, str]) -> "Port":
|
||||||
|
try:
|
||||||
|
value = int(value)
|
||||||
|
except ValueError:
|
||||||
|
raise ValidationError(f"Cannot encode {value} into port number")
|
||||||
|
|
||||||
|
return Port(value)
|
||||||
|
|
||||||
|
def _serialize(self) -> int:
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class PortEncoder(FieldEncoder):
|
class PortEncoder(FieldEncoder):
|
||||||
@property
|
@property
|
||||||
def json_schema(self):
|
def json_schema(self):
|
||||||
return {'type': 'integer', 'minimum': 0, 'maximum': 65535}
|
return {"type": "integer", "minimum": 0, "maximum": 65535}
|
||||||
|
|
||||||
|
|
||||||
class TimeDeltaFieldEncoder(FieldEncoder[timedelta]):
|
class TimeDeltaFieldEncoder(FieldEncoder[timedelta]):
|
||||||
@@ -31,12 +46,12 @@ class TimeDeltaFieldEncoder(FieldEncoder[timedelta]):
|
|||||||
return timedelta(seconds=value)
|
return timedelta(seconds=value)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
raise ValidationError(
|
raise ValidationError(
|
||||||
'cannot encode {} into timedelta'.format(value)
|
"cannot encode {} into timedelta".format(value)
|
||||||
) from None
|
) from None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def json_schema(self) -> JsonDict:
|
def json_schema(self) -> JsonDict:
|
||||||
return {'type': 'number'}
|
return {"type": "number"}
|
||||||
|
|
||||||
|
|
||||||
class PathEncoder(FieldEncoder):
|
class PathEncoder(FieldEncoder):
|
||||||
@@ -50,32 +65,35 @@ class PathEncoder(FieldEncoder):
|
|||||||
return Path(value)
|
return Path(value)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
raise ValidationError(
|
raise ValidationError(
|
||||||
'cannot encode {} into timedelta'.format(value)
|
"cannot encode {} into timedelta".format(value)
|
||||||
) from None
|
) from None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def json_schema(self) -> JsonDict:
|
def json_schema(self) -> JsonDict:
|
||||||
return {'type': 'string'}
|
return {"type": "string"}
|
||||||
|
|
||||||
|
|
||||||
class NVEnum(StrEnum):
|
class NVEnum(StrEnum):
|
||||||
novalue = 'novalue'
|
novalue = "novalue"
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return isinstance(other, NVEnum)
|
return isinstance(other, NVEnum)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class NoValue(JsonSchemaMixin):
|
class NoValue(dbtClassMixin):
|
||||||
"""Sometimes, you want a way to say none that isn't None"""
|
"""Sometimes, you want a way to say none that isn't None"""
|
||||||
|
|
||||||
novalue: NVEnum = NVEnum.novalue
|
novalue: NVEnum = NVEnum.novalue
|
||||||
|
|
||||||
|
|
||||||
JsonSchemaMixin.register_field_encoders({
|
dbtClassMixin.register_field_encoders(
|
||||||
Port: PortEncoder(),
|
{
|
||||||
timedelta: TimeDeltaFieldEncoder(),
|
Port: PortEncoder(),
|
||||||
Path: PathEncoder(),
|
timedelta: TimeDeltaFieldEncoder(),
|
||||||
})
|
Path: PathEncoder(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
FQNPath = Tuple[str, ...]
|
FQNPath = Tuple[str, ...]
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
from hologram.helpers import StrEnum
|
from dbt.dataclass_schema import StrEnum
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from typing import Union, Dict, Any
|
from typing import Union, Dict, Any
|
||||||
|
|
||||||
|
|
||||||
class ModelHookType(StrEnum):
|
class ModelHookType(StrEnum):
|
||||||
PreHook = 'pre-hook'
|
PreHook = "pre-hook"
|
||||||
PostHook = 'post-hook'
|
PostHook = "post-hook"
|
||||||
|
|
||||||
|
|
||||||
def get_hook_dict(source: Union[str, Dict[str, Any]]) -> Dict[str, Any]:
|
def get_hook_dict(source: Union[str, Dict[str, Any]]) -> Dict[str, Any]:
|
||||||
@@ -18,4 +18,4 @@ def get_hook_dict(source: Union[str, Dict[str, Any]]) -> Dict[str, Any]:
|
|||||||
try:
|
try:
|
||||||
return json.loads(source)
|
return json.loads(source)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return {'sql': source}
|
return {"sql": source}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
PACKAGE_PATH = os.path.dirname(__file__)
|
PACKAGE_PATH = os.path.dirname(__file__)
|
||||||
PROJECT_NAME = 'dbt'
|
PROJECT_NAME = "dbt"
|
||||||
|
|
||||||
DOCS_INDEX_FILE_PATH = os.path.normpath(
|
DOCS_INDEX_FILE_PATH = os.path.normpath(os.path.join(PACKAGE_PATH, "..", "index.html"))
|
||||||
os.path.join(PACKAGE_PATH, '..', "index.html"))
|
|
||||||
|
|||||||
@@ -287,4 +287,3 @@
|
|||||||
{% macro set_sql_header(config) -%}
|
{% macro set_sql_header(config) -%}
|
||||||
{{ config.set('sql_header', caller()) }}
|
{{ config.set('sql_header', caller()) }}
|
||||||
{%- endmacro %}
|
{%- endmacro %}
|
||||||
|
|
||||||
|
|||||||
@@ -70,7 +70,7 @@
|
|||||||
|
|
||||||
deletes_source_data as (
|
deletes_source_data as (
|
||||||
|
|
||||||
select
|
select
|
||||||
*,
|
*,
|
||||||
{{ strategy.unique_key }} as dbt_unique_key
|
{{ strategy.unique_key }} as dbt_unique_key
|
||||||
from snapshot_query
|
from snapshot_query
|
||||||
@@ -113,7 +113,7 @@
|
|||||||
,
|
,
|
||||||
|
|
||||||
deletes as (
|
deletes as (
|
||||||
|
|
||||||
select
|
select
|
||||||
'delete' as dbt_change_type,
|
'delete' as dbt_change_type,
|
||||||
source_data.*,
|
source_data.*,
|
||||||
@@ -121,7 +121,7 @@
|
|||||||
{{ snapshot_get_time() }} as dbt_updated_at,
|
{{ snapshot_get_time() }} as dbt_updated_at,
|
||||||
{{ snapshot_get_time() }} as dbt_valid_to,
|
{{ snapshot_get_time() }} as dbt_valid_to,
|
||||||
snapshotted_data.dbt_scd_id
|
snapshotted_data.dbt_scd_id
|
||||||
|
|
||||||
from snapshotted_data
|
from snapshotted_data
|
||||||
left join deletes_source_data as source_data on snapshotted_data.dbt_unique_key = source_data.dbt_unique_key
|
left join deletes_source_data as source_data on snapshotted_data.dbt_unique_key = source_data.dbt_unique_key
|
||||||
where source_data.dbt_unique_key is null
|
where source_data.dbt_unique_key is null
|
||||||
|
|||||||
@@ -23,5 +23,3 @@
|
|||||||
values ({{ insert_cols_csv }})
|
values ({{ insert_cols_csv }})
|
||||||
;
|
;
|
||||||
{% endmacro %}
|
{% endmacro %}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -134,7 +134,7 @@
|
|||||||
{% set check_cols_config = config['check_cols'] %}
|
{% set check_cols_config = config['check_cols'] %}
|
||||||
{% set primary_key = config['unique_key'] %}
|
{% set primary_key = config['unique_key'] %}
|
||||||
{% set invalidate_hard_deletes = config.get('invalidate_hard_deletes', false) %}
|
{% set invalidate_hard_deletes = config.get('invalidate_hard_deletes', false) %}
|
||||||
|
|
||||||
{% set select_current_time -%}
|
{% set select_current_time -%}
|
||||||
select {{ snapshot_get_time() }} as snapshot_start
|
select {{ snapshot_get_time() }} as snapshot_start
|
||||||
{%- endset %}
|
{%- endset %}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
ProfileConfigDocs = 'https://docs.getdbt.com/docs/configure-your-profile'
|
ProfileConfigDocs = "https://docs.getdbt.com/docs/configure-your-profile"
|
||||||
SnowflakeQuotingDocs = 'https://docs.getdbt.com/v0.10/docs/configuring-quoting'
|
SnowflakeQuotingDocs = "https://docs.getdbt.com/v0.10/docs/configuring-quoting"
|
||||||
IncrementalDocs = 'https://docs.getdbt.com/docs/configuring-incremental-models'
|
IncrementalDocs = "https://docs.getdbt.com/docs/configuring-incremental-models"
|
||||||
BigQueryNewPartitionBy = 'https://docs.getdbt.com/docs/upgrading-to-0-16-0'
|
BigQueryNewPartitionBy = "https://docs.getdbt.com/docs/upgrading-to-0-16-0"
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from typing import Optional, List, ContextManager, Callable, Dict, Any, Set
|
|||||||
|
|
||||||
import colorama
|
import colorama
|
||||||
import logbook
|
import logbook
|
||||||
from hologram import JsonSchemaMixin
|
from dbt.dataclass_schema import dbtClassMixin
|
||||||
|
|
||||||
# Colorama needs some help on windows because we're using logger.info
|
# Colorama needs some help on windows because we're using logger.info
|
||||||
# intead of print(). If the Windows env doesn't have a TERM var set,
|
# intead of print(). If the Windows env doesn't have a TERM var set,
|
||||||
@@ -26,30 +26,29 @@ colorama_wrap = True
|
|||||||
colorama.init(wrap=colorama_wrap)
|
colorama.init(wrap=colorama_wrap)
|
||||||
|
|
||||||
|
|
||||||
if sys.platform == 'win32' and not os.getenv('TERM'):
|
if sys.platform == "win32" and not os.getenv("TERM"):
|
||||||
colorama_wrap = False
|
colorama_wrap = False
|
||||||
colorama_stdout = colorama.AnsiToWin32(sys.stdout).stream
|
colorama_stdout = colorama.AnsiToWin32(sys.stdout).stream
|
||||||
|
|
||||||
elif sys.platform == 'win32':
|
elif sys.platform == "win32":
|
||||||
colorama_wrap = False
|
colorama_wrap = False
|
||||||
|
|
||||||
colorama.init(wrap=colorama_wrap)
|
colorama.init(wrap=colorama_wrap)
|
||||||
|
|
||||||
|
|
||||||
STDOUT_LOG_FORMAT = '{record.message}'
|
STDOUT_LOG_FORMAT = "{record.message}"
|
||||||
DEBUG_LOG_FORMAT = (
|
DEBUG_LOG_FORMAT = (
|
||||||
'{record.time:%Y-%m-%d %H:%M:%S.%f%z} '
|
"{record.time:%Y-%m-%d %H:%M:%S.%f%z} "
|
||||||
'({record.thread_name}): '
|
"({record.thread_name}): "
|
||||||
'{record.message}'
|
"{record.message}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
ExceptionInformation = str
|
ExceptionInformation = str
|
||||||
Extras = Dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LogMessage(JsonSchemaMixin):
|
class LogMessage(dbtClassMixin):
|
||||||
timestamp: datetime
|
timestamp: datetime
|
||||||
message: str
|
message: str
|
||||||
channel: str
|
channel: str
|
||||||
@@ -57,7 +56,7 @@ class LogMessage(JsonSchemaMixin):
|
|||||||
levelname: str
|
levelname: str
|
||||||
thread_name: str
|
thread_name: str
|
||||||
process: int
|
process: int
|
||||||
extra: Optional[Extras] = None
|
extra: Optional[Dict[str, Any]] = None
|
||||||
exc_info: Optional[ExceptionInformation] = None
|
exc_info: Optional[ExceptionInformation] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -95,8 +94,10 @@ class JsonFormatter(LogMessageFormatter):
|
|||||||
"""Return a the record converted to LogMessage's JSON form"""
|
"""Return a the record converted to LogMessage's JSON form"""
|
||||||
# utils imports exceptions which imports logger...
|
# utils imports exceptions which imports logger...
|
||||||
import dbt.utils
|
import dbt.utils
|
||||||
|
|
||||||
log_message = super().__call__(record, handler)
|
log_message = super().__call__(record, handler)
|
||||||
return json.dumps(log_message.to_dict(), cls=dbt.utils.JSONEncoder)
|
dct = log_message.to_dict(omit_none=True)
|
||||||
|
return json.dumps(dct, cls=dbt.utils.JSONEncoder)
|
||||||
|
|
||||||
|
|
||||||
class FormatterMixin:
|
class FormatterMixin:
|
||||||
@@ -117,9 +118,7 @@ class FormatterMixin:
|
|||||||
self.format_string = self._text_format_string
|
self.format_string = self._text_format_string
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError("reset() not implemented in FormatterMixin subclass")
|
||||||
'reset() not implemented in FormatterMixin subclass'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class OutputHandler(logbook.StreamHandler, FormatterMixin):
|
class OutputHandler(logbook.StreamHandler, FormatterMixin):
|
||||||
@@ -128,6 +127,7 @@ class OutputHandler(logbook.StreamHandler, FormatterMixin):
|
|||||||
The `format_string` parameter only changes the default text output, not
|
The `format_string` parameter only changes the default text output, not
|
||||||
debug mode or json.
|
debug mode or json.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
stream,
|
stream,
|
||||||
@@ -163,9 +163,9 @@ class OutputHandler(logbook.StreamHandler, FormatterMixin):
|
|||||||
if record.level < self.level:
|
if record.level < self.level:
|
||||||
return False
|
return False
|
||||||
text_mode = self.formatter_class is logbook.StringFormatter
|
text_mode = self.formatter_class is logbook.StringFormatter
|
||||||
if text_mode and record.extra.get('json_only', False):
|
if text_mode and record.extra.get("json_only", False):
|
||||||
return False
|
return False
|
||||||
elif not text_mode and record.extra.get('text_only', False):
|
elif not text_mode and record.extra.get("text_only", False):
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
@@ -176,7 +176,7 @@ def _redirect_std_logging():
|
|||||||
|
|
||||||
|
|
||||||
def _root_channel(record: logbook.LogRecord) -> str:
|
def _root_channel(record: logbook.LogRecord) -> str:
|
||||||
return record.channel.split('.')[0]
|
return record.channel.split(".")[0]
|
||||||
|
|
||||||
|
|
||||||
class Relevel(logbook.Processor):
|
class Relevel(logbook.Processor):
|
||||||
@@ -194,7 +194,7 @@ class Relevel(logbook.Processor):
|
|||||||
def process(self, record):
|
def process(self, record):
|
||||||
if _root_channel(record) in self.allowed:
|
if _root_channel(record) in self.allowed:
|
||||||
return
|
return
|
||||||
record.extra['old_level'] = record.level
|
record.extra["old_level"] = record.level
|
||||||
# suppress logs at/below our min level by lowering them to NOTSET
|
# suppress logs at/below our min level by lowering them to NOTSET
|
||||||
if record.level < self.min_level:
|
if record.level < self.min_level:
|
||||||
record.level = logbook.NOTSET
|
record.level = logbook.NOTSET
|
||||||
@@ -206,22 +206,22 @@ class Relevel(logbook.Processor):
|
|||||||
|
|
||||||
class JsonOnly(logbook.Processor):
|
class JsonOnly(logbook.Processor):
|
||||||
def process(self, record):
|
def process(self, record):
|
||||||
record.extra['json_only'] = True
|
record.extra["json_only"] = True
|
||||||
|
|
||||||
|
|
||||||
class TextOnly(logbook.Processor):
|
class TextOnly(logbook.Processor):
|
||||||
def process(self, record):
|
def process(self, record):
|
||||||
record.extra['text_only'] = True
|
record.extra["text_only"] = True
|
||||||
|
|
||||||
|
|
||||||
class TimingProcessor(logbook.Processor):
|
class TimingProcessor(logbook.Processor):
|
||||||
def __init__(self, timing_info: Optional[JsonSchemaMixin] = None):
|
def __init__(self, timing_info: Optional[dbtClassMixin] = None):
|
||||||
self.timing_info = timing_info
|
self.timing_info = timing_info
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def process(self, record):
|
def process(self, record):
|
||||||
if self.timing_info is not None:
|
if self.timing_info is not None:
|
||||||
record.extra['timing_info'] = self.timing_info.to_dict()
|
record.extra["timing_info"] = self.timing_info.to_dict(omit_none=True)
|
||||||
|
|
||||||
|
|
||||||
class DbtProcessState(logbook.Processor):
|
class DbtProcessState(logbook.Processor):
|
||||||
@@ -231,11 +231,10 @@ class DbtProcessState(logbook.Processor):
|
|||||||
|
|
||||||
def process(self, record):
|
def process(self, record):
|
||||||
overwrite = (
|
overwrite = (
|
||||||
'run_state' not in record.extra or
|
"run_state" not in record.extra or record.extra["run_state"] == "internal"
|
||||||
record.extra['run_state'] == 'internal'
|
|
||||||
)
|
)
|
||||||
if overwrite:
|
if overwrite:
|
||||||
record.extra['run_state'] = self.value
|
record.extra["run_state"] = self.value
|
||||||
|
|
||||||
|
|
||||||
class DbtModelState(logbook.Processor):
|
class DbtModelState(logbook.Processor):
|
||||||
@@ -249,7 +248,7 @@ class DbtModelState(logbook.Processor):
|
|||||||
|
|
||||||
class DbtStatusMessage(logbook.Processor):
|
class DbtStatusMessage(logbook.Processor):
|
||||||
def process(self, record):
|
def process(self, record):
|
||||||
record.extra['is_status_message'] = True
|
record.extra["is_status_message"] = True
|
||||||
|
|
||||||
|
|
||||||
class UniqueID(logbook.Processor):
|
class UniqueID(logbook.Processor):
|
||||||
@@ -258,7 +257,7 @@ class UniqueID(logbook.Processor):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def process(self, record):
|
def process(self, record):
|
||||||
record.extra['unique_id'] = self.unique_id
|
record.extra["unique_id"] = self.unique_id
|
||||||
|
|
||||||
|
|
||||||
class NodeCount(logbook.Processor):
|
class NodeCount(logbook.Processor):
|
||||||
@@ -267,7 +266,7 @@ class NodeCount(logbook.Processor):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def process(self, record):
|
def process(self, record):
|
||||||
record.extra['node_count'] = self.node_count
|
record.extra["node_count"] = self.node_count
|
||||||
|
|
||||||
|
|
||||||
class NodeMetadata(logbook.Processor):
|
class NodeMetadata(logbook.Processor):
|
||||||
@@ -287,26 +286,26 @@ class NodeMetadata(logbook.Processor):
|
|||||||
|
|
||||||
def process(self, record):
|
def process(self, record):
|
||||||
self.process_keys(record)
|
self.process_keys(record)
|
||||||
record.extra['node_index'] = self.index
|
record.extra["node_index"] = self.index
|
||||||
|
|
||||||
|
|
||||||
class ModelMetadata(NodeMetadata):
|
class ModelMetadata(NodeMetadata):
|
||||||
def mapping_keys(self):
|
def mapping_keys(self):
|
||||||
return [
|
return [
|
||||||
('alias', 'node_alias'),
|
("alias", "node_alias"),
|
||||||
('schema', 'node_schema'),
|
("schema", "node_schema"),
|
||||||
('database', 'node_database'),
|
("database", "node_database"),
|
||||||
('original_file_path', 'node_path'),
|
("original_file_path", "node_path"),
|
||||||
('name', 'node_name'),
|
("name", "node_name"),
|
||||||
('resource_type', 'resource_type'),
|
("resource_type", "resource_type"),
|
||||||
('depends_on_nodes', 'depends_on'),
|
("depends_on_nodes", "depends_on"),
|
||||||
]
|
]
|
||||||
|
|
||||||
def process_config(self, record):
|
def process_config(self, record):
|
||||||
if hasattr(self.node, 'config'):
|
if hasattr(self.node, "config"):
|
||||||
materialized = getattr(self.node.config, 'materialized', None)
|
materialized = getattr(self.node.config, "materialized", None)
|
||||||
if materialized is not None:
|
if materialized is not None:
|
||||||
record.extra['node_materialized'] = materialized
|
record.extra["node_materialized"] = materialized
|
||||||
|
|
||||||
def process(self, record):
|
def process(self, record):
|
||||||
super().process(record)
|
super().process(record)
|
||||||
@@ -316,8 +315,8 @@ class ModelMetadata(NodeMetadata):
|
|||||||
class HookMetadata(NodeMetadata):
|
class HookMetadata(NodeMetadata):
|
||||||
def mapping_keys(self):
|
def mapping_keys(self):
|
||||||
return [
|
return [
|
||||||
('name', 'node_name'),
|
("name", "node_name"),
|
||||||
('resource_type', 'resource_type'),
|
("resource_type", "resource_type"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -331,29 +330,31 @@ class TimestampNamed(logbook.Processor):
|
|||||||
record.extra[self.name] = datetime.utcnow().isoformat()
|
record.extra[self.name] = datetime.utcnow().isoformat()
|
||||||
|
|
||||||
|
|
||||||
logger = logbook.Logger('dbt')
|
logger = logbook.Logger("dbt")
|
||||||
# provide this for the cache, disabled by default
|
# provide this for the cache, disabled by default
|
||||||
CACHE_LOGGER = logbook.Logger('dbt.cache')
|
CACHE_LOGGER = logbook.Logger("dbt.cache")
|
||||||
CACHE_LOGGER.disable()
|
CACHE_LOGGER.disable()
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=ResourceWarning,
|
warnings.filterwarnings(
|
||||||
message="unclosed.*<socket.socket.*>")
|
"ignore", category=ResourceWarning, message="unclosed.*<socket.socket.*>"
|
||||||
|
)
|
||||||
|
|
||||||
initialized = False
|
initialized = False
|
||||||
|
|
||||||
|
|
||||||
def make_log_dir_if_missing(log_dir):
|
def make_log_dir_if_missing(log_dir):
|
||||||
import dbt.clients.system
|
import dbt.clients.system
|
||||||
|
|
||||||
dbt.clients.system.make_directory(log_dir)
|
dbt.clients.system.make_directory(log_dir)
|
||||||
|
|
||||||
|
|
||||||
class DebugWarnings(logbook.compat.redirected_warnings):
|
class DebugWarnings(logbook.compat.redirected_warnings):
|
||||||
"""Log warnings, except send them to 'debug' instead of 'warning' level.
|
"""Log warnings, except send them to 'debug' instead of 'warning' level."""
|
||||||
"""
|
|
||||||
def make_record(self, message, exception, filename, lineno):
|
def make_record(self, message, exception, filename, lineno):
|
||||||
rv = super().make_record(message, exception, filename, lineno)
|
rv = super().make_record(message, exception, filename, lineno)
|
||||||
rv.level = logbook.DEBUG
|
rv.level = logbook.DEBUG
|
||||||
rv.extra['from_warnings'] = True
|
rv.extra["from_warnings"] = True
|
||||||
return rv
|
return rv
|
||||||
|
|
||||||
|
|
||||||
@@ -405,14 +406,14 @@ class DelayedFileHandler(logbook.RotatingFileHandler, FormatterMixin):
|
|||||||
if self.disabled:
|
if self.disabled:
|
||||||
return
|
return
|
||||||
|
|
||||||
assert not self.initialized, 'set_path called after being set'
|
assert not self.initialized, "set_path called after being set"
|
||||||
|
|
||||||
if log_dir is None:
|
if log_dir is None:
|
||||||
self.disabled = True
|
self.disabled = True
|
||||||
return
|
return
|
||||||
|
|
||||||
make_log_dir_if_missing(log_dir)
|
make_log_dir_if_missing(log_dir)
|
||||||
log_path = os.path.join(log_dir, 'dbt.log')
|
log_path = os.path.join(log_dir, "dbt.log")
|
||||||
self._super_init(log_path)
|
self._super_init(log_path)
|
||||||
self._replay_buffered()
|
self._replay_buffered()
|
||||||
self._log_path = log_path
|
self._log_path = log_path
|
||||||
@@ -432,8 +433,9 @@ class DelayedFileHandler(logbook.RotatingFileHandler, FormatterMixin):
|
|||||||
FormatterMixin.__init__(self, DEBUG_LOG_FORMAT)
|
FormatterMixin.__init__(self, DEBUG_LOG_FORMAT)
|
||||||
|
|
||||||
def _replay_buffered(self):
|
def _replay_buffered(self):
|
||||||
assert self._msg_buffer is not None, \
|
assert (
|
||||||
'_msg_buffer should never be None in _replay_buffered'
|
self._msg_buffer is not None
|
||||||
|
), "_msg_buffer should never be None in _replay_buffered"
|
||||||
for record in self._msg_buffer:
|
for record in self._msg_buffer:
|
||||||
super().emit(record)
|
super().emit(record)
|
||||||
self._msg_buffer = None
|
self._msg_buffer = None
|
||||||
@@ -442,7 +444,7 @@ class DelayedFileHandler(logbook.RotatingFileHandler, FormatterMixin):
|
|||||||
msg = super().format(record)
|
msg = super().format(record)
|
||||||
subbed = str(msg)
|
subbed = str(msg)
|
||||||
for escape_sequence in dbt.ui.COLORS.values():
|
for escape_sequence in dbt.ui.COLORS.values():
|
||||||
subbed = subbed.replace(escape_sequence, '')
|
subbed = subbed.replace(escape_sequence, "")
|
||||||
return subbed
|
return subbed
|
||||||
|
|
||||||
def emit(self, record: logbook.LogRecord):
|
def emit(self, record: logbook.LogRecord):
|
||||||
@@ -454,11 +456,13 @@ class DelayedFileHandler(logbook.RotatingFileHandler, FormatterMixin):
|
|||||||
elif self.initialized:
|
elif self.initialized:
|
||||||
super().emit(record)
|
super().emit(record)
|
||||||
else:
|
else:
|
||||||
assert self._msg_buffer is not None, \
|
assert (
|
||||||
'_msg_buffer should never be None if _log_path is set'
|
self._msg_buffer is not None
|
||||||
|
), "_msg_buffer should never be None if _log_path is set"
|
||||||
self._msg_buffer.append(record)
|
self._msg_buffer.append(record)
|
||||||
assert len(self._msg_buffer) < self._bufmax, \
|
assert (
|
||||||
'too many messages received before initilization!'
|
len(self._msg_buffer) < self._bufmax
|
||||||
|
), "too many messages received before initilization!"
|
||||||
|
|
||||||
|
|
||||||
class LogManager(logbook.NestedSetup):
|
class LogManager(logbook.NestedSetup):
|
||||||
@@ -468,19 +472,21 @@ class LogManager(logbook.NestedSetup):
|
|||||||
self._null_handler = logbook.NullHandler()
|
self._null_handler = logbook.NullHandler()
|
||||||
self._output_handler = OutputHandler(self.stdout)
|
self._output_handler = OutputHandler(self.stdout)
|
||||||
self._file_handler = DelayedFileHandler()
|
self._file_handler = DelayedFileHandler()
|
||||||
self._relevel_processor = Relevel(allowed=['dbt', 'werkzeug'])
|
self._relevel_processor = Relevel(allowed=["dbt", "werkzeug"])
|
||||||
self._state_processor = DbtProcessState('internal')
|
self._state_processor = DbtProcessState("internal")
|
||||||
# keep track of wheter we've already entered to decide if we should
|
# keep track of wheter we've already entered to decide if we should
|
||||||
# be actually pushing. This allows us to log in main() and also
|
# be actually pushing. This allows us to log in main() and also
|
||||||
# support entering dbt execution via handle_and_check.
|
# support entering dbt execution via handle_and_check.
|
||||||
self._stack_depth = 0
|
self._stack_depth = 0
|
||||||
super().__init__([
|
super().__init__(
|
||||||
self._null_handler,
|
[
|
||||||
self._output_handler,
|
self._null_handler,
|
||||||
self._file_handler,
|
self._output_handler,
|
||||||
self._relevel_processor,
|
self._file_handler,
|
||||||
self._state_processor,
|
self._relevel_processor,
|
||||||
])
|
self._state_processor,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def push_application(self):
|
def push_application(self):
|
||||||
self._stack_depth += 1
|
self._stack_depth += 1
|
||||||
@@ -496,8 +502,7 @@ class LogManager(logbook.NestedSetup):
|
|||||||
self.add_handler(logbook.NullHandler())
|
self.add_handler(logbook.NullHandler())
|
||||||
|
|
||||||
def add_handler(self, handler):
|
def add_handler(self, handler):
|
||||||
"""add an handler to the log manager that runs before the file handler.
|
"""add an handler to the log manager that runs before the file handler."""
|
||||||
"""
|
|
||||||
self.objects.append(handler)
|
self.objects.append(handler)
|
||||||
|
|
||||||
# this is used by `dbt ls` to allow piping stdout to jq, etc
|
# this is used by `dbt ls` to allow piping stdout to jq, etc
|
||||||
@@ -555,8 +560,7 @@ log_manager = LogManager()
|
|||||||
|
|
||||||
|
|
||||||
def log_cache_events(flag):
|
def log_cache_events(flag):
|
||||||
"""Set the cache logger to propagate its messages based on the given flag.
|
"""Set the cache logger to propagate its messages based on the given flag."""
|
||||||
"""
|
|
||||||
# the flag is True if we should log, and False if we shouldn't, so disabled
|
# the flag is True if we should log, and False if we shouldn't, so disabled
|
||||||
# is the inverse.
|
# is the inverse.
|
||||||
CACHE_LOGGER.disabled = not flag
|
CACHE_LOGGER.disabled = not flag
|
||||||
@@ -580,7 +584,7 @@ class ListLogHandler(LogMessageHandler):
|
|||||||
level: int = logbook.NOTSET,
|
level: int = logbook.NOTSET,
|
||||||
filter: Callable = None,
|
filter: Callable = None,
|
||||||
bubble: bool = False,
|
bubble: bool = False,
|
||||||
lst: Optional[List[LogMessage]] = None
|
lst: Optional[List[LogMessage]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(level, filter, bubble)
|
super().__init__(level, filter, bubble)
|
||||||
if lst is None:
|
if lst is None:
|
||||||
@@ -589,7 +593,7 @@ class ListLogHandler(LogMessageHandler):
|
|||||||
|
|
||||||
def should_handle(self, record):
|
def should_handle(self, record):
|
||||||
"""Only ever emit dbt-sourced log messages to the ListHandler."""
|
"""Only ever emit dbt-sourced log messages to the ListHandler."""
|
||||||
if _root_channel(record) != 'dbt':
|
if _root_channel(record) != "dbt":
|
||||||
return False
|
return False
|
||||||
return super().should_handle(record)
|
return super().should_handle(record)
|
||||||
|
|
||||||
@@ -606,28 +610,27 @@ def _env_log_level(var_name: str) -> int:
|
|||||||
return logging.ERROR
|
return logging.ERROR
|
||||||
|
|
||||||
|
|
||||||
LOG_LEVEL_GOOGLE = _env_log_level('DBT_GOOGLE_DEBUG_LOGGING')
|
LOG_LEVEL_GOOGLE = _env_log_level("DBT_GOOGLE_DEBUG_LOGGING")
|
||||||
LOG_LEVEL_SNOWFLAKE = _env_log_level('DBT_SNOWFLAKE_CONNECTOR_DEBUG_LOGGING')
|
LOG_LEVEL_SNOWFLAKE = _env_log_level("DBT_SNOWFLAKE_CONNECTOR_DEBUG_LOGGING")
|
||||||
LOG_LEVEL_BOTOCORE = _env_log_level('DBT_BOTOCORE_DEBUG_LOGGING')
|
LOG_LEVEL_BOTOCORE = _env_log_level("DBT_BOTOCORE_DEBUG_LOGGING")
|
||||||
LOG_LEVEL_HTTP = _env_log_level('DBT_HTTP_DEBUG_LOGGING')
|
LOG_LEVEL_HTTP = _env_log_level("DBT_HTTP_DEBUG_LOGGING")
|
||||||
LOG_LEVEL_WERKZEUG = _env_log_level('DBT_WERKZEUG_DEBUG_LOGGING')
|
LOG_LEVEL_WERKZEUG = _env_log_level("DBT_WERKZEUG_DEBUG_LOGGING")
|
||||||
|
|
||||||
logging.getLogger('botocore').setLevel(LOG_LEVEL_BOTOCORE)
|
logging.getLogger("botocore").setLevel(LOG_LEVEL_BOTOCORE)
|
||||||
logging.getLogger('requests').setLevel(LOG_LEVEL_HTTP)
|
logging.getLogger("requests").setLevel(LOG_LEVEL_HTTP)
|
||||||
logging.getLogger('urllib3').setLevel(LOG_LEVEL_HTTP)
|
logging.getLogger("urllib3").setLevel(LOG_LEVEL_HTTP)
|
||||||
logging.getLogger('google').setLevel(LOG_LEVEL_GOOGLE)
|
logging.getLogger("google").setLevel(LOG_LEVEL_GOOGLE)
|
||||||
logging.getLogger('snowflake.connector').setLevel(LOG_LEVEL_SNOWFLAKE)
|
logging.getLogger("snowflake.connector").setLevel(LOG_LEVEL_SNOWFLAKE)
|
||||||
|
|
||||||
logging.getLogger('parsedatetime').setLevel(logging.ERROR)
|
logging.getLogger("parsedatetime").setLevel(logging.ERROR)
|
||||||
logging.getLogger('werkzeug').setLevel(LOG_LEVEL_WERKZEUG)
|
logging.getLogger("werkzeug").setLevel(LOG_LEVEL_WERKZEUG)
|
||||||
|
|
||||||
|
|
||||||
def list_handler(
|
def list_handler(
|
||||||
lst: Optional[List[LogMessage]],
|
lst: Optional[List[LogMessage]],
|
||||||
level=logbook.NOTSET,
|
level=logbook.NOTSET,
|
||||||
) -> ContextManager:
|
) -> ContextManager:
|
||||||
"""Return a context manager that temporarly attaches a list to the logger.
|
"""Return a context manager that temporarly attaches a list to the logger."""
|
||||||
"""
|
|
||||||
return ListLogHandler(lst=lst, level=level, bubble=True)
|
return ListLogHandler(lst=lst, level=level, bubble=True)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
672
core/dbt/main.py
672
core/dbt/main.py
File diff suppressed because it is too large
Load Diff
@@ -1,23 +1,23 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from hologram.helpers import StrEnum
|
from dbt.dataclass_schema import StrEnum
|
||||||
|
|
||||||
|
|
||||||
class NodeType(StrEnum):
|
class NodeType(StrEnum):
|
||||||
Model = 'model'
|
Model = "model"
|
||||||
Analysis = 'analysis'
|
Analysis = "analysis"
|
||||||
Test = 'test'
|
Test = "test"
|
||||||
Snapshot = 'snapshot'
|
Snapshot = "snapshot"
|
||||||
Operation = 'operation'
|
Operation = "operation"
|
||||||
Seed = 'seed'
|
Seed = "seed"
|
||||||
RPCCall = 'rpc'
|
RPCCall = "rpc"
|
||||||
Documentation = 'docs'
|
Documentation = "docs"
|
||||||
Source = 'source'
|
Source = "source"
|
||||||
Macro = 'macro'
|
Macro = "macro"
|
||||||
Exposure = 'exposure'
|
Exposure = "exposure"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def executable(cls) -> List['NodeType']:
|
def executable(cls) -> List["NodeType"]:
|
||||||
return [
|
return [
|
||||||
cls.Model,
|
cls.Model,
|
||||||
cls.Test,
|
cls.Test,
|
||||||
@@ -30,7 +30,7 @@ class NodeType(StrEnum):
|
|||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def refable(cls) -> List['NodeType']:
|
def refable(cls) -> List["NodeType"]:
|
||||||
return [
|
return [
|
||||||
cls.Model,
|
cls.Model,
|
||||||
cls.Seed,
|
cls.Seed,
|
||||||
@@ -38,7 +38,7 @@ class NodeType(StrEnum):
|
|||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def documentable(cls) -> List['NodeType']:
|
def documentable(cls) -> List["NodeType"]:
|
||||||
return [
|
return [
|
||||||
cls.Model,
|
cls.Model,
|
||||||
cls.Seed,
|
cls.Seed,
|
||||||
@@ -46,16 +46,16 @@ class NodeType(StrEnum):
|
|||||||
cls.Source,
|
cls.Source,
|
||||||
cls.Macro,
|
cls.Macro,
|
||||||
cls.Analysis,
|
cls.Analysis,
|
||||||
cls.Exposure
|
cls.Exposure,
|
||||||
]
|
]
|
||||||
|
|
||||||
def pluralize(self) -> str:
|
def pluralize(self) -> str:
|
||||||
if self == 'analysis':
|
if self == "analysis":
|
||||||
return 'analyses'
|
return "analyses"
|
||||||
else:
|
else:
|
||||||
return f'{self}s'
|
return f"{self}s"
|
||||||
|
|
||||||
|
|
||||||
class RunHookType(StrEnum):
|
class RunHookType(StrEnum):
|
||||||
Start = 'on-run-start'
|
Start = "on-run-start"
|
||||||
End = 'on-run-end'
|
End = "on-run-end"
|
||||||
|
|||||||
@@ -11,6 +11,14 @@ from .seeds import SeedParser # noqa
|
|||||||
from .snapshots import SnapshotParser # noqa
|
from .snapshots import SnapshotParser # noqa
|
||||||
|
|
||||||
from . import ( # noqa
|
from . import ( # noqa
|
||||||
analysis, base, data_test, docs, hooks, macros, models, results, schemas,
|
analysis,
|
||||||
snapshots
|
base,
|
||||||
|
data_test,
|
||||||
|
docs,
|
||||||
|
hooks,
|
||||||
|
macros,
|
||||||
|
models,
|
||||||
|
results,
|
||||||
|
schemas,
|
||||||
|
snapshots,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,12 +8,12 @@ from dbt.parser.search import FilesystemSearcher, FileBlock
|
|||||||
|
|
||||||
class AnalysisParser(SimpleSQLParser[ParsedAnalysisNode]):
|
class AnalysisParser(SimpleSQLParser[ParsedAnalysisNode]):
|
||||||
def get_paths(self):
|
def get_paths(self):
|
||||||
return FilesystemSearcher(
|
return FilesystemSearcher(self.project, self.project.analysis_paths, ".sql")
|
||||||
self.project, self.project.analysis_paths, '.sql'
|
|
||||||
)
|
|
||||||
|
|
||||||
def parse_from_dict(self, dct, validate=True) -> ParsedAnalysisNode:
|
def parse_from_dict(self, dct, validate=True) -> ParsedAnalysisNode:
|
||||||
return ParsedAnalysisNode.from_dict(dct, validate=validate)
|
if validate:
|
||||||
|
ParsedAnalysisNode.validate(dct)
|
||||||
|
return ParsedAnalysisNode.from_dict(dct)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def resource_type(self) -> NodeType:
|
def resource_type(self) -> NodeType:
|
||||||
@@ -21,4 +21,4 @@ class AnalysisParser(SimpleSQLParser[ParsedAnalysisNode]):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_compiled_path(cls, block: FileBlock):
|
def get_compiled_path(cls, block: FileBlock):
|
||||||
return os.path.join('analysis', block.path.relative_path)
|
return os.path.join("analysis", block.path.relative_path)
|
||||||
|
|||||||
@@ -1,11 +1,9 @@
|
|||||||
import abc
|
import abc
|
||||||
import itertools
|
import itertools
|
||||||
import os
|
import os
|
||||||
from typing import (
|
from typing import List, Dict, Any, Iterable, Generic, TypeVar
|
||||||
List, Dict, Any, Iterable, Generic, TypeVar
|
|
||||||
)
|
|
||||||
|
|
||||||
from hologram import ValidationError
|
from dbt.dataclass_schema import ValidationError
|
||||||
|
|
||||||
from dbt import utils
|
from dbt import utils
|
||||||
from dbt.clients.jinja import MacroGenerator
|
from dbt.clients.jinja import MacroGenerator
|
||||||
@@ -17,17 +15,15 @@ from dbt.context.providers import (
|
|||||||
from dbt.adapters.factory import get_adapter
|
from dbt.adapters.factory import get_adapter
|
||||||
from dbt.clients.jinja import get_rendered
|
from dbt.clients.jinja import get_rendered
|
||||||
from dbt.config import Project, RuntimeConfig
|
from dbt.config import Project, RuntimeConfig
|
||||||
from dbt.context.context_config import (
|
from dbt.context.context_config import ContextConfig
|
||||||
ContextConfig
|
from dbt.contracts.files import SourceFile, FilePath, FileHash
|
||||||
)
|
from dbt.contracts.graph.manifest import MacroManifest
|
||||||
from dbt.contracts.files import (
|
|
||||||
SourceFile, FilePath, FileHash
|
|
||||||
)
|
|
||||||
from dbt.contracts.graph.manifest import Manifest
|
|
||||||
from dbt.contracts.graph.parsed import HasUniqueID
|
from dbt.contracts.graph.parsed import HasUniqueID
|
||||||
from dbt.contracts.graph.unparsed import UnparsedNode
|
from dbt.contracts.graph.unparsed import UnparsedNode
|
||||||
from dbt.exceptions import (
|
from dbt.exceptions import (
|
||||||
CompilationException, validator_error_message, InternalException
|
CompilationException,
|
||||||
|
validator_error_message,
|
||||||
|
InternalException,
|
||||||
)
|
)
|
||||||
from dbt import hooks
|
from dbt import hooks
|
||||||
from dbt.node_types import NodeType
|
from dbt.node_types import NodeType
|
||||||
@@ -37,14 +33,14 @@ from dbt.parser.search import FileBlock
|
|||||||
# internally, the parser may store a less-restrictive type that will be
|
# internally, the parser may store a less-restrictive type that will be
|
||||||
# transformed into the final type. But it will have to be derived from
|
# transformed into the final type. But it will have to be derived from
|
||||||
# ParsedNode to be operable.
|
# ParsedNode to be operable.
|
||||||
FinalValue = TypeVar('FinalValue', bound=HasUniqueID)
|
FinalValue = TypeVar("FinalValue", bound=HasUniqueID)
|
||||||
IntermediateValue = TypeVar('IntermediateValue', bound=HasUniqueID)
|
IntermediateValue = TypeVar("IntermediateValue", bound=HasUniqueID)
|
||||||
|
|
||||||
IntermediateNode = TypeVar('IntermediateNode', bound=Any)
|
IntermediateNode = TypeVar("IntermediateNode", bound=Any)
|
||||||
FinalNode = TypeVar('FinalNode', bound=ManifestNodes)
|
FinalNode = TypeVar("FinalNode", bound=ManifestNodes)
|
||||||
|
|
||||||
|
|
||||||
ConfiguredBlockType = TypeVar('ConfiguredBlockType', bound=FileBlock)
|
ConfiguredBlockType = TypeVar("ConfiguredBlockType", bound=FileBlock)
|
||||||
|
|
||||||
|
|
||||||
class BaseParser(Generic[FinalValue]):
|
class BaseParser(Generic[FinalValue]):
|
||||||
@@ -73,9 +69,9 @@ class BaseParser(Generic[FinalValue]):
|
|||||||
|
|
||||||
def generate_unique_id(self, resource_name: str) -> str:
|
def generate_unique_id(self, resource_name: str) -> str:
|
||||||
"""Returns a unique identifier for a resource"""
|
"""Returns a unique identifier for a resource"""
|
||||||
return "{}.{}.{}".format(self.resource_type,
|
return "{}.{}.{}".format(
|
||||||
self.project.project_name,
|
self.resource_type, self.project.project_name, resource_name
|
||||||
resource_name)
|
)
|
||||||
|
|
||||||
def load_file(
|
def load_file(
|
||||||
self,
|
self,
|
||||||
@@ -89,7 +85,7 @@ class BaseParser(Generic[FinalValue]):
|
|||||||
if set_contents:
|
if set_contents:
|
||||||
source_file.contents = file_contents.strip()
|
source_file.contents = file_contents.strip()
|
||||||
else:
|
else:
|
||||||
source_file.contents = ''
|
source_file.contents = ""
|
||||||
return source_file
|
return source_file
|
||||||
|
|
||||||
|
|
||||||
@@ -99,7 +95,7 @@ class Parser(BaseParser[FinalValue], Generic[FinalValue]):
|
|||||||
results: ParseResult,
|
results: ParseResult,
|
||||||
project: Project,
|
project: Project,
|
||||||
root_project: RuntimeConfig,
|
root_project: RuntimeConfig,
|
||||||
macro_manifest: Manifest,
|
macro_manifest: MacroManifest,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(results, project)
|
super().__init__(results, project)
|
||||||
self.root_project = root_project
|
self.root_project = root_project
|
||||||
@@ -108,26 +104,24 @@ class Parser(BaseParser[FinalValue], Generic[FinalValue]):
|
|||||||
|
|
||||||
class RelationUpdate:
|
class RelationUpdate:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, config: RuntimeConfig, manifest: Manifest, component: str
|
self, config: RuntimeConfig, macro_manifest: MacroManifest, component: str
|
||||||
) -> None:
|
) -> None:
|
||||||
macro = manifest.find_generate_macro_by_name(
|
macro = macro_manifest.find_generate_macro_by_name(
|
||||||
component=component,
|
component=component,
|
||||||
root_project_name=config.project_name,
|
root_project_name=config.project_name,
|
||||||
)
|
)
|
||||||
if macro is None:
|
if macro is None:
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
f'No macro with name generate_{component}_name found'
|
f"No macro with name generate_{component}_name found"
|
||||||
)
|
)
|
||||||
|
|
||||||
root_context = generate_generate_component_name_macro(
|
root_context = generate_generate_component_name_macro(
|
||||||
macro, config, manifest
|
macro, config, macro_manifest
|
||||||
)
|
)
|
||||||
self.updater = MacroGenerator(macro, root_context)
|
self.updater = MacroGenerator(macro, root_context)
|
||||||
self.component = component
|
self.component = component
|
||||||
|
|
||||||
def __call__(
|
def __call__(self, parsed_node: Any, config_dict: Dict[str, Any]) -> None:
|
||||||
self, parsed_node: Any, config_dict: Dict[str, Any]
|
|
||||||
) -> None:
|
|
||||||
override = config_dict.get(self.component)
|
override = config_dict.get(self.component)
|
||||||
new_value = self.updater(override, parsed_node)
|
new_value = self.updater(override, parsed_node)
|
||||||
if isinstance(new_value, str):
|
if isinstance(new_value, str):
|
||||||
@@ -144,18 +138,18 @@ class ConfiguredParser(
|
|||||||
results: ParseResult,
|
results: ParseResult,
|
||||||
project: Project,
|
project: Project,
|
||||||
root_project: RuntimeConfig,
|
root_project: RuntimeConfig,
|
||||||
macro_manifest: Manifest,
|
macro_manifest: MacroManifest,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(results, project, root_project, macro_manifest)
|
super().__init__(results, project, root_project, macro_manifest)
|
||||||
|
|
||||||
self._update_node_database = RelationUpdate(
|
self._update_node_database = RelationUpdate(
|
||||||
manifest=macro_manifest, config=root_project, component='database'
|
macro_manifest=macro_manifest, config=root_project, component="database"
|
||||||
)
|
)
|
||||||
self._update_node_schema = RelationUpdate(
|
self._update_node_schema = RelationUpdate(
|
||||||
manifest=macro_manifest, config=root_project, component='schema'
|
macro_manifest=macro_manifest, config=root_project, component="schema"
|
||||||
)
|
)
|
||||||
self._update_node_alias = RelationUpdate(
|
self._update_node_alias = RelationUpdate(
|
||||||
manifest=macro_manifest, config=root_project, component='alias'
|
macro_manifest=macro_manifest, config=root_project, component="alias"
|
||||||
)
|
)
|
||||||
|
|
||||||
@abc.abstractclassmethod
|
@abc.abstractclassmethod
|
||||||
@@ -202,7 +196,11 @@ class ConfiguredParser(
|
|||||||
config[key] = [hooks.get_hook_dict(h) for h in config[key]]
|
config[key] = [hooks.get_hook_dict(h) for h in config[key]]
|
||||||
|
|
||||||
def _create_error_node(
|
def _create_error_node(
|
||||||
self, name: str, path: str, original_file_path: str, raw_sql: str,
|
self,
|
||||||
|
name: str,
|
||||||
|
path: str,
|
||||||
|
original_file_path: str,
|
||||||
|
raw_sql: str,
|
||||||
) -> UnparsedNode:
|
) -> UnparsedNode:
|
||||||
"""If we hit an error before we've actually parsed a node, provide some
|
"""If we hit an error before we've actually parsed a node, provide some
|
||||||
level of useful information by attaching this to the exception.
|
level of useful information by attaching this to the exception.
|
||||||
@@ -235,24 +233,24 @@ class ConfiguredParser(
|
|||||||
if name is None:
|
if name is None:
|
||||||
name = block.name
|
name = block.name
|
||||||
dct = {
|
dct = {
|
||||||
'alias': name,
|
"alias": name,
|
||||||
'schema': self.default_schema,
|
"schema": self.default_schema,
|
||||||
'database': self.default_database,
|
"database": self.default_database,
|
||||||
'fqn': fqn,
|
"fqn": fqn,
|
||||||
'name': name,
|
"name": name,
|
||||||
'root_path': self.project.project_root,
|
"root_path": self.project.project_root,
|
||||||
'resource_type': self.resource_type,
|
"resource_type": self.resource_type,
|
||||||
'path': path,
|
"path": path,
|
||||||
'original_file_path': block.path.original_file_path,
|
"original_file_path": block.path.original_file_path,
|
||||||
'package_name': self.project.project_name,
|
"package_name": self.project.project_name,
|
||||||
'raw_sql': block.contents,
|
"raw_sql": block.contents,
|
||||||
'unique_id': self.generate_unique_id(name),
|
"unique_id": self.generate_unique_id(name),
|
||||||
'config': self.config_dict(config),
|
"config": self.config_dict(config),
|
||||||
'checksum': block.file.checksum.to_dict(),
|
"checksum": block.file.checksum.to_dict(omit_none=True),
|
||||||
}
|
}
|
||||||
dct.update(kwargs)
|
dct.update(kwargs)
|
||||||
try:
|
try:
|
||||||
return self.parse_from_dict(dct)
|
return self.parse_from_dict(dct, validate=True)
|
||||||
except ValidationError as exc:
|
except ValidationError as exc:
|
||||||
msg = validator_error_message(exc)
|
msg = validator_error_message(exc)
|
||||||
# this is a bit silly, but build an UnparsedNode just for error
|
# this is a bit silly, but build an UnparsedNode just for error
|
||||||
@@ -275,25 +273,27 @@ class ConfiguredParser(
|
|||||||
def render_with_context(
|
def render_with_context(
|
||||||
self, parsed_node: IntermediateNode, config: ContextConfig
|
self, parsed_node: IntermediateNode, config: ContextConfig
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Given the parsed node and a ContextConfig to use during parsing,
|
# Given the parsed node and a ContextConfig to use during parsing,
|
||||||
render the node's sql wtih macro capture enabled.
|
# render the node's sql wtih macro capture enabled.
|
||||||
|
# Note: this mutates the config object when config calls are rendered.
|
||||||
|
|
||||||
Note: this mutates the config object when config() calls are rendered.
|
|
||||||
"""
|
|
||||||
# during parsing, we don't have a connection, but we might need one, so
|
# during parsing, we don't have a connection, but we might need one, so
|
||||||
# we have to acquire it.
|
# we have to acquire it.
|
||||||
with get_adapter(self.root_project).connection_for(parsed_node):
|
with get_adapter(self.root_project).connection_for(parsed_node):
|
||||||
context = self._context_for(parsed_node, config)
|
context = self._context_for(parsed_node, config)
|
||||||
|
|
||||||
get_rendered(
|
# this goes through the process of rendering, but just throws away
|
||||||
parsed_node.raw_sql, context, parsed_node, capture_macros=True
|
# the rendered result. The "macro capture" is the point?
|
||||||
)
|
get_rendered(parsed_node.raw_sql, context, parsed_node, capture_macros=True)
|
||||||
|
|
||||||
|
# This is taking the original config for the node, converting it to a dict,
|
||||||
|
# updating the config with new config passed in, then re-creating the
|
||||||
|
# config from the dict in the node.
|
||||||
def update_parsed_node_config(
|
def update_parsed_node_config(
|
||||||
self, parsed_node: IntermediateNode, config_dict: Dict[str, Any]
|
self, parsed_node: IntermediateNode, config_dict: Dict[str, Any]
|
||||||
) -> None:
|
) -> None:
|
||||||
# Overwrite node config
|
# Overwrite node config
|
||||||
final_config_dict = parsed_node.config.to_dict()
|
final_config_dict = parsed_node.config.to_dict(omit_none=True)
|
||||||
final_config_dict.update(config_dict)
|
final_config_dict.update(config_dict)
|
||||||
# re-mangle hooks, in case we got new ones
|
# re-mangle hooks, in case we got new ones
|
||||||
self._mangle_hooks(final_config_dict)
|
self._mangle_hooks(final_config_dict)
|
||||||
@@ -316,12 +316,10 @@ class ConfiguredParser(
|
|||||||
config_dict = config.build_config_dict()
|
config_dict = config.build_config_dict()
|
||||||
|
|
||||||
# Set tags on node provided in config blocks
|
# Set tags on node provided in config blocks
|
||||||
model_tags = config_dict.get('tags', [])
|
model_tags = config_dict.get("tags", [])
|
||||||
parsed_node.tags.extend(model_tags)
|
parsed_node.tags.extend(model_tags)
|
||||||
|
|
||||||
parsed_node.unrendered_config = config.build_config_dict(
|
parsed_node.unrendered_config = config.build_config_dict(rendered=False)
|
||||||
rendered=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# do this once before we parse the node database/schema/alias, so
|
# do this once before we parse the node database/schema/alias, so
|
||||||
# parsed_node.config is what it would be if they did nothing
|
# parsed_node.config is what it would be if they did nothing
|
||||||
@@ -330,8 +328,9 @@ class ConfiguredParser(
|
|||||||
|
|
||||||
# at this point, we've collected our hooks. Use the node context to
|
# at this point, we've collected our hooks. Use the node context to
|
||||||
# render each hook and collect refs/sources
|
# render each hook and collect refs/sources
|
||||||
hooks = list(itertools.chain(parsed_node.config.pre_hook,
|
hooks = list(
|
||||||
parsed_node.config.post_hook))
|
itertools.chain(parsed_node.config.pre_hook, parsed_node.config.post_hook)
|
||||||
|
)
|
||||||
# skip context rebuilding if there aren't any hooks
|
# skip context rebuilding if there aren't any hooks
|
||||||
if not hooks:
|
if not hooks:
|
||||||
return
|
return
|
||||||
@@ -354,20 +353,18 @@ class ConfiguredParser(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
f'Got an unexpected project version={config_version}, '
|
f"Got an unexpected project version={config_version}, " f"expected 2"
|
||||||
f'expected 2'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def config_dict(
|
def config_dict(
|
||||||
self, config: ContextConfig,
|
self,
|
||||||
|
config: ContextConfig,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
config_dict = config.build_config_dict(base=True)
|
config_dict = config.build_config_dict(base=True)
|
||||||
self._mangle_hooks(config_dict)
|
self._mangle_hooks(config_dict)
|
||||||
return config_dict
|
return config_dict
|
||||||
|
|
||||||
def render_update(
|
def render_update(self, node: IntermediateNode, config: ContextConfig) -> None:
|
||||||
self, node: IntermediateNode, config: ContextConfig
|
|
||||||
) -> None:
|
|
||||||
try:
|
try:
|
||||||
self.render_with_context(node, config)
|
self.render_with_context(node, config)
|
||||||
self.update_parsed_node(node, config)
|
self.update_parsed_node(node, config)
|
||||||
@@ -410,7 +407,7 @@ class ConfiguredParser(
|
|||||||
|
|
||||||
class SimpleParser(
|
class SimpleParser(
|
||||||
ConfiguredParser[ConfiguredBlockType, FinalNode, FinalNode],
|
ConfiguredParser[ConfiguredBlockType, FinalNode, FinalNode],
|
||||||
Generic[ConfiguredBlockType, FinalNode]
|
Generic[ConfiguredBlockType, FinalNode],
|
||||||
):
|
):
|
||||||
def transform(self, node):
|
def transform(self, node):
|
||||||
return node
|
return node
|
||||||
@@ -418,14 +415,12 @@ class SimpleParser(
|
|||||||
|
|
||||||
class SQLParser(
|
class SQLParser(
|
||||||
ConfiguredParser[FileBlock, IntermediateNode, FinalNode],
|
ConfiguredParser[FileBlock, IntermediateNode, FinalNode],
|
||||||
Generic[IntermediateNode, FinalNode]
|
Generic[IntermediateNode, FinalNode],
|
||||||
):
|
):
|
||||||
def parse_file(self, file_block: FileBlock) -> None:
|
def parse_file(self, file_block: FileBlock) -> None:
|
||||||
self.parse_node(file_block)
|
self.parse_node(file_block)
|
||||||
|
|
||||||
|
|
||||||
class SimpleSQLParser(
|
class SimpleSQLParser(SQLParser[FinalNode, FinalNode]):
|
||||||
SQLParser[FinalNode, FinalNode]
|
|
||||||
):
|
|
||||||
def transform(self, node):
|
def transform(self, node):
|
||||||
return node
|
return node
|
||||||
|
|||||||
@@ -7,23 +7,22 @@ from dbt.utils import get_pseudo_test_path
|
|||||||
|
|
||||||
class DataTestParser(SimpleSQLParser[ParsedDataTestNode]):
|
class DataTestParser(SimpleSQLParser[ParsedDataTestNode]):
|
||||||
def get_paths(self):
|
def get_paths(self):
|
||||||
return FilesystemSearcher(
|
return FilesystemSearcher(self.project, self.project.test_paths, ".sql")
|
||||||
self.project, self.project.test_paths, '.sql'
|
|
||||||
)
|
|
||||||
|
|
||||||
def parse_from_dict(self, dct, validate=True) -> ParsedDataTestNode:
|
def parse_from_dict(self, dct, validate=True) -> ParsedDataTestNode:
|
||||||
return ParsedDataTestNode.from_dict(dct, validate=validate)
|
if validate:
|
||||||
|
ParsedDataTestNode.validate(dct)
|
||||||
|
return ParsedDataTestNode.from_dict(dct)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def resource_type(self) -> NodeType:
|
def resource_type(self) -> NodeType:
|
||||||
return NodeType.Test
|
return NodeType.Test
|
||||||
|
|
||||||
def transform(self, node):
|
def transform(self, node):
|
||||||
if 'data' not in node.tags:
|
if "data" not in node.tags:
|
||||||
node.tags.append('data')
|
node.tags.append("data")
|
||||||
return node
|
return node
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_compiled_path(cls, block: FileBlock):
|
def get_compiled_path(cls, block: FileBlock):
|
||||||
return get_pseudo_test_path(block.name, block.path.relative_path,
|
return get_pseudo_test_path(block.name, block.path.relative_path, "data_test")
|
||||||
'data_test')
|
|
||||||
|
|||||||
@@ -7,11 +7,14 @@ from dbt.contracts.graph.parsed import ParsedDocumentation
|
|||||||
from dbt.node_types import NodeType
|
from dbt.node_types import NodeType
|
||||||
from dbt.parser.base import Parser
|
from dbt.parser.base import Parser
|
||||||
from dbt.parser.search import (
|
from dbt.parser.search import (
|
||||||
BlockContents, FileBlock, FilesystemSearcher, BlockSearcher
|
BlockContents,
|
||||||
|
FileBlock,
|
||||||
|
FilesystemSearcher,
|
||||||
|
BlockSearcher,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
SHOULD_PARSE_RE = re.compile(r'{[{%]')
|
SHOULD_PARSE_RE = re.compile(r"{[{%]")
|
||||||
|
|
||||||
|
|
||||||
class DocumentationParser(Parser[ParsedDocumentation]):
|
class DocumentationParser(Parser[ParsedDocumentation]):
|
||||||
@@ -19,7 +22,7 @@ class DocumentationParser(Parser[ParsedDocumentation]):
|
|||||||
return FilesystemSearcher(
|
return FilesystemSearcher(
|
||||||
project=self.project,
|
project=self.project,
|
||||||
relative_dirs=self.project.docs_paths,
|
relative_dirs=self.project.docs_paths,
|
||||||
extension='.md',
|
extension=".md",
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -33,11 +36,9 @@ class DocumentationParser(Parser[ParsedDocumentation]):
|
|||||||
def generate_unique_id(self, resource_name: str) -> str:
|
def generate_unique_id(self, resource_name: str) -> str:
|
||||||
# because docs are in their own graph namespace, node type doesn't
|
# because docs are in their own graph namespace, node type doesn't
|
||||||
# need to be part of the unique ID.
|
# need to be part of the unique ID.
|
||||||
return '{}.{}'.format(self.project.project_name, resource_name)
|
return "{}.{}".format(self.project.project_name, resource_name)
|
||||||
|
|
||||||
def parse_block(
|
def parse_block(self, block: BlockContents) -> Iterable[ParsedDocumentation]:
|
||||||
self, block: BlockContents
|
|
||||||
) -> Iterable[ParsedDocumentation]:
|
|
||||||
unique_id = self.generate_unique_id(block.name)
|
unique_id = self.generate_unique_id(block.name)
|
||||||
contents = get_rendered(block.contents, {}).strip()
|
contents = get_rendered(block.contents, {}).strip()
|
||||||
|
|
||||||
@@ -55,7 +56,7 @@ class DocumentationParser(Parser[ParsedDocumentation]):
|
|||||||
def parse_file(self, file_block: FileBlock):
|
def parse_file(self, file_block: FileBlock):
|
||||||
searcher: Iterable[BlockContents] = BlockSearcher(
|
searcher: Iterable[BlockContents] = BlockSearcher(
|
||||||
source=[file_block],
|
source=[file_block],
|
||||||
allowed_blocks={'docs'},
|
allowed_blocks={"docs"},
|
||||||
source_tag_factory=BlockContents,
|
source_tag_factory=BlockContents,
|
||||||
)
|
)
|
||||||
for block in searcher:
|
for block in searcher:
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ class HookBlock(FileBlock):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
return '{}-{!s}-{!s}'.format(self.project, self.hook_type, self.index)
|
return "{}-{!s}-{!s}".format(self.project, self.hook_type, self.index)
|
||||||
|
|
||||||
|
|
||||||
class HookSearcher(Iterable[HookBlock]):
|
class HookSearcher(Iterable[HookBlock]):
|
||||||
@@ -33,9 +33,7 @@ class HookSearcher(Iterable[HookBlock]):
|
|||||||
self.source_file = source_file
|
self.source_file = source_file
|
||||||
self.hook_type = hook_type
|
self.hook_type = hook_type
|
||||||
|
|
||||||
def _hook_list(
|
def _hook_list(self, hooks: Union[str, List[str], Tuple[str, ...]]) -> List[str]:
|
||||||
self, hooks: Union[str, List[str], Tuple[str, ...]]
|
|
||||||
) -> List[str]:
|
|
||||||
if isinstance(hooks, tuple):
|
if isinstance(hooks, tuple):
|
||||||
hooks = list(hooks)
|
hooks = list(hooks)
|
||||||
elif not isinstance(hooks, list):
|
elif not isinstance(hooks, list):
|
||||||
@@ -49,8 +47,9 @@ class HookSearcher(Iterable[HookBlock]):
|
|||||||
hooks = self.project.on_run_end
|
hooks = self.project.on_run_end
|
||||||
else:
|
else:
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
'hook_type must be one of "{}" or "{}" (got {})'
|
'hook_type must be one of "{}" or "{}" (got {})'.format(
|
||||||
.format(RunHookType.Start, RunHookType.End, self.hook_type)
|
RunHookType.Start, RunHookType.End, self.hook_type
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return self._hook_list(hooks)
|
return self._hook_list(hooks)
|
||||||
|
|
||||||
@@ -73,13 +72,15 @@ class HookParser(SimpleParser[HookBlock, ParsedHookNode]):
|
|||||||
def get_paths(self) -> List[FilePath]:
|
def get_paths(self) -> List[FilePath]:
|
||||||
path = FilePath(
|
path = FilePath(
|
||||||
project_root=self.project.project_root,
|
project_root=self.project.project_root,
|
||||||
searched_path='.',
|
searched_path=".",
|
||||||
relative_path='dbt_project.yml',
|
relative_path="dbt_project.yml",
|
||||||
)
|
)
|
||||||
return [path]
|
return [path]
|
||||||
|
|
||||||
def parse_from_dict(self, dct, validate=True) -> ParsedHookNode:
|
def parse_from_dict(self, dct, validate=True) -> ParsedHookNode:
|
||||||
return ParsedHookNode.from_dict(dct, validate=validate)
|
if validate:
|
||||||
|
ParsedHookNode.validate(dct)
|
||||||
|
return ParsedHookNode.from_dict(dct)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_compiled_path(cls, block: HookBlock):
|
def get_compiled_path(cls, block: HookBlock):
|
||||||
@@ -96,9 +97,13 @@ class HookParser(SimpleParser[HookBlock, ParsedHookNode]):
|
|||||||
) -> ParsedHookNode:
|
) -> ParsedHookNode:
|
||||||
|
|
||||||
return super()._create_parsetime_node(
|
return super()._create_parsetime_node(
|
||||||
block=block, path=path, config=config, fqn=fqn,
|
block=block,
|
||||||
index=block.index, name=name,
|
path=path,
|
||||||
tags=[str(block.hook_type)]
|
config=config,
|
||||||
|
fqn=fqn,
|
||||||
|
index=block.index,
|
||||||
|
name=name,
|
||||||
|
tags=[str(block.hook_type)],
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ class MacroParser(BaseParser[ParsedMacro]):
|
|||||||
return FilesystemSearcher(
|
return FilesystemSearcher(
|
||||||
project=self.project,
|
project=self.project,
|
||||||
relative_dirs=self.project.macro_paths,
|
relative_dirs=self.project.macro_paths,
|
||||||
extension='.sql',
|
extension=".sql",
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -45,15 +45,13 @@ class MacroParser(BaseParser[ParsedMacro]):
|
|||||||
unique_id=unique_id,
|
unique_id=unique_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def parse_unparsed_macros(
|
def parse_unparsed_macros(self, base_node: UnparsedMacro) -> Iterable[ParsedMacro]:
|
||||||
self, base_node: UnparsedMacro
|
|
||||||
) -> Iterable[ParsedMacro]:
|
|
||||||
try:
|
try:
|
||||||
blocks: List[jinja.BlockTag] = [
|
blocks: List[jinja.BlockTag] = [
|
||||||
t for t in
|
t
|
||||||
jinja.extract_toplevel_blocks(
|
for t in jinja.extract_toplevel_blocks(
|
||||||
base_node.raw_sql,
|
base_node.raw_sql,
|
||||||
allowed_blocks={'macro', 'materialization'},
|
allowed_blocks={"macro", "materialization"},
|
||||||
collect_raw_data=False,
|
collect_raw_data=False,
|
||||||
)
|
)
|
||||||
if isinstance(t, jinja.BlockTag)
|
if isinstance(t, jinja.BlockTag)
|
||||||
@@ -75,8 +73,8 @@ class MacroParser(BaseParser[ParsedMacro]):
|
|||||||
# things have gone disastrously wrong, we thought we only
|
# things have gone disastrously wrong, we thought we only
|
||||||
# parsed one block!
|
# parsed one block!
|
||||||
raise CompilationException(
|
raise CompilationException(
|
||||||
f'Found multiple macros in {block.full_block}, expected 1',
|
f"Found multiple macros in {block.full_block}, expected 1",
|
||||||
node=base_node
|
node=base_node,
|
||||||
)
|
)
|
||||||
|
|
||||||
macro_name = macro_nodes[0].name
|
macro_name = macro_nodes[0].name
|
||||||
@@ -84,7 +82,7 @@ class MacroParser(BaseParser[ParsedMacro]):
|
|||||||
if not macro_name.startswith(MACRO_PREFIX):
|
if not macro_name.startswith(MACRO_PREFIX):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
name: str = macro_name.replace(MACRO_PREFIX, '')
|
name: str = macro_name.replace(MACRO_PREFIX, "")
|
||||||
node = self.parse_macro(block, base_node, name)
|
node = self.parse_macro(block, base_node, name)
|
||||||
yield node
|
yield node
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,15 @@ from dataclasses import field
|
|||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
from typing import (
|
from typing import (
|
||||||
Dict, Optional, Mapping, Callable, Any, List, Type, Union, MutableMapping
|
Dict,
|
||||||
|
Optional,
|
||||||
|
Mapping,
|
||||||
|
Callable,
|
||||||
|
Any,
|
||||||
|
List,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
MutableMapping,
|
||||||
)
|
)
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@@ -23,9 +31,13 @@ from dbt.config import Project, RuntimeConfig
|
|||||||
from dbt.context.docs import generate_runtime_docs
|
from dbt.context.docs import generate_runtime_docs
|
||||||
from dbt.contracts.files import FilePath, FileHash
|
from dbt.contracts.files import FilePath, FileHash
|
||||||
from dbt.contracts.graph.compiled import ManifestNode
|
from dbt.contracts.graph.compiled import ManifestNode
|
||||||
from dbt.contracts.graph.manifest import Manifest, Disabled
|
from dbt.contracts.graph.manifest import Manifest, MacroManifest, AnyManifest, Disabled
|
||||||
from dbt.contracts.graph.parsed import (
|
from dbt.contracts.graph.parsed import (
|
||||||
ParsedSourceDefinition, ParsedNode, ParsedMacro, ColumnInfo, ParsedExposure
|
ParsedSourceDefinition,
|
||||||
|
ParsedNode,
|
||||||
|
ParsedMacro,
|
||||||
|
ColumnInfo,
|
||||||
|
ParsedExposure,
|
||||||
)
|
)
|
||||||
from dbt.contracts.util import Writable
|
from dbt.contracts.util import Writable
|
||||||
from dbt.exceptions import (
|
from dbt.exceptions import (
|
||||||
@@ -51,22 +63,22 @@ from dbt.parser.sources import patch_sources
|
|||||||
from dbt.ui import warning_tag
|
from dbt.ui import warning_tag
|
||||||
from dbt.version import __version__
|
from dbt.version import __version__
|
||||||
|
|
||||||
from hologram import JsonSchemaMixin
|
from dbt.dataclass_schema import dbtClassMixin
|
||||||
|
|
||||||
PARTIAL_PARSE_FILE_NAME = 'partial_parse.pickle'
|
PARTIAL_PARSE_FILE_NAME = "partial_parse.pickle"
|
||||||
PARSING_STATE = DbtProcessState('parsing')
|
PARSING_STATE = DbtProcessState("parsing")
|
||||||
DEFAULT_PARTIAL_PARSE = False
|
DEFAULT_PARTIAL_PARSE = False
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ParserInfo(JsonSchemaMixin):
|
class ParserInfo(dbtClassMixin):
|
||||||
parser: str
|
parser: str
|
||||||
elapsed: float
|
elapsed: float
|
||||||
path_count: int = 0
|
path_count: int = 0
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ProjectLoaderInfo(JsonSchemaMixin):
|
class ProjectLoaderInfo(dbtClassMixin):
|
||||||
project_name: str
|
project_name: str
|
||||||
elapsed: float
|
elapsed: float
|
||||||
parsers: List[ParserInfo]
|
parsers: List[ParserInfo]
|
||||||
@@ -74,7 +86,7 @@ class ProjectLoaderInfo(JsonSchemaMixin):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ManifestLoaderInfo(JsonSchemaMixin, Writable):
|
class ManifestLoaderInfo(dbtClassMixin, Writable):
|
||||||
path_count: int = 0
|
path_count: int = 0
|
||||||
is_partial_parse_enabled: Optional[bool] = None
|
is_partial_parse_enabled: Optional[bool] = None
|
||||||
parse_project_elapsed: Optional[float] = None
|
parse_project_elapsed: Optional[float] = None
|
||||||
@@ -108,20 +120,22 @@ def make_parse_result(
|
|||||||
"""Make a ParseResult from the project configuration and the profile."""
|
"""Make a ParseResult from the project configuration and the profile."""
|
||||||
# if any of these change, we need to reject the parser
|
# if any of these change, we need to reject the parser
|
||||||
vars_hash = FileHash.from_contents(
|
vars_hash = FileHash.from_contents(
|
||||||
'\x00'.join([
|
"\x00".join(
|
||||||
getattr(config.args, 'vars', '{}') or '{}',
|
[
|
||||||
getattr(config.args, 'profile', '') or '',
|
getattr(config.args, "vars", "{}") or "{}",
|
||||||
getattr(config.args, 'target', '') or '',
|
getattr(config.args, "profile", "") or "",
|
||||||
__version__
|
getattr(config.args, "target", "") or "",
|
||||||
])
|
__version__,
|
||||||
|
]
|
||||||
|
)
|
||||||
)
|
)
|
||||||
profile_path = os.path.join(config.args.profiles_dir, 'profiles.yml')
|
profile_path = os.path.join(config.args.profiles_dir, "profiles.yml")
|
||||||
with open(profile_path) as fp:
|
with open(profile_path) as fp:
|
||||||
profile_hash = FileHash.from_contents(fp.read())
|
profile_hash = FileHash.from_contents(fp.read())
|
||||||
|
|
||||||
project_hashes = {}
|
project_hashes = {}
|
||||||
for name, project in all_projects.items():
|
for name, project in all_projects.items():
|
||||||
path = os.path.join(project.project_root, 'dbt_project.yml')
|
path = os.path.join(project.project_root, "dbt_project.yml")
|
||||||
with open(path) as fp:
|
with open(path) as fp:
|
||||||
project_hashes[name] = FileHash.from_contents(fp.read())
|
project_hashes[name] = FileHash.from_contents(fp.read())
|
||||||
|
|
||||||
@@ -137,18 +151,22 @@ class ManifestLoader:
|
|||||||
self,
|
self,
|
||||||
root_project: RuntimeConfig,
|
root_project: RuntimeConfig,
|
||||||
all_projects: Mapping[str, Project],
|
all_projects: Mapping[str, Project],
|
||||||
macro_hook: Optional[Callable[[Manifest], Any]] = None,
|
macro_hook: Optional[Callable[[AnyManifest], Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.root_project: RuntimeConfig = root_project
|
self.root_project: RuntimeConfig = root_project
|
||||||
self.all_projects: Mapping[str, Project] = all_projects
|
self.all_projects: Mapping[str, Project] = all_projects
|
||||||
self.macro_hook: Callable[[Manifest], Any]
|
self.macro_hook: Callable[[AnyManifest], Any]
|
||||||
if macro_hook is None:
|
if macro_hook is None:
|
||||||
self.macro_hook = lambda m: None
|
self.macro_hook = lambda m: None
|
||||||
else:
|
else:
|
||||||
self.macro_hook = macro_hook
|
self.macro_hook = macro_hook
|
||||||
|
|
||||||
|
# results holds all of the nodes created by parsing,
|
||||||
|
# in dictionaries: nodes, sources, docs, macros, exposures,
|
||||||
|
# macro_patches, patches, source_patches, files, etc
|
||||||
self.results: ParseResult = make_parse_result(
|
self.results: ParseResult = make_parse_result(
|
||||||
root_project, all_projects,
|
root_project,
|
||||||
|
all_projects,
|
||||||
)
|
)
|
||||||
self._loaded_file_cache: Dict[str, FileBlock] = {}
|
self._loaded_file_cache: Dict[str, FileBlock] = {}
|
||||||
self._perf_info = ManifestLoaderInfo(
|
self._perf_info = ManifestLoaderInfo(
|
||||||
@@ -157,20 +175,18 @@ class ManifestLoader:
|
|||||||
|
|
||||||
def track_project_load(self):
|
def track_project_load(self):
|
||||||
invocation_id = dbt.tracking.active_user.invocation_id
|
invocation_id = dbt.tracking.active_user.invocation_id
|
||||||
dbt.tracking.track_project_load({
|
dbt.tracking.track_project_load(
|
||||||
"invocation_id": invocation_id,
|
{
|
||||||
"project_id": self.root_project.hashed_name(),
|
"invocation_id": invocation_id,
|
||||||
"path_count": self._perf_info.path_count,
|
"project_id": self.root_project.hashed_name(),
|
||||||
"parse_project_elapsed": self._perf_info.parse_project_elapsed,
|
"path_count": self._perf_info.path_count,
|
||||||
"patch_sources_elapsed": self._perf_info.patch_sources_elapsed,
|
"parse_project_elapsed": self._perf_info.parse_project_elapsed,
|
||||||
"process_manifest_elapsed": (
|
"patch_sources_elapsed": self._perf_info.patch_sources_elapsed,
|
||||||
self._perf_info.process_manifest_elapsed
|
"process_manifest_elapsed": (self._perf_info.process_manifest_elapsed),
|
||||||
),
|
"load_all_elapsed": self._perf_info.load_all_elapsed,
|
||||||
"load_all_elapsed": self._perf_info.load_all_elapsed,
|
"is_partial_parse_enabled": (self._perf_info.is_partial_parse_enabled),
|
||||||
"is_partial_parse_enabled": (
|
}
|
||||||
self._perf_info.is_partial_parse_enabled
|
)
|
||||||
),
|
|
||||||
})
|
|
||||||
|
|
||||||
def parse_with_cache(
|
def parse_with_cache(
|
||||||
self,
|
self,
|
||||||
@@ -210,13 +226,12 @@ class ManifestLoader:
|
|||||||
def parse_project(
|
def parse_project(
|
||||||
self,
|
self,
|
||||||
project: Project,
|
project: Project,
|
||||||
macro_manifest: Manifest,
|
macro_manifest: MacroManifest,
|
||||||
old_results: Optional[ParseResult],
|
old_results: Optional[ParseResult],
|
||||||
) -> None:
|
) -> None:
|
||||||
parsers: List[Parser] = []
|
parsers: List[Parser] = []
|
||||||
for cls in _parser_types:
|
for cls in _parser_types:
|
||||||
parser = cls(self.results, project, self.root_project,
|
parser = cls(self.results, project, self.root_project, macro_manifest)
|
||||||
macro_manifest)
|
|
||||||
parsers.append(parser)
|
parsers.append(parser)
|
||||||
|
|
||||||
# per-project cache.
|
# per-project cache.
|
||||||
@@ -233,11 +248,13 @@ class ManifestLoader:
|
|||||||
parser_path_count = parser_path_count + 1
|
parser_path_count = parser_path_count + 1
|
||||||
|
|
||||||
if parser_path_count > 0:
|
if parser_path_count > 0:
|
||||||
project_parser_info.append(ParserInfo(
|
project_parser_info.append(
|
||||||
parser=parser.resource_type,
|
ParserInfo(
|
||||||
path_count=parser_path_count,
|
parser=parser.resource_type,
|
||||||
elapsed=time.perf_counter() - parser_start_timer
|
path_count=parser_path_count,
|
||||||
))
|
elapsed=time.perf_counter() - parser_start_timer,
|
||||||
|
)
|
||||||
|
)
|
||||||
total_path_count = total_path_count + parser_path_count
|
total_path_count = total_path_count + parser_path_count
|
||||||
|
|
||||||
elapsed = time.perf_counter() - start_timer
|
elapsed = time.perf_counter() - start_timer
|
||||||
@@ -245,14 +262,12 @@ class ManifestLoader:
|
|||||||
project_name=project.project_name,
|
project_name=project.project_name,
|
||||||
path_count=total_path_count,
|
path_count=total_path_count,
|
||||||
elapsed=elapsed,
|
elapsed=elapsed,
|
||||||
parsers=project_parser_info
|
parsers=project_parser_info,
|
||||||
)
|
)
|
||||||
self._perf_info.projects.append(project_info)
|
self._perf_info.projects.append(project_info)
|
||||||
self._perf_info.path_count = (
|
self._perf_info.path_count = self._perf_info.path_count + total_path_count
|
||||||
self._perf_info.path_count + total_path_count
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_only_macros(self) -> Manifest:
|
def load_only_macros(self) -> MacroManifest:
|
||||||
old_results = self.read_parse_results()
|
old_results = self.read_parse_results()
|
||||||
|
|
||||||
for project in self.all_projects.values():
|
for project in self.all_projects.values():
|
||||||
@@ -261,17 +276,19 @@ class ManifestLoader:
|
|||||||
self.parse_with_cache(path, parser, old_results)
|
self.parse_with_cache(path, parser, old_results)
|
||||||
|
|
||||||
# make a manifest with just the macros to get the context
|
# make a manifest with just the macros to get the context
|
||||||
macro_manifest = Manifest.from_macros(
|
macro_manifest = MacroManifest(
|
||||||
macros=self.results.macros,
|
macros=self.results.macros, files=self.results.files
|
||||||
files=self.results.files
|
|
||||||
)
|
)
|
||||||
self.macro_hook(macro_manifest)
|
self.macro_hook(macro_manifest)
|
||||||
return macro_manifest
|
return macro_manifest
|
||||||
|
|
||||||
def load(self, macro_manifest: Manifest):
|
# This is where the main action happens
|
||||||
|
def load(self, macro_manifest: MacroManifest):
|
||||||
|
# if partial parse is enabled, load old results
|
||||||
old_results = self.read_parse_results()
|
old_results = self.read_parse_results()
|
||||||
if old_results is not None:
|
if old_results is not None:
|
||||||
logger.debug('Got an acceptable cached parse result')
|
logger.debug("Got an acceptable cached parse result")
|
||||||
|
# store the macros & files from the adapter macro manifest
|
||||||
self.results.macros.update(macro_manifest.macros)
|
self.results.macros.update(macro_manifest.macros)
|
||||||
self.results.files.update(macro_manifest.files)
|
self.results.files.update(macro_manifest.files)
|
||||||
|
|
||||||
@@ -281,15 +298,12 @@ class ManifestLoader:
|
|||||||
# parse a single project
|
# parse a single project
|
||||||
self.parse_project(project, macro_manifest, old_results)
|
self.parse_project(project, macro_manifest, old_results)
|
||||||
|
|
||||||
self._perf_info.parse_project_elapsed = (
|
self._perf_info.parse_project_elapsed = time.perf_counter() - start_timer
|
||||||
time.perf_counter() - start_timer
|
|
||||||
)
|
|
||||||
|
|
||||||
def write_parse_results(self):
|
def write_parse_results(self):
|
||||||
path = os.path.join(self.root_project.target_path,
|
path = os.path.join(self.root_project.target_path, PARTIAL_PARSE_FILE_NAME)
|
||||||
PARTIAL_PARSE_FILE_NAME)
|
|
||||||
make_directory(self.root_project.target_path)
|
make_directory(self.root_project.target_path)
|
||||||
with open(path, 'wb') as fp:
|
with open(path, "wb") as fp:
|
||||||
pickle.dump(self.results, fp)
|
pickle.dump(self.results, fp)
|
||||||
|
|
||||||
def matching_parse_results(self, result: ParseResult) -> bool:
|
def matching_parse_results(self, result: ParseResult) -> bool:
|
||||||
@@ -299,31 +313,32 @@ class ManifestLoader:
|
|||||||
try:
|
try:
|
||||||
if result.dbt_version != __version__:
|
if result.dbt_version != __version__:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
'dbt version mismatch: {} != {}, cache invalidated'
|
"dbt version mismatch: {} != {}, cache invalidated".format(
|
||||||
.format(result.dbt_version, __version__)
|
result.dbt_version, __version__
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
logger.debug('malformed result file, cache invalidated')
|
logger.debug("malformed result file, cache invalidated")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
valid = True
|
valid = True
|
||||||
|
|
||||||
if self.results.vars_hash != result.vars_hash:
|
if self.results.vars_hash != result.vars_hash:
|
||||||
logger.debug('vars hash mismatch, cache invalidated')
|
logger.debug("vars hash mismatch, cache invalidated")
|
||||||
valid = False
|
valid = False
|
||||||
if self.results.profile_hash != result.profile_hash:
|
if self.results.profile_hash != result.profile_hash:
|
||||||
logger.debug('profile hash mismatch, cache invalidated')
|
logger.debug("profile hash mismatch, cache invalidated")
|
||||||
valid = False
|
valid = False
|
||||||
|
|
||||||
missing_keys = {
|
missing_keys = {
|
||||||
k for k in self.results.project_hashes
|
k for k in self.results.project_hashes if k not in result.project_hashes
|
||||||
if k not in result.project_hashes
|
|
||||||
}
|
}
|
||||||
if missing_keys:
|
if missing_keys:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
'project hash mismatch: values missing, cache invalidated: {}'
|
"project hash mismatch: values missing, cache invalidated: {}".format(
|
||||||
.format(missing_keys)
|
missing_keys
|
||||||
|
)
|
||||||
)
|
)
|
||||||
valid = False
|
valid = False
|
||||||
|
|
||||||
@@ -332,9 +347,8 @@ class ManifestLoader:
|
|||||||
old_value = result.project_hashes[key]
|
old_value = result.project_hashes[key]
|
||||||
if new_value != old_value:
|
if new_value != old_value:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
'For key {}, hash mismatch ({} -> {}), cache '
|
"For key {}, hash mismatch ({} -> {}), cache "
|
||||||
'invalidated'
|
"invalidated".format(key, old_value, new_value)
|
||||||
.format(key, old_value, new_value)
|
|
||||||
)
|
)
|
||||||
valid = False
|
valid = False
|
||||||
return valid
|
return valid
|
||||||
@@ -351,14 +365,13 @@ class ManifestLoader:
|
|||||||
|
|
||||||
def read_parse_results(self) -> Optional[ParseResult]:
|
def read_parse_results(self) -> Optional[ParseResult]:
|
||||||
if not self._partial_parse_enabled():
|
if not self._partial_parse_enabled():
|
||||||
logger.debug('Partial parsing not enabled')
|
logger.debug("Partial parsing not enabled")
|
||||||
return None
|
return None
|
||||||
path = os.path.join(self.root_project.target_path,
|
path = os.path.join(self.root_project.target_path, PARTIAL_PARSE_FILE_NAME)
|
||||||
PARTIAL_PARSE_FILE_NAME)
|
|
||||||
|
|
||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
try:
|
try:
|
||||||
with open(path, 'rb') as fp:
|
with open(path, "rb") as fp:
|
||||||
result: ParseResult = pickle.load(fp)
|
result: ParseResult = pickle.load(fp)
|
||||||
# keep this check inside the try/except in case something about
|
# keep this check inside the try/except in case something about
|
||||||
# the file has changed in weird ways, perhaps due to being a
|
# the file has changed in weird ways, perhaps due to being a
|
||||||
@@ -367,9 +380,8 @@ class ManifestLoader:
|
|||||||
return result
|
return result
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
'Failed to load parsed file from disk at {}: {}'
|
"Failed to load parsed file from disk at {}: {}".format(path, exc),
|
||||||
.format(path, exc),
|
exc_info=True,
|
||||||
exc_info=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
@@ -386,9 +398,7 @@ class ManifestLoader:
|
|||||||
# list is created
|
# list is created
|
||||||
start_patch = time.perf_counter()
|
start_patch = time.perf_counter()
|
||||||
sources = patch_sources(self.results, self.root_project)
|
sources = patch_sources(self.results, self.root_project)
|
||||||
self._perf_info.patch_sources_elapsed = (
|
self._perf_info.patch_sources_elapsed = time.perf_counter() - start_patch
|
||||||
time.perf_counter() - start_patch
|
|
||||||
)
|
|
||||||
disabled = []
|
disabled = []
|
||||||
for value in self.results.disabled.values():
|
for value in self.results.disabled.values():
|
||||||
disabled.extend(value)
|
disabled.extend(value)
|
||||||
@@ -413,9 +423,7 @@ class ManifestLoader:
|
|||||||
start_process = time.perf_counter()
|
start_process = time.perf_counter()
|
||||||
self.process_manifest(manifest)
|
self.process_manifest(manifest)
|
||||||
|
|
||||||
self._perf_info.process_manifest_elapsed = (
|
self._perf_info.process_manifest_elapsed = time.perf_counter() - start_process
|
||||||
time.perf_counter() - start_process
|
|
||||||
)
|
|
||||||
|
|
||||||
return manifest
|
return manifest
|
||||||
|
|
||||||
@@ -423,8 +431,8 @@ class ManifestLoader:
|
|||||||
def load_all(
|
def load_all(
|
||||||
cls,
|
cls,
|
||||||
root_config: RuntimeConfig,
|
root_config: RuntimeConfig,
|
||||||
macro_manifest: Manifest,
|
macro_manifest: MacroManifest,
|
||||||
macro_hook: Callable[[Manifest], Any],
|
macro_hook: Callable[[AnyManifest], Any],
|
||||||
) -> Manifest:
|
) -> Manifest:
|
||||||
with PARSING_STATE:
|
with PARSING_STATE:
|
||||||
start_load_all = time.perf_counter()
|
start_load_all = time.perf_counter()
|
||||||
@@ -437,9 +445,7 @@ class ManifestLoader:
|
|||||||
_check_manifest(manifest, root_config)
|
_check_manifest(manifest, root_config)
|
||||||
manifest.build_flat_graph()
|
manifest.build_flat_graph()
|
||||||
|
|
||||||
loader._perf_info.load_all_elapsed = (
|
loader._perf_info.load_all_elapsed = time.perf_counter() - start_load_all
|
||||||
time.perf_counter() - start_load_all
|
|
||||||
)
|
|
||||||
|
|
||||||
loader.track_project_load()
|
loader.track_project_load()
|
||||||
|
|
||||||
@@ -449,16 +455,17 @@ class ManifestLoader:
|
|||||||
def load_macros(
|
def load_macros(
|
||||||
cls,
|
cls,
|
||||||
root_config: RuntimeConfig,
|
root_config: RuntimeConfig,
|
||||||
macro_hook: Callable[[Manifest], Any],
|
macro_hook: Callable[[AnyManifest], Any],
|
||||||
) -> Manifest:
|
) -> MacroManifest:
|
||||||
with PARSING_STATE:
|
with PARSING_STATE:
|
||||||
projects = root_config.load_dependencies()
|
projects = root_config.load_dependencies()
|
||||||
loader = cls(root_config, projects, macro_hook)
|
loader = cls(root_config, projects, macro_hook)
|
||||||
return loader.load_only_macros()
|
return loader.load_only_macros()
|
||||||
|
|
||||||
|
|
||||||
def invalid_ref_fail_unless_test(node, target_model_name,
|
def invalid_ref_fail_unless_test(
|
||||||
target_model_package, disabled):
|
node, target_model_name, target_model_package, disabled
|
||||||
|
):
|
||||||
|
|
||||||
if node.resource_type == NodeType.Test:
|
if node.resource_type == NodeType.Test:
|
||||||
msg = get_target_not_found_or_disabled_msg(
|
msg = get_target_not_found_or_disabled_msg(
|
||||||
@@ -467,10 +474,7 @@ def invalid_ref_fail_unless_test(node, target_model_name,
|
|||||||
if disabled:
|
if disabled:
|
||||||
logger.debug(warning_tag(msg))
|
logger.debug(warning_tag(msg))
|
||||||
else:
|
else:
|
||||||
warn_or_error(
|
warn_or_error(msg, log_fmt=warning_tag("{}"))
|
||||||
msg,
|
|
||||||
log_fmt=warning_tag('{}')
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
ref_target_not_found(
|
ref_target_not_found(
|
||||||
node,
|
node,
|
||||||
@@ -480,9 +484,7 @@ def invalid_ref_fail_unless_test(node, target_model_name,
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def invalid_source_fail_unless_test(
|
def invalid_source_fail_unless_test(node, target_name, target_table_name, disabled):
|
||||||
node, target_name, target_table_name, disabled
|
|
||||||
):
|
|
||||||
if node.resource_type == NodeType.Test:
|
if node.resource_type == NodeType.Test:
|
||||||
msg = get_source_not_found_or_disabled_msg(
|
msg = get_source_not_found_or_disabled_msg(
|
||||||
node, target_name, target_table_name, disabled
|
node, target_name, target_table_name, disabled
|
||||||
@@ -490,17 +492,9 @@ def invalid_source_fail_unless_test(
|
|||||||
if disabled:
|
if disabled:
|
||||||
logger.debug(warning_tag(msg))
|
logger.debug(warning_tag(msg))
|
||||||
else:
|
else:
|
||||||
warn_or_error(
|
warn_or_error(msg, log_fmt=warning_tag("{}"))
|
||||||
msg,
|
|
||||||
log_fmt=warning_tag('{}')
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
source_target_not_found(
|
source_target_not_found(node, target_name, target_table_name, disabled=disabled)
|
||||||
node,
|
|
||||||
target_name,
|
|
||||||
target_table_name,
|
|
||||||
disabled=disabled
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _check_resource_uniqueness(
|
def _check_resource_uniqueness(
|
||||||
@@ -524,15 +518,11 @@ def _check_resource_uniqueness(
|
|||||||
|
|
||||||
existing_node = names_resources.get(name)
|
existing_node = names_resources.get(name)
|
||||||
if existing_node is not None:
|
if existing_node is not None:
|
||||||
dbt.exceptions.raise_duplicate_resource_name(
|
dbt.exceptions.raise_duplicate_resource_name(existing_node, node)
|
||||||
existing_node, node
|
|
||||||
)
|
|
||||||
|
|
||||||
existing_alias = alias_resources.get(full_node_name)
|
existing_alias = alias_resources.get(full_node_name)
|
||||||
if existing_alias is not None:
|
if existing_alias is not None:
|
||||||
dbt.exceptions.raise_ambiguous_alias(
|
dbt.exceptions.raise_ambiguous_alias(existing_alias, node, full_node_name)
|
||||||
existing_alias, node, full_node_name
|
|
||||||
)
|
|
||||||
|
|
||||||
names_resources[name] = node
|
names_resources[name] = node
|
||||||
alias_resources[full_node_name] = node
|
alias_resources[full_node_name] = node
|
||||||
@@ -557,8 +547,7 @@ def _load_projects(config, paths):
|
|||||||
project = config.new_project(path)
|
project = config.new_project(path)
|
||||||
except dbt.exceptions.DbtProjectError as e:
|
except dbt.exceptions.DbtProjectError as e:
|
||||||
raise dbt.exceptions.DbtProjectError(
|
raise dbt.exceptions.DbtProjectError(
|
||||||
'Failed to read package at {}: {}'
|
"Failed to read package at {}: {}".format(path, e)
|
||||||
.format(path, e)
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
yield project.project_name, project
|
yield project.project_name, project
|
||||||
@@ -579,8 +568,7 @@ def _get_node_column(node, column_name):
|
|||||||
|
|
||||||
|
|
||||||
DocsContextCallback = Callable[
|
DocsContextCallback = Callable[
|
||||||
[Union[ParsedNode, ParsedSourceDefinition]],
|
[Union[ParsedNode, ParsedSourceDefinition]], Dict[str, Any]
|
||||||
Dict[str, Any]
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -610,9 +598,7 @@ def _process_docs_for_source(
|
|||||||
column.description = column_desc
|
column.description = column_desc
|
||||||
|
|
||||||
|
|
||||||
def _process_docs_for_macro(
|
def _process_docs_for_macro(context: Dict[str, Any], macro: ParsedMacro) -> None:
|
||||||
context: Dict[str, Any], macro: ParsedMacro
|
|
||||||
) -> None:
|
|
||||||
macro.description = get_rendered(macro.description, context)
|
macro.description = get_rendered(macro.description, context)
|
||||||
for arg in macro.arguments:
|
for arg in macro.arguments:
|
||||||
arg.description = get_rendered(arg.description, context)
|
arg.description = get_rendered(arg.description, context)
|
||||||
@@ -674,7 +660,7 @@ def _process_refs_for_exposure(
|
|||||||
target_model_package, target_model_name = ref
|
target_model_package, target_model_name = ref
|
||||||
else:
|
else:
|
||||||
raise dbt.exceptions.InternalException(
|
raise dbt.exceptions.InternalException(
|
||||||
f'Refs should always be 1 or 2 arguments - got {len(ref)}'
|
f"Refs should always be 1 or 2 arguments - got {len(ref)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
target_model = manifest.resolve_ref(
|
target_model = manifest.resolve_ref(
|
||||||
@@ -688,8 +674,10 @@ def _process_refs_for_exposure(
|
|||||||
# This may raise. Even if it doesn't, we don't want to add
|
# This may raise. Even if it doesn't, we don't want to add
|
||||||
# this exposure to the graph b/c there is no destination exposure
|
# this exposure to the graph b/c there is no destination exposure
|
||||||
invalid_ref_fail_unless_test(
|
invalid_ref_fail_unless_test(
|
||||||
exposure, target_model_name, target_model_package,
|
exposure,
|
||||||
disabled=(isinstance(target_model, Disabled))
|
target_model_name,
|
||||||
|
target_model_package,
|
||||||
|
disabled=(isinstance(target_model, Disabled)),
|
||||||
)
|
)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
@@ -715,7 +703,7 @@ def _process_refs_for_node(
|
|||||||
target_model_package, target_model_name = ref
|
target_model_package, target_model_name = ref
|
||||||
else:
|
else:
|
||||||
raise dbt.exceptions.InternalException(
|
raise dbt.exceptions.InternalException(
|
||||||
f'Refs should always be 1 or 2 arguments - got {len(ref)}'
|
f"Refs should always be 1 or 2 arguments - got {len(ref)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
target_model = manifest.resolve_ref(
|
target_model = manifest.resolve_ref(
|
||||||
@@ -730,8 +718,10 @@ def _process_refs_for_node(
|
|||||||
# this node to the graph b/c there is no destination node
|
# this node to the graph b/c there is no destination node
|
||||||
node.config.enabled = False
|
node.config.enabled = False
|
||||||
invalid_ref_fail_unless_test(
|
invalid_ref_fail_unless_test(
|
||||||
node, target_model_name, target_model_package,
|
node,
|
||||||
disabled=(isinstance(target_model, Disabled))
|
target_model_name,
|
||||||
|
target_model_package,
|
||||||
|
disabled=(isinstance(target_model, Disabled)),
|
||||||
)
|
)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
@@ -769,7 +759,7 @@ def _process_sources_for_exposure(
|
|||||||
exposure,
|
exposure,
|
||||||
source_name,
|
source_name,
|
||||||
table_name,
|
table_name,
|
||||||
disabled=(isinstance(target_source, Disabled))
|
disabled=(isinstance(target_source, Disabled)),
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
target_source_id = target_source.unique_id
|
target_source_id = target_source.unique_id
|
||||||
@@ -796,7 +786,7 @@ def _process_sources_for_node(
|
|||||||
node,
|
node,
|
||||||
source_name,
|
source_name,
|
||||||
table_name,
|
table_name,
|
||||||
disabled=(isinstance(target_source, Disabled))
|
disabled=(isinstance(target_source, Disabled)),
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
target_source_id = target_source.unique_id
|
target_source_id = target_source.unique_id
|
||||||
@@ -827,13 +817,9 @@ def process_macro(
|
|||||||
_process_docs_for_macro(ctx, macro)
|
_process_docs_for_macro(ctx, macro)
|
||||||
|
|
||||||
|
|
||||||
def process_node(
|
def process_node(config: RuntimeConfig, manifest: Manifest, node: ManifestNode):
|
||||||
config: RuntimeConfig, manifest: Manifest, node: ManifestNode
|
|
||||||
):
|
|
||||||
|
|
||||||
_process_sources_for_node(
|
_process_sources_for_node(manifest, config.project_name, node)
|
||||||
manifest, config.project_name, node
|
|
||||||
)
|
|
||||||
_process_refs_for_node(manifest, config.project_name, node)
|
_process_refs_for_node(manifest, config.project_name, node)
|
||||||
ctx = generate_runtime_docs(config, node, manifest, config.project_name)
|
ctx = generate_runtime_docs(config, node, manifest, config.project_name)
|
||||||
_process_docs_for_node(ctx, node)
|
_process_docs_for_node(ctx, node)
|
||||||
@@ -841,14 +827,14 @@ def process_node(
|
|||||||
|
|
||||||
def load_macro_manifest(
|
def load_macro_manifest(
|
||||||
config: RuntimeConfig,
|
config: RuntimeConfig,
|
||||||
macro_hook: Callable[[Manifest], Any],
|
macro_hook: Callable[[AnyManifest], Any],
|
||||||
) -> Manifest:
|
) -> MacroManifest:
|
||||||
return ManifestLoader.load_macros(config, macro_hook)
|
return ManifestLoader.load_macros(config, macro_hook)
|
||||||
|
|
||||||
|
|
||||||
def load_manifest(
|
def load_manifest(
|
||||||
config: RuntimeConfig,
|
config: RuntimeConfig,
|
||||||
macro_manifest: Manifest,
|
macro_manifest: MacroManifest,
|
||||||
macro_hook: Callable[[Manifest], Any],
|
macro_hook: Callable[[AnyManifest], Any],
|
||||||
) -> Manifest:
|
) -> Manifest:
|
||||||
return ManifestLoader.load_all(config, macro_manifest, macro_hook)
|
return ManifestLoader.load_all(config, macro_manifest, macro_hook)
|
||||||
|
|||||||
@@ -6,12 +6,12 @@ from dbt.parser.search import FilesystemSearcher, FileBlock
|
|||||||
|
|
||||||
class ModelParser(SimpleSQLParser[ParsedModelNode]):
|
class ModelParser(SimpleSQLParser[ParsedModelNode]):
|
||||||
def get_paths(self):
|
def get_paths(self):
|
||||||
return FilesystemSearcher(
|
return FilesystemSearcher(self.project, self.project.source_paths, ".sql")
|
||||||
self.project, self.project.source_paths, '.sql'
|
|
||||||
)
|
|
||||||
|
|
||||||
def parse_from_dict(self, dct, validate=True) -> ParsedModelNode:
|
def parse_from_dict(self, dct, validate=True) -> ParsedModelNode:
|
||||||
return ParsedModelNode.from_dict(dct, validate=validate)
|
if validate:
|
||||||
|
ParsedModelNode.validate(dct)
|
||||||
|
return ParsedModelNode.from_dict(dct)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def resource_type(self) -> NodeType:
|
def resource_type(self) -> NodeType:
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TypeVar, MutableMapping, Mapping, Union, List
|
from typing import TypeVar, MutableMapping, Mapping, Union, List
|
||||||
|
|
||||||
from hologram import JsonSchemaMixin
|
from dbt.dataclass_schema import dbtClassMixin
|
||||||
|
|
||||||
from dbt.contracts.files import RemoteFile, FileHash, SourceFile
|
from dbt.contracts.files import RemoteFile, FileHash, SourceFile
|
||||||
from dbt.contracts.graph.compiled import CompileResultNode
|
from dbt.contracts.graph.compiled import CompileResultNode
|
||||||
@@ -25,9 +25,13 @@ from dbt.contracts.graph.parsed import (
|
|||||||
from dbt.contracts.graph.unparsed import SourcePatch
|
from dbt.contracts.graph.unparsed import SourcePatch
|
||||||
from dbt.contracts.util import Writable, Replaceable, MacroKey, SourceKey
|
from dbt.contracts.util import Writable, Replaceable, MacroKey, SourceKey
|
||||||
from dbt.exceptions import (
|
from dbt.exceptions import (
|
||||||
raise_duplicate_resource_name, raise_duplicate_patch_name,
|
raise_duplicate_resource_name,
|
||||||
raise_duplicate_macro_patch_name, CompilationException, InternalException,
|
raise_duplicate_patch_name,
|
||||||
raise_compiler_error, raise_duplicate_source_patch_name
|
raise_duplicate_macro_patch_name,
|
||||||
|
CompilationException,
|
||||||
|
InternalException,
|
||||||
|
raise_compiler_error,
|
||||||
|
raise_duplicate_source_patch_name,
|
||||||
)
|
)
|
||||||
from dbt.node_types import NodeType
|
from dbt.node_types import NodeType
|
||||||
from dbt.ui import line_wrap_message
|
from dbt.ui import line_wrap_message
|
||||||
@@ -35,12 +39,10 @@ from dbt.version import __version__
|
|||||||
|
|
||||||
|
|
||||||
# Parsers can return anything as long as it's a unique ID
|
# Parsers can return anything as long as it's a unique ID
|
||||||
ParsedValueType = TypeVar('ParsedValueType', bound=HasUniqueID)
|
ParsedValueType = TypeVar("ParsedValueType", bound=HasUniqueID)
|
||||||
|
|
||||||
|
|
||||||
def _check_duplicates(
|
def _check_duplicates(value: HasUniqueID, src: Mapping[str, HasUniqueID]):
|
||||||
value: HasUniqueID, src: Mapping[str, HasUniqueID]
|
|
||||||
):
|
|
||||||
if value.unique_id in src:
|
if value.unique_id in src:
|
||||||
raise_duplicate_resource_name(value, src[value.unique_id])
|
raise_duplicate_resource_name(value, src[value.unique_id])
|
||||||
|
|
||||||
@@ -62,7 +64,7 @@ def dict_field():
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ParseResult(JsonSchemaMixin, Writable, Replaceable):
|
class ParseResult(dbtClassMixin, Writable, Replaceable):
|
||||||
vars_hash: FileHash
|
vars_hash: FileHash
|
||||||
profile_hash: FileHash
|
profile_hash: FileHash
|
||||||
project_hashes: MutableMapping[str, FileHash]
|
project_hashes: MutableMapping[str, FileHash]
|
||||||
@@ -86,9 +88,7 @@ class ParseResult(JsonSchemaMixin, Writable, Replaceable):
|
|||||||
self.files[key] = source_file
|
self.files[key] = source_file
|
||||||
return self.files[key]
|
return self.files[key]
|
||||||
|
|
||||||
def add_source(
|
def add_source(self, source_file: SourceFile, source: UnpatchedSourceDefinition):
|
||||||
self, source_file: SourceFile, source: UnpatchedSourceDefinition
|
|
||||||
):
|
|
||||||
# sources can't be overwritten!
|
# sources can't be overwritten!
|
||||||
_check_duplicates(source, self.sources)
|
_check_duplicates(source, self.sources)
|
||||||
self.sources[source.unique_id] = source
|
self.sources[source.unique_id] = source
|
||||||
@@ -126,7 +126,7 @@ class ParseResult(JsonSchemaMixin, Writable, Replaceable):
|
|||||||
# note that the line wrap eats newlines, so if you want newlines,
|
# note that the line wrap eats newlines, so if you want newlines,
|
||||||
# this is the result :(
|
# this is the result :(
|
||||||
msg = line_wrap_message(
|
msg = line_wrap_message(
|
||||||
f'''\
|
f"""\
|
||||||
dbt found two macros named "{macro.name}" in the project
|
dbt found two macros named "{macro.name}" in the project
|
||||||
"{macro.package_name}".
|
"{macro.package_name}".
|
||||||
|
|
||||||
@@ -137,8 +137,8 @@ class ParseResult(JsonSchemaMixin, Writable, Replaceable):
|
|||||||
- {macro.original_file_path}
|
- {macro.original_file_path}
|
||||||
|
|
||||||
- {other_path}
|
- {other_path}
|
||||||
''',
|
""",
|
||||||
subtract=2
|
subtract=2,
|
||||||
)
|
)
|
||||||
raise_compiler_error(msg)
|
raise_compiler_error(msg)
|
||||||
|
|
||||||
@@ -150,18 +150,14 @@ class ParseResult(JsonSchemaMixin, Writable, Replaceable):
|
|||||||
self.docs[doc.unique_id] = doc
|
self.docs[doc.unique_id] = doc
|
||||||
self.get_file(source_file).docs.append(doc.unique_id)
|
self.get_file(source_file).docs.append(doc.unique_id)
|
||||||
|
|
||||||
def add_patch(
|
def add_patch(self, source_file: SourceFile, patch: ParsedNodePatch) -> None:
|
||||||
self, source_file: SourceFile, patch: ParsedNodePatch
|
|
||||||
) -> None:
|
|
||||||
# patches can't be overwritten
|
# patches can't be overwritten
|
||||||
if patch.name in self.patches:
|
if patch.name in self.patches:
|
||||||
raise_duplicate_patch_name(patch, self.patches[patch.name])
|
raise_duplicate_patch_name(patch, self.patches[patch.name])
|
||||||
self.patches[patch.name] = patch
|
self.patches[patch.name] = patch
|
||||||
self.get_file(source_file).patches.append(patch.name)
|
self.get_file(source_file).patches.append(patch.name)
|
||||||
|
|
||||||
def add_macro_patch(
|
def add_macro_patch(self, source_file: SourceFile, patch: ParsedMacroPatch) -> None:
|
||||||
self, source_file: SourceFile, patch: ParsedMacroPatch
|
|
||||||
) -> None:
|
|
||||||
# macros are fully namespaced
|
# macros are fully namespaced
|
||||||
key = (patch.package_name, patch.name)
|
key = (patch.package_name, patch.name)
|
||||||
if key in self.macro_patches:
|
if key in self.macro_patches:
|
||||||
@@ -169,9 +165,7 @@ class ParseResult(JsonSchemaMixin, Writable, Replaceable):
|
|||||||
self.macro_patches[key] = patch
|
self.macro_patches[key] = patch
|
||||||
self.get_file(source_file).macro_patches.append(key)
|
self.get_file(source_file).macro_patches.append(key)
|
||||||
|
|
||||||
def add_source_patch(
|
def add_source_patch(self, source_file: SourceFile, patch: SourcePatch) -> None:
|
||||||
self, source_file: SourceFile, patch: SourcePatch
|
|
||||||
) -> None:
|
|
||||||
# source patches must be unique
|
# source patches must be unique
|
||||||
key = (patch.overrides, patch.name)
|
key = (patch.overrides, patch.name)
|
||||||
if key in self.source_patches:
|
if key in self.source_patches:
|
||||||
@@ -186,11 +180,13 @@ class ParseResult(JsonSchemaMixin, Writable, Replaceable):
|
|||||||
) -> List[CompileResultNode]:
|
) -> List[CompileResultNode]:
|
||||||
if unique_id not in self.disabled:
|
if unique_id not in self.disabled:
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
'called _get_disabled with id={}, but it does not exist'
|
"called _get_disabled with id={}, but it does not exist".format(
|
||||||
.format(unique_id)
|
unique_id
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return [
|
return [
|
||||||
n for n in self.disabled[unique_id]
|
n
|
||||||
|
for n in self.disabled[unique_id]
|
||||||
if n.original_file_path == match_file.path.original_file_path
|
if n.original_file_path == match_file.path.original_file_path
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -199,7 +195,7 @@ class ParseResult(JsonSchemaMixin, Writable, Replaceable):
|
|||||||
node_id: str,
|
node_id: str,
|
||||||
source_file: SourceFile,
|
source_file: SourceFile,
|
||||||
old_file: SourceFile,
|
old_file: SourceFile,
|
||||||
old_result: 'ParseResult',
|
old_result: "ParseResult",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Nodes are a special kind of complicated - there can be multiple
|
"""Nodes are a special kind of complicated - there can be multiple
|
||||||
with the same name, as long as all but one are disabled.
|
with the same name, as long as all but one are disabled.
|
||||||
@@ -224,14 +220,15 @@ class ParseResult(JsonSchemaMixin, Writable, Replaceable):
|
|||||||
if not found:
|
if not found:
|
||||||
raise CompilationException(
|
raise CompilationException(
|
||||||
'Expected to find "{}" in cached "manifest.nodes" or '
|
'Expected to find "{}" in cached "manifest.nodes" or '
|
||||||
'"manifest.disabled" based on cached file information: {}!'
|
'"manifest.disabled" based on cached file information: {}!'.format(
|
||||||
.format(node_id, old_file)
|
node_id, old_file
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def sanitized_update(
|
def sanitized_update(
|
||||||
self,
|
self,
|
||||||
source_file: SourceFile,
|
source_file: SourceFile,
|
||||||
old_result: 'ParseResult',
|
old_result: "ParseResult",
|
||||||
resource_type: NodeType,
|
resource_type: NodeType,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Perform a santized update. If the file can't be updated, invalidate
|
"""Perform a santized update. If the file can't be updated, invalidate
|
||||||
@@ -246,15 +243,11 @@ class ParseResult(JsonSchemaMixin, Writable, Replaceable):
|
|||||||
self.add_doc(source_file, doc)
|
self.add_doc(source_file, doc)
|
||||||
|
|
||||||
for macro_id in old_file.macros:
|
for macro_id in old_file.macros:
|
||||||
macro = _expect_value(
|
macro = _expect_value(macro_id, old_result.macros, old_file, "macros")
|
||||||
macro_id, old_result.macros, old_file, "macros"
|
|
||||||
)
|
|
||||||
self.add_macro(source_file, macro)
|
self.add_macro(source_file, macro)
|
||||||
|
|
||||||
for source_id in old_file.sources:
|
for source_id in old_file.sources:
|
||||||
source = _expect_value(
|
source = _expect_value(source_id, old_result.sources, old_file, "sources")
|
||||||
source_id, old_result.sources, old_file, "sources"
|
|
||||||
)
|
|
||||||
self.add_source(source_file, source)
|
self.add_source(source_file, source)
|
||||||
|
|
||||||
# because we know this is how we _parsed_ the node, we can safely
|
# because we know this is how we _parsed_ the node, we can safely
|
||||||
@@ -265,7 +258,7 @@ class ParseResult(JsonSchemaMixin, Writable, Replaceable):
|
|||||||
for node_id in old_file.nodes:
|
for node_id in old_file.nodes:
|
||||||
# cheat: look at the first part of the node ID and compare it to
|
# cheat: look at the first part of the node ID and compare it to
|
||||||
# the parser resource type. On a mismatch, bail out.
|
# the parser resource type. On a mismatch, bail out.
|
||||||
if resource_type != node_id.split('.')[0]:
|
if resource_type != node_id.split(".")[0]:
|
||||||
continue
|
continue
|
||||||
self._process_node(node_id, source_file, old_file, old_result)
|
self._process_node(node_id, source_file, old_file, old_result)
|
||||||
|
|
||||||
@@ -277,9 +270,7 @@ class ParseResult(JsonSchemaMixin, Writable, Replaceable):
|
|||||||
|
|
||||||
patched = False
|
patched = False
|
||||||
for name in old_file.patches:
|
for name in old_file.patches:
|
||||||
patch = _expect_value(
|
patch = _expect_value(name, old_result.patches, old_file, "patches")
|
||||||
name, old_result.patches, old_file, "patches"
|
|
||||||
)
|
|
||||||
self.add_patch(source_file, patch)
|
self.add_patch(source_file, patch)
|
||||||
patched = True
|
patched = True
|
||||||
if patched:
|
if patched:
|
||||||
@@ -312,8 +303,8 @@ class ParseResult(JsonSchemaMixin, Writable, Replaceable):
|
|||||||
return cls(FileHash.empty(), FileHash.empty(), {})
|
return cls(FileHash.empty(), FileHash.empty(), {})
|
||||||
|
|
||||||
|
|
||||||
K_T = TypeVar('K_T')
|
K_T = TypeVar("K_T")
|
||||||
V_T = TypeVar('V_T')
|
V_T = TypeVar("V_T")
|
||||||
|
|
||||||
|
|
||||||
def _expect_value(
|
def _expect_value(
|
||||||
@@ -322,7 +313,6 @@ def _expect_value(
|
|||||||
if key not in src:
|
if key not in src:
|
||||||
raise CompilationException(
|
raise CompilationException(
|
||||||
'Expected to find "{}" in cached "result.{}" based '
|
'Expected to find "{}" in cached "result.{}" based '
|
||||||
'on cached file information: {}!'
|
"on cached file information: {}!".format(key, name, old_file)
|
||||||
.format(key, name, old_file)
|
|
||||||
)
|
)
|
||||||
return src[key]
|
return src[key]
|
||||||
|
|||||||
@@ -26,7 +26,9 @@ class RPCCallParser(SimpleSQLParser[ParsedRPCNode]):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
def parse_from_dict(self, dct, validate=True) -> ParsedRPCNode:
|
def parse_from_dict(self, dct, validate=True) -> ParsedRPCNode:
|
||||||
return ParsedRPCNode.from_dict(dct, validate=validate)
|
if validate:
|
||||||
|
ParsedRPCNode.validate(dct)
|
||||||
|
return ParsedRPCNode.from_dict(dct)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def resource_type(self) -> NodeType:
|
def resource_type(self) -> NodeType:
|
||||||
@@ -36,11 +38,11 @@ class RPCCallParser(SimpleSQLParser[ParsedRPCNode]):
|
|||||||
# we do it this way to make mypy happy
|
# we do it this way to make mypy happy
|
||||||
if not isinstance(block, RPCBlock):
|
if not isinstance(block, RPCBlock):
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
'While parsing RPC calls, got an actual file block instead of '
|
"While parsing RPC calls, got an actual file block instead of "
|
||||||
'an RPC block: {}'.format(block)
|
"an RPC block: {}".format(block)
|
||||||
)
|
)
|
||||||
|
|
||||||
return os.path.join('rpc', block.name)
|
return os.path.join("rpc", block.name)
|
||||||
|
|
||||||
def parse_remote(self, sql: str, name: str) -> ParsedRPCNode:
|
def parse_remote(self, sql: str, name: str) -> ParsedRPCNode:
|
||||||
source_file = SourceFile.remote(contents=sql)
|
source_file = SourceFile.remote(contents=sql)
|
||||||
@@ -51,8 +53,8 @@ class RPCCallParser(SimpleSQLParser[ParsedRPCNode]):
|
|||||||
class RPCMacroParser(MacroParser):
|
class RPCMacroParser(MacroParser):
|
||||||
def parse_remote(self, contents) -> Iterable[ParsedMacro]:
|
def parse_remote(self, contents) -> Iterable[ParsedMacro]:
|
||||||
base = UnparsedMacro(
|
base = UnparsedMacro(
|
||||||
path='from remote system',
|
path="from remote system",
|
||||||
original_file_path='from remote system',
|
original_file_path="from remote system",
|
||||||
package_name=self.project.project_name,
|
package_name=self.project.project_name,
|
||||||
raw_sql=contents,
|
raw_sql=contents,
|
||||||
root_path=self.project.project_root,
|
root_path=self.project.project_root,
|
||||||
|
|||||||
@@ -3,7 +3,13 @@ import re
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (
|
from typing import (
|
||||||
Generic, TypeVar, Dict, Any, Tuple, Optional, List,
|
Generic,
|
||||||
|
TypeVar,
|
||||||
|
Dict,
|
||||||
|
Any,
|
||||||
|
Tuple,
|
||||||
|
Optional,
|
||||||
|
List,
|
||||||
)
|
)
|
||||||
|
|
||||||
from dbt.clients.jinja import get_rendered, SCHEMA_TEST_KWARGS_NAME
|
from dbt.clients.jinja import get_rendered, SCHEMA_TEST_KWARGS_NAME
|
||||||
@@ -25,7 +31,7 @@ def get_nice_schema_test_name(
|
|||||||
flat_args = []
|
flat_args = []
|
||||||
for arg_name in sorted(args):
|
for arg_name in sorted(args):
|
||||||
# the model is already embedded in the name, so skip it
|
# the model is already embedded in the name, so skip it
|
||||||
if arg_name == 'model':
|
if arg_name == "model":
|
||||||
continue
|
continue
|
||||||
arg_val = args[arg_name]
|
arg_val = args[arg_name]
|
||||||
|
|
||||||
@@ -38,17 +44,17 @@ def get_nice_schema_test_name(
|
|||||||
|
|
||||||
flat_args.extend([str(part) for part in parts])
|
flat_args.extend([str(part) for part in parts])
|
||||||
|
|
||||||
clean_flat_args = [re.sub('[^0-9a-zA-Z_]+', '_', arg) for arg in flat_args]
|
clean_flat_args = [re.sub("[^0-9a-zA-Z_]+", "_", arg) for arg in flat_args]
|
||||||
unique = "__".join(clean_flat_args)
|
unique = "__".join(clean_flat_args)
|
||||||
|
|
||||||
cutoff = 32
|
cutoff = 32
|
||||||
if len(unique) <= cutoff:
|
if len(unique) <= cutoff:
|
||||||
label = unique
|
label = unique
|
||||||
else:
|
else:
|
||||||
label = hashlib.md5(unique.encode('utf-8')).hexdigest()
|
label = hashlib.md5(unique.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
filename = '{}_{}_{}'.format(test_type, test_name, label)
|
filename = "{}_{}_{}".format(test_type, test_name, label)
|
||||||
name = '{}_{}_{}'.format(test_type, test_name, unique)
|
name = "{}_{}_{}".format(test_type, test_name, unique)
|
||||||
|
|
||||||
return filename, name
|
return filename, name
|
||||||
|
|
||||||
@@ -65,19 +71,17 @@ class YamlBlock(FileBlock):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
Testable = TypeVar(
|
Testable = TypeVar("Testable", UnparsedNodeUpdate, UnpatchedSourceDefinition)
|
||||||
'Testable', UnparsedNodeUpdate, UnpatchedSourceDefinition
|
|
||||||
)
|
|
||||||
|
|
||||||
ColumnTarget = TypeVar(
|
ColumnTarget = TypeVar(
|
||||||
'ColumnTarget',
|
"ColumnTarget",
|
||||||
UnparsedNodeUpdate,
|
UnparsedNodeUpdate,
|
||||||
UnparsedAnalysisUpdate,
|
UnparsedAnalysisUpdate,
|
||||||
UnpatchedSourceDefinition,
|
UnpatchedSourceDefinition,
|
||||||
)
|
)
|
||||||
|
|
||||||
Target = TypeVar(
|
Target = TypeVar(
|
||||||
'Target',
|
"Target",
|
||||||
UnparsedNodeUpdate,
|
UnparsedNodeUpdate,
|
||||||
UnparsedMacroUpdate,
|
UnparsedMacroUpdate,
|
||||||
UnparsedAnalysisUpdate,
|
UnparsedAnalysisUpdate,
|
||||||
@@ -103,9 +107,7 @@ class TargetBlock(YamlBlock, Generic[Target]):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_yaml_block(
|
def from_yaml_block(cls, src: YamlBlock, target: Target) -> "TargetBlock[Target]":
|
||||||
cls, src: YamlBlock, target: Target
|
|
||||||
) -> 'TargetBlock[Target]':
|
|
||||||
return cls(
|
return cls(
|
||||||
file=src.file,
|
file=src.file,
|
||||||
data=src.data,
|
data=src.data,
|
||||||
@@ -137,9 +139,7 @@ class TestBlock(TargetColumnsBlock[Testable], Generic[Testable]):
|
|||||||
return self.target.quote_columns
|
return self.target.quote_columns
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_yaml_block(
|
def from_yaml_block(cls, src: YamlBlock, target: Testable) -> "TestBlock[Testable]":
|
||||||
cls, src: YamlBlock, target: Testable
|
|
||||||
) -> 'TestBlock[Testable]':
|
|
||||||
return cls(
|
return cls(
|
||||||
file=src.file,
|
file=src.file,
|
||||||
data=src.data,
|
data=src.data,
|
||||||
@@ -160,7 +160,7 @@ class SchemaTestBlock(TestBlock[Testable], Generic[Testable]):
|
|||||||
test: Dict[str, Any],
|
test: Dict[str, Any],
|
||||||
column_name: Optional[str],
|
column_name: Optional[str],
|
||||||
tags: List[str],
|
tags: List[str],
|
||||||
) -> 'SchemaTestBlock':
|
) -> "SchemaTestBlock":
|
||||||
return cls(
|
return cls(
|
||||||
file=src.file,
|
file=src.file,
|
||||||
data=src.data,
|
data=src.data,
|
||||||
@@ -179,12 +179,14 @@ class TestBuilder(Generic[Testable]):
|
|||||||
- or it may not be namespaced (test)
|
- or it may not be namespaced (test)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# The 'test_name' is used to find the 'macro' that implements the test
|
||||||
TEST_NAME_PATTERN = re.compile(
|
TEST_NAME_PATTERN = re.compile(
|
||||||
r'((?P<test_namespace>([a-zA-Z_][0-9a-zA-Z_]*))\.)?'
|
r"((?P<test_namespace>([a-zA-Z_][0-9a-zA-Z_]*))\.)?"
|
||||||
r'(?P<test_name>([a-zA-Z_][0-9a-zA-Z_]*))'
|
r"(?P<test_name>([a-zA-Z_][0-9a-zA-Z_]*))"
|
||||||
)
|
)
|
||||||
# map magic keys to default values
|
# map magic keys to default values
|
||||||
MODIFIER_ARGS = {'severity': 'ERROR', 'tags': []}
|
MODIFIER_ARGS = {"severity": "ERROR", "tags": []}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -196,25 +198,24 @@ class TestBuilder(Generic[Testable]):
|
|||||||
) -> None:
|
) -> None:
|
||||||
test_name, test_args = self.extract_test_args(test, column_name)
|
test_name, test_args = self.extract_test_args(test, column_name)
|
||||||
self.args: Dict[str, Any] = test_args
|
self.args: Dict[str, Any] = test_args
|
||||||
if 'model' in self.args:
|
if "model" in self.args:
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
'Test arguments include "model", which is a reserved argument',
|
'Test arguments include "model", which is a reserved argument',
|
||||||
)
|
)
|
||||||
self.package_name: str = package_name
|
self.package_name: str = package_name
|
||||||
self.target: Testable = target
|
self.target: Testable = target
|
||||||
|
|
||||||
self.args['model'] = self.build_model_str()
|
self.args["model"] = self.build_model_str()
|
||||||
|
|
||||||
match = self.TEST_NAME_PATTERN.match(test_name)
|
match = self.TEST_NAME_PATTERN.match(test_name)
|
||||||
if match is None:
|
if match is None:
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
'Test name string did not match expected pattern: {}'
|
"Test name string did not match expected pattern: {}".format(test_name)
|
||||||
.format(test_name)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
groups = match.groupdict()
|
groups = match.groupdict()
|
||||||
self.name: str = groups['test_name']
|
self.name: str = groups["test_name"]
|
||||||
self.namespace: str = groups['test_namespace']
|
self.namespace: str = groups["test_namespace"]
|
||||||
self.modifiers: Dict[str, Any] = {}
|
self.modifiers: Dict[str, Any] = {}
|
||||||
for key, default in self.MODIFIER_ARGS.items():
|
for key, default in self.MODIFIER_ARGS.items():
|
||||||
value = self.args.pop(key, default)
|
value = self.args.pop(key, default)
|
||||||
@@ -236,57 +237,52 @@ class TestBuilder(Generic[Testable]):
|
|||||||
def extract_test_args(test, name=None) -> Tuple[str, Dict[str, Any]]:
|
def extract_test_args(test, name=None) -> Tuple[str, Dict[str, Any]]:
|
||||||
if not isinstance(test, dict):
|
if not isinstance(test, dict):
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
'test must be dict or str, got {} (value {})'.format(
|
"test must be dict or str, got {} (value {})".format(type(test), test)
|
||||||
type(test), test
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
test = list(test.items())
|
test = list(test.items())
|
||||||
if len(test) != 1:
|
if len(test) != 1:
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
'test definition dictionary must have exactly one key, got'
|
"test definition dictionary must have exactly one key, got"
|
||||||
' {} instead ({} keys)'.format(test, len(test))
|
" {} instead ({} keys)".format(test, len(test))
|
||||||
)
|
)
|
||||||
test_name, test_args = test[0]
|
test_name, test_args = test[0]
|
||||||
|
|
||||||
if not isinstance(test_args, dict):
|
if not isinstance(test_args, dict):
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
'test arguments must be dict, got {} (value {})'.format(
|
"test arguments must be dict, got {} (value {})".format(
|
||||||
type(test_args), test_args
|
type(test_args), test_args
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if not isinstance(test_name, str):
|
if not isinstance(test_name, str):
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
'test name must be a str, got {} (value {})'.format(
|
"test name must be a str, got {} (value {})".format(
|
||||||
type(test_name), test_name
|
type(test_name), test_name
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
test_args = deepcopy(test_args)
|
test_args = deepcopy(test_args)
|
||||||
if name is not None:
|
if name is not None:
|
||||||
test_args['column_name'] = name
|
test_args["column_name"] = name
|
||||||
return test_name, test_args
|
return test_name, test_args
|
||||||
|
|
||||||
def severity(self) -> str:
|
def severity(self) -> str:
|
||||||
return self.modifiers.get('severity', 'ERROR').upper()
|
return self.modifiers.get("severity", "ERROR").upper()
|
||||||
|
|
||||||
def tags(self) -> List[str]:
|
def tags(self) -> List[str]:
|
||||||
tags = self.modifiers.get('tags', [])
|
tags = self.modifiers.get("tags", [])
|
||||||
if isinstance(tags, str):
|
if isinstance(tags, str):
|
||||||
tags = [tags]
|
tags = [tags]
|
||||||
if not isinstance(tags, list):
|
if not isinstance(tags, list):
|
||||||
raise_compiler_error(
|
raise_compiler_error(
|
||||||
f'got {tags} ({type(tags)}) for tags, expected a list of '
|
f"got {tags} ({type(tags)}) for tags, expected a list of " f"strings"
|
||||||
f'strings'
|
|
||||||
)
|
)
|
||||||
for tag in tags:
|
for tag in tags:
|
||||||
if not isinstance(tag, str):
|
if not isinstance(tag, str):
|
||||||
raise_compiler_error(
|
raise_compiler_error(f"got {tag} ({type(tag)}) for tag, expected a str")
|
||||||
f'got {tag} ({type(tag)}) for tag, expected a str'
|
|
||||||
)
|
|
||||||
return tags[:]
|
return tags[:]
|
||||||
|
|
||||||
def macro_name(self) -> str:
|
def macro_name(self) -> str:
|
||||||
macro_name = 'test_{}'.format(self.name)
|
macro_name = "test_{}".format(self.name)
|
||||||
if self.namespace is not None:
|
if self.namespace is not None:
|
||||||
macro_name = "{}.{}".format(self.namespace, macro_name)
|
macro_name = "{}.{}".format(self.namespace, macro_name)
|
||||||
return macro_name
|
return macro_name
|
||||||
@@ -295,13 +291,15 @@ class TestBuilder(Generic[Testable]):
|
|||||||
if isinstance(self.target, UnparsedNodeUpdate):
|
if isinstance(self.target, UnparsedNodeUpdate):
|
||||||
name = self.name
|
name = self.name
|
||||||
elif isinstance(self.target, UnpatchedSourceDefinition):
|
elif isinstance(self.target, UnpatchedSourceDefinition):
|
||||||
name = 'source_' + self.name
|
name = "source_" + self.name
|
||||||
else:
|
else:
|
||||||
raise self._bad_type()
|
raise self._bad_type()
|
||||||
if self.namespace is not None:
|
if self.namespace is not None:
|
||||||
name = '{}_{}'.format(self.namespace, name)
|
name = "{}_{}".format(self.namespace, name)
|
||||||
return get_nice_schema_test_name(name, self.target.name, self.args)
|
return get_nice_schema_test_name(name, self.target.name, self.args)
|
||||||
|
|
||||||
|
# this is the 'raw_sql' that's used in 'render_update' and execution
|
||||||
|
# of the test macro
|
||||||
def build_raw_sql(self) -> str:
|
def build_raw_sql(self) -> str:
|
||||||
return (
|
return (
|
||||||
"{{{{ config(severity='{severity}') }}}}"
|
"{{{{ config(severity='{severity}') }}}}"
|
||||||
|
|||||||
@@ -2,13 +2,11 @@ import itertools
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
from typing import (
|
from typing import Iterable, Dict, Any, Union, List, Optional, Generic, TypeVar, Type
|
||||||
Iterable, Dict, Any, Union, List, Optional, Generic, TypeVar, Type
|
|
||||||
)
|
|
||||||
|
|
||||||
from hologram import ValidationError, JsonSchemaMixin
|
from dbt.dataclass_schema import ValidationError, dbtClassMixin
|
||||||
|
|
||||||
from dbt.adapters.factory import get_adapter
|
from dbt.adapters.factory import get_adapter, get_adapter_package_names
|
||||||
from dbt.clients.jinja import get_rendered, add_rendered_test_kwargs
|
from dbt.clients.jinja import get_rendered, add_rendered_test_kwargs
|
||||||
from dbt.clients.yaml_helper import load_yaml_text
|
from dbt.clients.yaml_helper import load_yaml_text
|
||||||
from dbt.config.renderer import SchemaYamlRenderer
|
from dbt.config.renderer import SchemaYamlRenderer
|
||||||
@@ -20,7 +18,8 @@ from dbt.context.context_config import (
|
|||||||
)
|
)
|
||||||
from dbt.context.configured import generate_schema_yml
|
from dbt.context.configured import generate_schema_yml
|
||||||
from dbt.context.target import generate_target_context
|
from dbt.context.target import generate_target_context
|
||||||
from dbt.context.providers import generate_parse_exposure
|
from dbt.context.providers import generate_parse_exposure, generate_test_context
|
||||||
|
from dbt.context.macro_resolver import MacroResolver
|
||||||
from dbt.contracts.files import FileHash
|
from dbt.contracts.files import FileHash
|
||||||
from dbt.contracts.graph.manifest import SourceFile
|
from dbt.contracts.graph.manifest import SourceFile
|
||||||
from dbt.contracts.graph.model_config import SourceConfig
|
from dbt.contracts.graph.model_config import SourceConfig
|
||||||
@@ -47,20 +46,26 @@ from dbt.contracts.graph.unparsed import (
|
|||||||
UnparsedSourceDefinition,
|
UnparsedSourceDefinition,
|
||||||
)
|
)
|
||||||
from dbt.exceptions import (
|
from dbt.exceptions import (
|
||||||
validator_error_message, JSONValidationException,
|
validator_error_message,
|
||||||
raise_invalid_schema_yml_version, ValidationException,
|
JSONValidationException,
|
||||||
CompilationException, warn_or_error, InternalException
|
raise_invalid_schema_yml_version,
|
||||||
|
ValidationException,
|
||||||
|
CompilationException,
|
||||||
|
warn_or_error,
|
||||||
|
InternalException,
|
||||||
)
|
)
|
||||||
from dbt.node_types import NodeType
|
from dbt.node_types import NodeType
|
||||||
from dbt.parser.base import SimpleParser
|
from dbt.parser.base import SimpleParser
|
||||||
from dbt.parser.search import FileBlock, FilesystemSearcher
|
from dbt.parser.search import FileBlock, FilesystemSearcher
|
||||||
from dbt.parser.schema_test_builders import (
|
from dbt.parser.schema_test_builders import (
|
||||||
TestBuilder, SchemaTestBlock, TargetBlock, YamlBlock,
|
TestBuilder,
|
||||||
TestBlock, Testable
|
SchemaTestBlock,
|
||||||
)
|
TargetBlock,
|
||||||
from dbt.utils import (
|
YamlBlock,
|
||||||
get_pseudo_test_path, coerce_dict_str
|
TestBlock,
|
||||||
|
Testable,
|
||||||
)
|
)
|
||||||
|
from dbt.utils import get_pseudo_test_path, coerce_dict_str
|
||||||
|
|
||||||
|
|
||||||
UnparsedSchemaYaml = Union[
|
UnparsedSchemaYaml = Union[
|
||||||
@@ -77,19 +82,17 @@ def error_context(
|
|||||||
path: str,
|
path: str,
|
||||||
key: str,
|
key: str,
|
||||||
data: Any,
|
data: Any,
|
||||||
cause: Union[str, ValidationException, JSONValidationException]
|
cause: Union[str, ValidationException, JSONValidationException],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Provide contextual information about an error while parsing
|
"""Provide contextual information about an error while parsing"""
|
||||||
"""
|
|
||||||
if isinstance(cause, str):
|
if isinstance(cause, str):
|
||||||
reason = cause
|
reason = cause
|
||||||
elif isinstance(cause, ValidationError):
|
elif isinstance(cause, ValidationError):
|
||||||
reason = validator_error_message(cause)
|
reason = validator_error_message(cause)
|
||||||
else:
|
else:
|
||||||
reason = cause.msg
|
reason = cause.msg
|
||||||
return (
|
return "Invalid {key} config given in {path} @ {key}: {data} - {reason}".format(
|
||||||
'Invalid {key} config given in {path} @ {key}: {data} - {reason}'
|
key=key, path=path, data=data, reason=reason
|
||||||
.format(key=key, path=path, data=data, reason=reason)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -107,7 +110,7 @@ class ParserRef:
|
|||||||
meta: Dict[str, Any],
|
meta: Dict[str, Any],
|
||||||
):
|
):
|
||||||
tags: List[str] = []
|
tags: List[str] = []
|
||||||
tags.extend(getattr(column, 'tags', ()))
|
tags.extend(getattr(column, "tags", ()))
|
||||||
quote: Optional[bool]
|
quote: Optional[bool]
|
||||||
if isinstance(column, UnparsedColumn):
|
if isinstance(column, UnparsedColumn):
|
||||||
quote = column.quote
|
quote = column.quote
|
||||||
@@ -120,13 +123,11 @@ class ParserRef:
|
|||||||
meta=meta,
|
meta=meta,
|
||||||
tags=tags,
|
tags=tags,
|
||||||
quote=quote,
|
quote=quote,
|
||||||
_extra=column.extra
|
_extra=column.extra,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_target(
|
def from_target(cls, target: Union[HasColumnDocs, HasColumnTests]) -> "ParserRef":
|
||||||
cls, target: Union[HasColumnDocs, HasColumnTests]
|
|
||||||
) -> 'ParserRef':
|
|
||||||
refs = cls()
|
refs = cls()
|
||||||
for column in target.columns:
|
for column in target.columns:
|
||||||
description = column.description
|
description = column.description
|
||||||
@@ -139,7 +140,7 @@ class ParserRef:
|
|||||||
def _trimmed(inp: str) -> str:
|
def _trimmed(inp: str) -> str:
|
||||||
if len(inp) < 50:
|
if len(inp) < 50:
|
||||||
return inp
|
return inp
|
||||||
return inp[:44] + '...' + inp[-3:]
|
return inp[:44] + "..." + inp[-3:]
|
||||||
|
|
||||||
|
|
||||||
def merge_freshness(
|
def merge_freshness(
|
||||||
@@ -155,24 +156,32 @@ def merge_freshness(
|
|||||||
|
|
||||||
class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, results, project, root_project, macro_manifest,
|
self,
|
||||||
|
results,
|
||||||
|
project,
|
||||||
|
root_project,
|
||||||
|
macro_manifest,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(results, project, root_project, macro_manifest)
|
super().__init__(results, project, root_project, macro_manifest)
|
||||||
all_v_2 = (
|
all_v_2 = (
|
||||||
self.root_project.config_version == 2 and
|
self.root_project.config_version == 2 and self.project.config_version == 2
|
||||||
self.project.config_version == 2
|
|
||||||
)
|
)
|
||||||
if all_v_2:
|
if all_v_2:
|
||||||
ctx = generate_schema_yml(
|
ctx = generate_schema_yml(self.root_project, self.project.project_name)
|
||||||
self.root_project, self.project.project_name
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
ctx = generate_target_context(
|
ctx = generate_target_context(self.root_project, self.root_project.cli_vars)
|
||||||
self.root_project, self.root_project.cli_vars
|
|
||||||
)
|
|
||||||
|
|
||||||
self.raw_renderer = SchemaYamlRenderer(ctx)
|
self.raw_renderer = SchemaYamlRenderer(ctx)
|
||||||
|
|
||||||
|
internal_package_names = get_adapter_package_names(
|
||||||
|
self.root_project.credentials.type
|
||||||
|
)
|
||||||
|
self.macro_resolver = MacroResolver(
|
||||||
|
self.macro_manifest.macros,
|
||||||
|
self.root_project.project_name,
|
||||||
|
internal_package_names,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_compiled_path(cls, block: FileBlock) -> str:
|
def get_compiled_path(cls, block: FileBlock) -> str:
|
||||||
# should this raise an error?
|
# should this raise an error?
|
||||||
@@ -185,63 +194,55 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
|||||||
def get_paths(self):
|
def get_paths(self):
|
||||||
# TODO: In order to support this, make FilesystemSearcher accept a list
|
# TODO: In order to support this, make FilesystemSearcher accept a list
|
||||||
# of file patterns. eg: ['.yml', '.yaml']
|
# of file patterns. eg: ['.yml', '.yaml']
|
||||||
yaml_files = list(FilesystemSearcher(
|
yaml_files = list(
|
||||||
self.project, self.project.all_source_paths, '.yaml'
|
FilesystemSearcher(self.project, self.project.all_source_paths, ".yaml")
|
||||||
))
|
)
|
||||||
if yaml_files:
|
if yaml_files:
|
||||||
warn_or_error(
|
warn_or_error(
|
||||||
'A future version of dbt will parse files with both'
|
"A future version of dbt will parse files with both"
|
||||||
' .yml and .yaml file extensions. dbt found'
|
" .yml and .yaml file extensions. dbt found"
|
||||||
f' {len(yaml_files)} files with .yaml extensions in'
|
f" {len(yaml_files)} files with .yaml extensions in"
|
||||||
' your dbt project. To avoid errors when upgrading'
|
" your dbt project. To avoid errors when upgrading"
|
||||||
' to a future release, either remove these files from'
|
" to a future release, either remove these files from"
|
||||||
' your dbt project, or change their extensions.'
|
" your dbt project, or change their extensions."
|
||||||
)
|
)
|
||||||
return FilesystemSearcher(
|
return FilesystemSearcher(self.project, self.project.all_source_paths, ".yml")
|
||||||
self.project, self.project.all_source_paths, '.yml'
|
|
||||||
)
|
|
||||||
|
|
||||||
def parse_from_dict(self, dct, validate=True) -> ParsedSchemaTestNode:
|
def parse_from_dict(self, dct, validate=True) -> ParsedSchemaTestNode:
|
||||||
return ParsedSchemaTestNode.from_dict(dct, validate=validate)
|
if validate:
|
||||||
|
ParsedSchemaTestNode.validate(dct)
|
||||||
|
return ParsedSchemaTestNode.from_dict(dct)
|
||||||
|
|
||||||
def _parse_format_version(
|
def _check_format_version(self, yaml: YamlBlock) -> None:
|
||||||
self, yaml: YamlBlock
|
|
||||||
) -> None:
|
|
||||||
path = yaml.path.relative_path
|
path = yaml.path.relative_path
|
||||||
if 'version' not in yaml.data:
|
if "version" not in yaml.data:
|
||||||
raise_invalid_schema_yml_version(path, 'no version is specified')
|
raise_invalid_schema_yml_version(path, "no version is specified")
|
||||||
|
|
||||||
version = yaml.data['version']
|
version = yaml.data["version"]
|
||||||
# if it's not an integer, the version is malformed, or not
|
# if it's not an integer, the version is malformed, or not
|
||||||
# set. Either way, only 'version: 2' is supported.
|
# set. Either way, only 'version: 2' is supported.
|
||||||
if not isinstance(version, int):
|
if not isinstance(version, int):
|
||||||
raise_invalid_schema_yml_version(
|
raise_invalid_schema_yml_version(path, "the version is not an integer")
|
||||||
path, 'the version is not an integer'
|
|
||||||
)
|
|
||||||
if version != 2:
|
if version != 2:
|
||||||
raise_invalid_schema_yml_version(
|
raise_invalid_schema_yml_version(
|
||||||
path, 'version {} is not supported'.format(version)
|
path, "version {} is not supported".format(version)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _yaml_from_file(
|
def _yaml_from_file(self, source_file: SourceFile) -> Optional[Dict[str, Any]]:
|
||||||
self, source_file: SourceFile
|
"""If loading the yaml fails, raise an exception."""
|
||||||
) -> Optional[Dict[str, Any]]:
|
|
||||||
"""If loading the yaml fails, raise an exception.
|
|
||||||
"""
|
|
||||||
path: str = source_file.path.relative_path
|
path: str = source_file.path.relative_path
|
||||||
try:
|
try:
|
||||||
return load_yaml_text(source_file.contents)
|
return load_yaml_text(source_file.contents)
|
||||||
except ValidationException as e:
|
except ValidationException as e:
|
||||||
reason = validator_error_message(e)
|
reason = validator_error_message(e)
|
||||||
raise CompilationException(
|
raise CompilationException(
|
||||||
'Error reading {}: {} - {}'
|
"Error reading {}: {} - {}".format(
|
||||||
.format(self.project.project_name, path, reason)
|
self.project.project_name, path, reason
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def parse_column_tests(
|
def parse_column_tests(self, block: TestBlock, column: UnparsedColumn) -> None:
|
||||||
self, block: TestBlock, column: UnparsedColumn
|
|
||||||
) -> None:
|
|
||||||
if not column.tests:
|
if not column.tests:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -253,9 +254,7 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
|||||||
if rendered:
|
if rendered:
|
||||||
generator = ContextConfigGenerator(self.root_project)
|
generator = ContextConfigGenerator(self.root_project)
|
||||||
else:
|
else:
|
||||||
generator = UnrenderedConfigGenerator(
|
generator = UnrenderedConfigGenerator(self.root_project)
|
||||||
self.root_project
|
|
||||||
)
|
|
||||||
|
|
||||||
return generator.calculate_node_config(
|
return generator.calculate_node_config(
|
||||||
config_calls=[],
|
config_calls=[],
|
||||||
@@ -270,16 +269,14 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
|||||||
relation_cls = adapter.Relation
|
relation_cls = adapter.Relation
|
||||||
return str(relation_cls.create_from(self.root_project, node))
|
return str(relation_cls.create_from(self.root_project, node))
|
||||||
|
|
||||||
def parse_source(
|
def parse_source(self, target: UnpatchedSourceDefinition) -> ParsedSourceDefinition:
|
||||||
self, target: UnpatchedSourceDefinition
|
|
||||||
) -> ParsedSourceDefinition:
|
|
||||||
source = target.source
|
source = target.source
|
||||||
table = target.table
|
table = target.table
|
||||||
refs = ParserRef.from_target(table)
|
refs = ParserRef.from_target(table)
|
||||||
unique_id = target.unique_id
|
unique_id = target.unique_id
|
||||||
description = table.description or ''
|
description = table.description or ""
|
||||||
meta = table.meta or {}
|
meta = table.meta or {}
|
||||||
source_description = source.description or ''
|
source_description = source.description or ""
|
||||||
loaded_at_field = table.loaded_at_field or source.loaded_at_field
|
loaded_at_field = table.loaded_at_field or source.loaded_at_field
|
||||||
|
|
||||||
freshness = merge_freshness(source.freshness, table.freshness)
|
freshness = merge_freshness(source.freshness, table.freshness)
|
||||||
@@ -302,8 +299,8 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
|||||||
|
|
||||||
if not isinstance(config, SourceConfig):
|
if not isinstance(config, SourceConfig):
|
||||||
raise InternalException(
|
raise InternalException(
|
||||||
f'Calculated a {type(config)} for a source, but expected '
|
f"Calculated a {type(config)} for a source, but expected "
|
||||||
f'a SourceConfig'
|
f"a SourceConfig"
|
||||||
)
|
)
|
||||||
|
|
||||||
default_database = self.root_project.credentials.database
|
default_database = self.root_project.credentials.database
|
||||||
@@ -355,26 +352,27 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
|||||||
) -> ParsedSchemaTestNode:
|
) -> ParsedSchemaTestNode:
|
||||||
|
|
||||||
dct = {
|
dct = {
|
||||||
'alias': name,
|
"alias": name,
|
||||||
'schema': self.default_schema,
|
"schema": self.default_schema,
|
||||||
'database': self.default_database,
|
"database": self.default_database,
|
||||||
'fqn': fqn,
|
"fqn": fqn,
|
||||||
'name': name,
|
"name": name,
|
||||||
'root_path': self.project.project_root,
|
"root_path": self.project.project_root,
|
||||||
'resource_type': self.resource_type,
|
"resource_type": self.resource_type,
|
||||||
'tags': tags,
|
"tags": tags,
|
||||||
'path': path,
|
"path": path,
|
||||||
'original_file_path': target.original_file_path,
|
"original_file_path": target.original_file_path,
|
||||||
'package_name': self.project.project_name,
|
"package_name": self.project.project_name,
|
||||||
'raw_sql': raw_sql,
|
"raw_sql": raw_sql,
|
||||||
'unique_id': self.generate_unique_id(name),
|
"unique_id": self.generate_unique_id(name),
|
||||||
'config': self.config_dict(config),
|
"config": self.config_dict(config),
|
||||||
'test_metadata': test_metadata,
|
"test_metadata": test_metadata,
|
||||||
'column_name': column_name,
|
"column_name": column_name,
|
||||||
'checksum': FileHash.empty().to_dict(),
|
"checksum": FileHash.empty().to_dict(omit_none=True),
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
return self.parse_from_dict(dct)
|
ParsedSchemaTestNode.validate(dct)
|
||||||
|
return ParsedSchemaTestNode.from_dict(dct)
|
||||||
except ValidationError as exc:
|
except ValidationError as exc:
|
||||||
msg = validator_error_message(exc)
|
msg = validator_error_message(exc)
|
||||||
# this is a bit silly, but build an UnparsedNode just for error
|
# this is a bit silly, but build an UnparsedNode just for error
|
||||||
@@ -387,6 +385,7 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
|||||||
)
|
)
|
||||||
raise CompilationException(msg, node=node) from exc
|
raise CompilationException(msg, node=node) from exc
|
||||||
|
|
||||||
|
# lots of time spent in this method
|
||||||
def _parse_generic_test(
|
def _parse_generic_test(
|
||||||
self,
|
self,
|
||||||
target: Testable,
|
target: Testable,
|
||||||
@@ -408,33 +407,36 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
|||||||
)
|
)
|
||||||
except CompilationException as exc:
|
except CompilationException as exc:
|
||||||
context = _trimmed(str(target))
|
context = _trimmed(str(target))
|
||||||
msg = (
|
msg = "Invalid test config given in {}:" "\n\t{}\n\t@: {}".format(
|
||||||
'Invalid test config given in {}:'
|
target.original_file_path, exc.msg, context
|
||||||
'\n\t{}\n\t@: {}'
|
|
||||||
.format(target.original_file_path, exc.msg, context)
|
|
||||||
)
|
)
|
||||||
raise CompilationException(msg) from exc
|
raise CompilationException(msg) from exc
|
||||||
original_name = os.path.basename(target.original_file_path)
|
original_name = os.path.basename(target.original_file_path)
|
||||||
compiled_path = get_pseudo_test_path(
|
compiled_path = get_pseudo_test_path(
|
||||||
builder.compiled_name, original_name, 'schema_test',
|
builder.compiled_name,
|
||||||
|
original_name,
|
||||||
|
"schema_test",
|
||||||
)
|
)
|
||||||
fqn_path = get_pseudo_test_path(
|
fqn_path = get_pseudo_test_path(
|
||||||
builder.fqn_name, original_name, 'schema_test',
|
builder.fqn_name,
|
||||||
|
original_name,
|
||||||
|
"schema_test",
|
||||||
)
|
)
|
||||||
# the fqn for tests actually happens in the test target's name, which
|
# the fqn for tests actually happens in the test target's name, which
|
||||||
# is not necessarily this package's name
|
# is not necessarily this package's name
|
||||||
fqn = self.get_fqn(fqn_path, builder.fqn_name)
|
fqn = self.get_fqn(fqn_path, builder.fqn_name)
|
||||||
|
|
||||||
|
# this is the config that is used in render_update
|
||||||
config = self.initial_config(fqn)
|
config = self.initial_config(fqn)
|
||||||
|
|
||||||
metadata = {
|
metadata = {
|
||||||
'namespace': builder.namespace,
|
"namespace": builder.namespace,
|
||||||
'name': builder.name,
|
"name": builder.name,
|
||||||
'kwargs': builder.args,
|
"kwargs": builder.args,
|
||||||
}
|
}
|
||||||
tags = sorted(set(itertools.chain(tags, builder.tags())))
|
tags = sorted(set(itertools.chain(tags, builder.tags())))
|
||||||
if 'schema' not in tags:
|
if "schema" not in tags:
|
||||||
tags.append('schema')
|
tags.append("schema")
|
||||||
|
|
||||||
node = self.create_test_node(
|
node = self.create_test_node(
|
||||||
target=target,
|
target=target,
|
||||||
@@ -447,9 +449,54 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
|||||||
column_name=column_name,
|
column_name=column_name,
|
||||||
test_metadata=metadata,
|
test_metadata=metadata,
|
||||||
)
|
)
|
||||||
self.render_update(node, config)
|
self.render_test_update(node, config, builder)
|
||||||
|
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
# This does special shortcut processing for the two
|
||||||
|
# most common internal macros, not_null and unique,
|
||||||
|
# which avoids the jinja rendering to resolve config
|
||||||
|
# and variables, etc, which might be in the macro.
|
||||||
|
# In the future we will look at generalizing this
|
||||||
|
# more to handle additional macros or to use static
|
||||||
|
# parsing to avoid jinja overhead.
|
||||||
|
def render_test_update(self, node, config, builder):
|
||||||
|
macro_unique_id = self.macro_resolver.get_macro_id(
|
||||||
|
node.package_name, "test_" + builder.name
|
||||||
|
)
|
||||||
|
# Add the depends_on here so we can limit the macros added
|
||||||
|
# to the context in rendering processing
|
||||||
|
node.depends_on.add_macro(macro_unique_id)
|
||||||
|
if macro_unique_id in ["macro.dbt.test_not_null", "macro.dbt.test_unique"]:
|
||||||
|
self.update_parsed_node(node, config)
|
||||||
|
node.unrendered_config["severity"] = builder.severity()
|
||||||
|
node.config["severity"] = builder.severity()
|
||||||
|
# source node tests are processed at patch_source time
|
||||||
|
if isinstance(builder.target, UnpatchedSourceDefinition):
|
||||||
|
sources = [builder.target.fqn[-2], builder.target.fqn[-1]]
|
||||||
|
node.sources.append(sources)
|
||||||
|
else: # all other nodes
|
||||||
|
node.refs.append([builder.target.name])
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
# make a base context that doesn't have the magic kwargs field
|
||||||
|
context = generate_test_context(
|
||||||
|
node,
|
||||||
|
self.root_project,
|
||||||
|
self.macro_manifest,
|
||||||
|
config,
|
||||||
|
self.macro_resolver,
|
||||||
|
)
|
||||||
|
# update with rendered test kwargs (which collects any refs)
|
||||||
|
add_rendered_test_kwargs(context, node, capture_macros=True)
|
||||||
|
# the parsed node is not rendered in the native context.
|
||||||
|
get_rendered(node.raw_sql, context, node, capture_macros=True)
|
||||||
|
self.update_parsed_node(node, config)
|
||||||
|
except ValidationError as exc:
|
||||||
|
# we got a ValidationError - probably bad types in config()
|
||||||
|
msg = validator_error_message(exc)
|
||||||
|
raise CompilationException(msg, node=node) from exc
|
||||||
|
|
||||||
def parse_source_test(
|
def parse_source_test(
|
||||||
self,
|
self,
|
||||||
target: UnpatchedSourceDefinition,
|
target: UnpatchedSourceDefinition,
|
||||||
@@ -461,9 +508,8 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
|||||||
column_name = None
|
column_name = None
|
||||||
else:
|
else:
|
||||||
column_name = column.name
|
column_name = column.name
|
||||||
should_quote = (
|
should_quote = column.quote or (
|
||||||
column.quote or
|
column.quote is None and target.quote_columns
|
||||||
(column.quote is None and target.quote_columns)
|
|
||||||
)
|
)
|
||||||
if should_quote:
|
if should_quote:
|
||||||
column_name = get_adapter(self.root_project).quote(column_name)
|
column_name = get_adapter(self.root_project).quote(column_name)
|
||||||
@@ -474,10 +520,7 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
|||||||
tags = list(itertools.chain.from_iterable(tags_sources))
|
tags = list(itertools.chain.from_iterable(tags_sources))
|
||||||
|
|
||||||
node = self._parse_generic_test(
|
node = self._parse_generic_test(
|
||||||
target=target,
|
target=target, test=test, tags=tags, column_name=column_name
|
||||||
test=test,
|
|
||||||
tags=tags,
|
|
||||||
column_name=column_name
|
|
||||||
)
|
)
|
||||||
# we can't go through result.add_node - no file... instead!
|
# we can't go through result.add_node - no file... instead!
|
||||||
if node.config.enabled:
|
if node.config.enabled:
|
||||||
@@ -501,7 +544,9 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
|||||||
return node
|
return node
|
||||||
|
|
||||||
def render_with_context(
|
def render_with_context(
|
||||||
self, node: ParsedSchemaTestNode, config: ContextConfig,
|
self,
|
||||||
|
node: ParsedSchemaTestNode,
|
||||||
|
config: ContextConfig,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Given the parsed node and a ContextConfig to use during
|
"""Given the parsed node and a ContextConfig to use during
|
||||||
parsing, collect all the refs that might be squirreled away in the test
|
parsing, collect all the refs that might be squirreled away in the test
|
||||||
@@ -513,9 +558,7 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
|||||||
add_rendered_test_kwargs(context, node, capture_macros=True)
|
add_rendered_test_kwargs(context, node, capture_macros=True)
|
||||||
|
|
||||||
# the parsed node is not rendered in the native context.
|
# the parsed node is not rendered in the native context.
|
||||||
get_rendered(
|
get_rendered(node.raw_sql, context, node, capture_macros=True)
|
||||||
node.raw_sql, context, node, capture_macros=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def parse_test(
|
def parse_test(
|
||||||
self,
|
self,
|
||||||
@@ -531,9 +574,8 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
|||||||
column_tags: List[str] = []
|
column_tags: List[str] = []
|
||||||
else:
|
else:
|
||||||
column_name = column.name
|
column_name = column.name
|
||||||
should_quote = (
|
should_quote = column.quote or (
|
||||||
column.quote or
|
column.quote is None and target_block.quote_columns
|
||||||
(column.quote is None and target_block.quote_columns)
|
|
||||||
)
|
)
|
||||||
if should_quote:
|
if should_quote:
|
||||||
column_name = get_adapter(self.root_project).quote(column_name)
|
column_name = get_adapter(self.root_project).quote(column_name)
|
||||||
@@ -561,60 +603,87 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
|
|||||||
|
|
||||||
def parse_file(self, block: FileBlock) -> None:
|
def parse_file(self, block: FileBlock) -> None:
|
||||||
dct = self._yaml_from_file(block.file)
|
dct = self._yaml_from_file(block.file)
|
||||||
# mark the file as seen, even if there are no macros in it
|
|
||||||
|
# mark the file as seen, in ParseResult.files
|
||||||
self.results.get_file(block.file)
|
self.results.get_file(block.file)
|
||||||
|
|
||||||
if dct:
|
if dct:
|
||||||
try:
|
try:
|
||||||
|
# This does a deep_map to check for circular references
|
||||||
dct = self.raw_renderer.render_data(dct)
|
dct = self.raw_renderer.render_data(dct)
|
||||||
except CompilationException as exc:
|
except CompilationException as exc:
|
||||||
raise CompilationException(
|
raise CompilationException(
|
||||||
f'Failed to render {block.path.original_file_path} from '
|
f"Failed to render {block.path.original_file_path} from "
|
||||||
f'project {self.project.project_name}: {exc}'
|
f"project {self.project.project_name}: {exc}"
|
||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
|
# contains the FileBlock and the data (dictionary)
|
||||||
yaml_block = YamlBlock.from_file_block(block, dct)
|
yaml_block = YamlBlock.from_file_block(block, dct)
|
||||||
|
|
||||||
self._parse_format_version(yaml_block)
|
# checks version
|
||||||
|
self._check_format_version(yaml_block)
|
||||||
|
|
||||||
parser: YamlDocsReader
|
parser: YamlDocsReader
|
||||||
for key in NodeType.documentable():
|
|
||||||
plural = key.pluralize()
|
# There are 7 kinds of parsers:
|
||||||
if key == NodeType.Source:
|
# Model, Seed, Snapshot, Source, Macro, Analysis, Exposures
|
||||||
parser = SourceParser(self, yaml_block, plural)
|
|
||||||
elif key == NodeType.Macro:
|
# NonSourceParser.parse(), TestablePatchParser is a variety of
|
||||||
parser = MacroPatchParser(self, yaml_block, plural)
|
# NodePatchParser
|
||||||
elif key == NodeType.Analysis:
|
if "models" in dct:
|
||||||
parser = AnalysisPatchParser(self, yaml_block, plural)
|
parser = TestablePatchParser(self, yaml_block, "models")
|
||||||
elif key == NodeType.Exposure:
|
|
||||||
# handle exposures separately, but they are
|
|
||||||
# technically still "documentable"
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
parser = TestablePatchParser(self, yaml_block, plural)
|
|
||||||
for test_block in parser.parse():
|
for test_block in parser.parse():
|
||||||
self.parse_tests(test_block)
|
self.parse_tests(test_block)
|
||||||
self.parse_exposures(yaml_block)
|
|
||||||
|
# NonSourceParser.parse()
|
||||||
|
if "seeds" in dct:
|
||||||
|
parser = TestablePatchParser(self, yaml_block, "seeds")
|
||||||
|
for test_block in parser.parse():
|
||||||
|
self.parse_tests(test_block)
|
||||||
|
|
||||||
|
# NonSourceParser.parse()
|
||||||
|
if "snapshots" in dct:
|
||||||
|
parser = TestablePatchParser(self, yaml_block, "snapshots")
|
||||||
|
for test_block in parser.parse():
|
||||||
|
self.parse_tests(test_block)
|
||||||
|
|
||||||
|
# This parser uses SourceParser.parse() which doesn't return
|
||||||
|
# any test blocks. Source tests are handled at a later point
|
||||||
|
# in the process.
|
||||||
|
if "sources" in dct:
|
||||||
|
parser = SourceParser(self, yaml_block, "sources")
|
||||||
|
parser.parse()
|
||||||
|
|
||||||
|
# NonSourceParser.parse()
|
||||||
|
if "macros" in dct:
|
||||||
|
parser = MacroPatchParser(self, yaml_block, "macros")
|
||||||
|
for test_block in parser.parse():
|
||||||
|
self.parse_tests(test_block)
|
||||||
|
|
||||||
|
# NonSourceParser.parse()
|
||||||
|
if "analyses" in dct:
|
||||||
|
parser = AnalysisPatchParser(self, yaml_block, "analyses")
|
||||||
|
for test_block in parser.parse():
|
||||||
|
self.parse_tests(test_block)
|
||||||
|
|
||||||
|
# parse exposures
|
||||||
|
if "exposures" in dct:
|
||||||
|
self.parse_exposures(yaml_block)
|
||||||
|
|
||||||
|
|
||||||
Parsed = TypeVar(
|
Parsed = TypeVar("Parsed", UnpatchedSourceDefinition, ParsedNodePatch, ParsedMacroPatch)
|
||||||
'Parsed',
|
NodeTarget = TypeVar("NodeTarget", UnparsedNodeUpdate, UnparsedAnalysisUpdate)
|
||||||
UnpatchedSourceDefinition, ParsedNodePatch, ParsedMacroPatch
|
|
||||||
)
|
|
||||||
NodeTarget = TypeVar(
|
|
||||||
'NodeTarget',
|
|
||||||
UnparsedNodeUpdate, UnparsedAnalysisUpdate
|
|
||||||
)
|
|
||||||
NonSourceTarget = TypeVar(
|
NonSourceTarget = TypeVar(
|
||||||
'NonSourceTarget',
|
"NonSourceTarget", UnparsedNodeUpdate, UnparsedAnalysisUpdate, UnparsedMacroUpdate
|
||||||
UnparsedNodeUpdate, UnparsedAnalysisUpdate, UnparsedMacroUpdate
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# abstract base class (ABCMeta)
|
||||||
class YamlReader(metaclass=ABCMeta):
|
class YamlReader(metaclass=ABCMeta):
|
||||||
def __init__(
|
def __init__(self, schema_parser: SchemaParser, yaml: YamlBlock, key: str) -> None:
|
||||||
self, schema_parser: SchemaParser, yaml: YamlBlock, key: str
|
|
||||||
) -> None:
|
|
||||||
self.schema_parser = schema_parser
|
self.schema_parser = schema_parser
|
||||||
|
# key: models, seeds, snapshots, sources, macros,
|
||||||
|
# analyses, exposures
|
||||||
self.key = key
|
self.key = key
|
||||||
self.yaml = yaml
|
self.yaml = yaml
|
||||||
|
|
||||||
@@ -634,21 +703,28 @@ class YamlReader(metaclass=ABCMeta):
|
|||||||
def root_project(self):
|
def root_project(self):
|
||||||
return self.schema_parser.root_project
|
return self.schema_parser.root_project
|
||||||
|
|
||||||
|
# for the different schema subparsers ('models', 'source', etc)
|
||||||
|
# get the list of dicts pointed to by the key in the yaml config,
|
||||||
|
# ensure that the dicts have string keys
|
||||||
def get_key_dicts(self) -> Iterable[Dict[str, Any]]:
|
def get_key_dicts(self) -> Iterable[Dict[str, Any]]:
|
||||||
data = self.yaml.data.get(self.key, [])
|
data = self.yaml.data.get(self.key, [])
|
||||||
if not isinstance(data, list):
|
if not isinstance(data, list):
|
||||||
raise CompilationException(
|
raise CompilationException(
|
||||||
'{} must be a list, got {} instead: ({})'
|
"{} must be a list, got {} instead: ({})".format(
|
||||||
.format(self.key, type(data), _trimmed(str(data)))
|
self.key, type(data), _trimmed(str(data))
|
||||||
|
)
|
||||||
)
|
)
|
||||||
path = self.yaml.path.original_file_path
|
path = self.yaml.path.original_file_path
|
||||||
|
|
||||||
|
# for each dict in the data (which is a list of dicts)
|
||||||
for entry in data:
|
for entry in data:
|
||||||
|
# check that entry is a dict and that all dict values
|
||||||
|
# are strings
|
||||||
if coerce_dict_str(entry) is not None:
|
if coerce_dict_str(entry) is not None:
|
||||||
yield entry
|
yield entry
|
||||||
else:
|
else:
|
||||||
msg = error_context(
|
msg = error_context(
|
||||||
path, self.key, data, 'expected a dict with string keys'
|
path, self.key, data, "expected a dict with string keys"
|
||||||
)
|
)
|
||||||
raise CompilationException(msg)
|
raise CompilationException(msg)
|
||||||
|
|
||||||
@@ -656,30 +732,31 @@ class YamlReader(metaclass=ABCMeta):
|
|||||||
class YamlDocsReader(YamlReader):
|
class YamlDocsReader(YamlReader):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def parse(self) -> List[TestBlock]:
|
def parse(self) -> List[TestBlock]:
|
||||||
raise NotImplementedError('parse is abstract')
|
raise NotImplementedError("parse is abstract")
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar('T', bound=JsonSchemaMixin)
|
T = TypeVar("T", bound=dbtClassMixin)
|
||||||
|
|
||||||
|
|
||||||
class SourceParser(YamlDocsReader):
|
class SourceParser(YamlDocsReader):
|
||||||
def _target_from_dict(self, cls: Type[T], data: Dict[str, Any]) -> T:
|
def _target_from_dict(self, cls: Type[T], data: Dict[str, Any]) -> T:
|
||||||
path = self.yaml.path.original_file_path
|
path = self.yaml.path.original_file_path
|
||||||
try:
|
try:
|
||||||
|
cls.validate(data)
|
||||||
return cls.from_dict(data)
|
return cls.from_dict(data)
|
||||||
except (ValidationError, JSONValidationException) as exc:
|
except (ValidationError, JSONValidationException) as exc:
|
||||||
msg = error_context(path, self.key, data, exc)
|
msg = error_context(path, self.key, data, exc)
|
||||||
raise CompilationException(msg) from exc
|
raise CompilationException(msg) from exc
|
||||||
|
|
||||||
|
# the other parse method returns TestBlocks. This one doesn't.
|
||||||
def parse(self) -> List[TestBlock]:
|
def parse(self) -> List[TestBlock]:
|
||||||
|
# get a verified list of dicts for the key handled by this parser
|
||||||
for data in self.get_key_dicts():
|
for data in self.get_key_dicts():
|
||||||
data = self.project.credentials.translate_aliases(
|
data = self.project.credentials.translate_aliases(data, recurse=True)
|
||||||
data, recurse=True
|
|
||||||
)
|
|
||||||
|
|
||||||
is_override = 'overrides' in data
|
is_override = "overrides" in data
|
||||||
if is_override:
|
if is_override:
|
||||||
data['path'] = self.yaml.path.original_file_path
|
data["path"] = self.yaml.path.original_file_path
|
||||||
patch = self._target_from_dict(SourcePatch, data)
|
patch = self._target_from_dict(SourcePatch, data)
|
||||||
self.results.add_source_patch(self.yaml.file, patch)
|
self.results.add_source_patch(self.yaml.file, patch)
|
||||||
else:
|
else:
|
||||||
@@ -691,10 +768,9 @@ class SourceParser(YamlDocsReader):
|
|||||||
original_file_path = self.yaml.path.original_file_path
|
original_file_path = self.yaml.path.original_file_path
|
||||||
fqn_path = self.yaml.path.relative_path
|
fqn_path = self.yaml.path.relative_path
|
||||||
for table in source.tables:
|
for table in source.tables:
|
||||||
unique_id = '.'.join([
|
unique_id = ".".join(
|
||||||
NodeType.Source, self.project.project_name,
|
[NodeType.Source, self.project.project_name, source.name, table.name]
|
||||||
source.name, table.name
|
)
|
||||||
])
|
|
||||||
|
|
||||||
# the FQN is project name / path elements /source_name /table_name
|
# the FQN is project name / path elements /source_name /table_name
|
||||||
fqn = self.schema_parser.get_fqn_prefix(fqn_path)
|
fqn = self.schema_parser.get_fqn_prefix(fqn_path)
|
||||||
@@ -714,60 +790,81 @@ class SourceParser(YamlDocsReader):
|
|||||||
self.results.add_source(self.yaml.file, result)
|
self.results.add_source(self.yaml.file, result)
|
||||||
|
|
||||||
|
|
||||||
|
# This class has three main subclasses: TestablePatchParser (models,
|
||||||
|
# seeds, snapshots), MacroPatchParser, and AnalysisPatchParser
|
||||||
class NonSourceParser(YamlDocsReader, Generic[NonSourceTarget, Parsed]):
|
class NonSourceParser(YamlDocsReader, Generic[NonSourceTarget, Parsed]):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _target_type(self) -> Type[NonSourceTarget]:
|
def _target_type(self) -> Type[NonSourceTarget]:
|
||||||
raise NotImplementedError('_unsafe_from_dict not implemented')
|
raise NotImplementedError("_target_type not implemented")
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_block(self, node: NonSourceTarget) -> TargetBlock:
|
def get_block(self, node: NonSourceTarget) -> TargetBlock:
|
||||||
raise NotImplementedError('get_block is abstract')
|
raise NotImplementedError("get_block is abstract")
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def parse_patch(
|
def parse_patch(self, block: TargetBlock[NonSourceTarget], refs: ParserRef) -> None:
|
||||||
self, block: TargetBlock[NonSourceTarget], refs: ParserRef
|
raise NotImplementedError("parse_patch is abstract")
|
||||||
) -> None:
|
|
||||||
raise NotImplementedError('parse_patch is abstract')
|
|
||||||
|
|
||||||
def parse(self) -> List[TestBlock]:
|
def parse(self) -> List[TestBlock]:
|
||||||
node: NonSourceTarget
|
node: NonSourceTarget
|
||||||
test_blocks: List[TestBlock] = []
|
test_blocks: List[TestBlock] = []
|
||||||
|
# get list of 'node' objects
|
||||||
|
# UnparsedNodeUpdate (TestablePatchParser, models, seeds, snapshots)
|
||||||
|
# = HasColumnTests, HasTests
|
||||||
|
# UnparsedAnalysisUpdate (UnparsedAnalysisParser, analyses)
|
||||||
|
# = HasColumnDocs, HasDocs
|
||||||
|
# UnparsedMacroUpdate (MacroPatchParser, 'macros')
|
||||||
|
# = HasDocs
|
||||||
|
# correspond to this parser's 'key'
|
||||||
for node in self.get_unparsed_target():
|
for node in self.get_unparsed_target():
|
||||||
|
# node_block is a TargetBlock (Macro or Analysis)
|
||||||
|
# or a TestBlock (all of the others)
|
||||||
node_block = self.get_block(node)
|
node_block = self.get_block(node)
|
||||||
if isinstance(node_block, TestBlock):
|
if isinstance(node_block, TestBlock):
|
||||||
|
# TestablePatchParser = models, seeds, snapshots
|
||||||
test_blocks.append(node_block)
|
test_blocks.append(node_block)
|
||||||
if isinstance(node, (HasColumnDocs, HasColumnTests)):
|
if isinstance(node, (HasColumnDocs, HasColumnTests)):
|
||||||
|
# UnparsedNodeUpdate and UnparsedAnalysisUpdate
|
||||||
refs: ParserRef = ParserRef.from_target(node)
|
refs: ParserRef = ParserRef.from_target(node)
|
||||||
else:
|
else:
|
||||||
refs = ParserRef()
|
refs = ParserRef()
|
||||||
|
# This adds the node_block to self.results (a ParseResult
|
||||||
|
# object) as a ParsedNodePatch or ParsedMacroPatch
|
||||||
self.parse_patch(node_block, refs)
|
self.parse_patch(node_block, refs)
|
||||||
return test_blocks
|
return test_blocks
|
||||||
|
|
||||||
def get_unparsed_target(self) -> Iterable[NonSourceTarget]:
|
def get_unparsed_target(self) -> Iterable[NonSourceTarget]:
|
||||||
path = self.yaml.path.original_file_path
|
path = self.yaml.path.original_file_path
|
||||||
|
|
||||||
for data in self.get_key_dicts():
|
# get verified list of dicts for the 'key' that this
|
||||||
data.update({
|
# parser handles
|
||||||
'original_file_path': path,
|
key_dicts = self.get_key_dicts()
|
||||||
'yaml_key': self.key,
|
for data in key_dicts:
|
||||||
'package_name': self.project.project_name,
|
# add extra data to each dict. This updates the dicts
|
||||||
})
|
# in the parser yaml
|
||||||
|
data.update(
|
||||||
|
{
|
||||||
|
"original_file_path": path,
|
||||||
|
"yaml_key": self.key,
|
||||||
|
"package_name": self.project.project_name,
|
||||||
|
}
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
model = self._target_type().from_dict(data)
|
# target_type: UnparsedNodeUpdate, UnparsedAnalysisUpdate,
|
||||||
|
# or UnparsedMacroUpdate
|
||||||
|
self._target_type().validate(data)
|
||||||
|
node = self._target_type().from_dict(data)
|
||||||
except (ValidationError, JSONValidationException) as exc:
|
except (ValidationError, JSONValidationException) as exc:
|
||||||
msg = error_context(path, self.key, data, exc)
|
msg = error_context(path, self.key, data, exc)
|
||||||
raise CompilationException(msg) from exc
|
raise CompilationException(msg) from exc
|
||||||
else:
|
else:
|
||||||
yield model
|
yield node
|
||||||
|
|
||||||
|
|
||||||
class NodePatchParser(
|
class NodePatchParser(
|
||||||
NonSourceParser[NodeTarget, ParsedNodePatch],
|
NonSourceParser[NodeTarget, ParsedNodePatch], Generic[NodeTarget]
|
||||||
Generic[NodeTarget]
|
|
||||||
):
|
):
|
||||||
def parse_patch(
|
def parse_patch(self, block: TargetBlock[NodeTarget], refs: ParserRef) -> None:
|
||||||
self, block: TargetBlock[NodeTarget], refs: ParserRef
|
|
||||||
) -> None:
|
|
||||||
result = ParsedNodePatch(
|
result = ParsedNodePatch(
|
||||||
name=block.target.name,
|
name=block.target.name,
|
||||||
original_file_path=block.target.original_file_path,
|
original_file_path=block.target.original_file_path,
|
||||||
@@ -828,7 +925,7 @@ class ExposureParser(YamlReader):
|
|||||||
|
|
||||||
def parse_exposure(self, unparsed: UnparsedExposure) -> ParsedExposure:
|
def parse_exposure(self, unparsed: UnparsedExposure) -> ParsedExposure:
|
||||||
package_name = self.project.project_name
|
package_name = self.project.project_name
|
||||||
unique_id = f'{NodeType.Exposure}.{package_name}.{unparsed.name}'
|
unique_id = f"{NodeType.Exposure}.{package_name}.{unparsed.name}"
|
||||||
path = self.yaml.path.relative_path
|
path = self.yaml.path.relative_path
|
||||||
|
|
||||||
fqn = self.schema_parser.get_fqn_prefix(path)
|
fqn = self.schema_parser.get_fqn_prefix(path)
|
||||||
@@ -854,18 +951,17 @@ class ExposureParser(YamlReader):
|
|||||||
self.schema_parser.macro_manifest,
|
self.schema_parser.macro_manifest,
|
||||||
package_name,
|
package_name,
|
||||||
)
|
)
|
||||||
depends_on_jinja = '\n'.join(
|
depends_on_jinja = "\n".join(
|
||||||
'{{ ' + line + '}}' for line in unparsed.depends_on
|
"{{ " + line + "}}" for line in unparsed.depends_on
|
||||||
)
|
|
||||||
get_rendered(
|
|
||||||
depends_on_jinja, ctx, parsed, capture_macros=True
|
|
||||||
)
|
)
|
||||||
|
get_rendered(depends_on_jinja, ctx, parsed, capture_macros=True)
|
||||||
# parsed now has a populated refs/sources
|
# parsed now has a populated refs/sources
|
||||||
return parsed
|
return parsed
|
||||||
|
|
||||||
def parse(self) -> Iterable[ParsedExposure]:
|
def parse(self) -> Iterable[ParsedExposure]:
|
||||||
for data in self.get_key_dicts():
|
for data in self.get_key_dicts():
|
||||||
try:
|
try:
|
||||||
|
UnparsedExposure.validate(data)
|
||||||
unparsed = UnparsedExposure.from_dict(data)
|
unparsed = UnparsedExposure.from_dict(data)
|
||||||
except (ValidationError, JSONValidationException) as exc:
|
except (ValidationError, JSONValidationException) as exc:
|
||||||
msg = error_context(self.yaml.path, self.key, data, exc)
|
msg = error_context(self.yaml.path, self.key, data, exc)
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user