mirror of
https://github.com/dbt-labs/dbt-core
synced 2025-12-20 00:11:28 +00:00
Compare commits
1 Commits
enable-pos
...
stu-k/retr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9d692aae5a |
@@ -4,7 +4,7 @@ from dataclasses import dataclass
|
|||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from multiprocessing import get_context
|
from multiprocessing import get_context
|
||||||
from pprint import pformat as pf
|
from pprint import pformat as pf
|
||||||
from typing import Callable, Dict, List, Set, Union
|
from typing import Any, Callable, Dict, List, Set, Union
|
||||||
|
|
||||||
from click import Context, get_current_context
|
from click import Context, get_current_context
|
||||||
from click.core import Command, Group, ParameterSource
|
from click.core import Command, Group, ParameterSource
|
||||||
@@ -12,6 +12,7 @@ from dbt.cli.exceptions import DbtUsageException
|
|||||||
from dbt.cli.resolvers import default_log_path, default_project_dir
|
from dbt.cli.resolvers import default_log_path, default_project_dir
|
||||||
from dbt.config.profile import read_user_config
|
from dbt.config.profile import read_user_config
|
||||||
from dbt.contracts.project import UserConfig
|
from dbt.contracts.project import UserConfig
|
||||||
|
from dbt.exceptions import DbtInternalError
|
||||||
from dbt.deprecations import renamed_env_var
|
from dbt.deprecations import renamed_env_var
|
||||||
from dbt.helper_types import WarnErrorOptions
|
from dbt.helper_types import WarnErrorOptions
|
||||||
|
|
||||||
@@ -277,3 +278,85 @@ class Flags:
|
|||||||
# It is necessary to remove this attr from the class so it does
|
# It is necessary to remove this attr from the class so it does
|
||||||
# not get pickled when written to disk as json.
|
# not get pickled when written to disk as json.
|
||||||
object.__delattr__(self, "deprecated_env_var_warnings")
|
object.__delattr__(self, "deprecated_env_var_warnings")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict_for_cmd(cls, cmd: str, d: Dict[str, Any]) -> "Flags":
|
||||||
|
arg_list = get_args_for_cmd_from_dict(cmd, d)
|
||||||
|
ctx = args_to_context(arg_list)
|
||||||
|
flags = cls(ctx=ctx)
|
||||||
|
flags.fire_deprecations()
|
||||||
|
return flags
|
||||||
|
|
||||||
|
|
||||||
|
def get_args_for_cmd_from_dict(cmd: str, d: Dict[str, Any]) -> List[str]:
|
||||||
|
"""Given a command name and a dict, returns a list of strings representing
|
||||||
|
the CLI arguments that for a command. The order of this list is consistent with
|
||||||
|
which flags are expected at the parent level vs the command level.
|
||||||
|
|
||||||
|
e.g. fn("run", {"defer": True, "print": False}) -> ["--no-print", "run", "--defer"]
|
||||||
|
|
||||||
|
The result of this function can be passed in to the args_to_context function
|
||||||
|
to produce a click context to instantiate Flags with.
|
||||||
|
"""
|
||||||
|
|
||||||
|
cmd_args = get_args_for_cmd(cmd)
|
||||||
|
parent_args = get_args_for_cmd("cli")
|
||||||
|
default_args = [x.lower() for x in FLAGS_DEFAULTS.keys()]
|
||||||
|
|
||||||
|
res = [cmd]
|
||||||
|
|
||||||
|
for k, v in d.items():
|
||||||
|
k = k.lower()
|
||||||
|
|
||||||
|
# if a "which" value exists in the args dict, it should match the cmd arg
|
||||||
|
if k == "which":
|
||||||
|
if v != cmd.lower():
|
||||||
|
raise DbtInternalError(f"cmd '{cmd}' does not match value of which '{v}'")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# param was assigned from defaults and should not be included
|
||||||
|
if k not in cmd_args + parent_args and k in default_args:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# if the param is in parent args, it should come before the arg name
|
||||||
|
# e.g. ["--print", "run"] vs ["run", "--print"]
|
||||||
|
add_fn = res.append
|
||||||
|
if k in parent_args:
|
||||||
|
add_fn = lambda x: res.insert(0, x)
|
||||||
|
|
||||||
|
spinal_cased = k.replace("_", "-")
|
||||||
|
|
||||||
|
if v in (None, False):
|
||||||
|
add_fn(f"--no-{spinal_cased}")
|
||||||
|
elif v is True:
|
||||||
|
add_fn(f"--{spinal_cased}")
|
||||||
|
else:
|
||||||
|
add_fn(f"--{spinal_cased}={v}")
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def get_args_for_cmd(name: str) -> List[str]:
|
||||||
|
"""Given the string name of a command, return a list of strings representing
|
||||||
|
the params that command takes. This function will not return params assigned
|
||||||
|
to a parent click click when passed the name of a child click command.
|
||||||
|
|
||||||
|
e.g. fn("run") -> ["defer", "favor_state", "exclude", ...]
|
||||||
|
"""
|
||||||
|
import dbt.cli.main as cli
|
||||||
|
CMD_DICT = {
|
||||||
|
"build": cli.build,
|
||||||
|
"cli": cli.cli,
|
||||||
|
"compile": cli.compile,
|
||||||
|
"freshness": cli.freshness,
|
||||||
|
"generate": cli.docs_generate,
|
||||||
|
"run": cli.run,
|
||||||
|
"seed": cli.seed,
|
||||||
|
"show": cli.show,
|
||||||
|
"snapshot": cli.snapshot,
|
||||||
|
"test": cli.test,
|
||||||
|
}
|
||||||
|
cmd = CMD_DICT.get(name, None)
|
||||||
|
if cmd is None:
|
||||||
|
raise DbtInternalError(f"No command found for name '{name}'")
|
||||||
|
return [x.name for x in cmd.params if not x.name.lower().startswith("deprecated_")]
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from dbt.task.freshness import FreshnessTask
|
|||||||
from dbt.task.generate import GenerateTask
|
from dbt.task.generate import GenerateTask
|
||||||
from dbt.task.init import InitTask
|
from dbt.task.init import InitTask
|
||||||
from dbt.task.list import ListTask
|
from dbt.task.list import ListTask
|
||||||
|
from dbt.task.retry import RetryTask
|
||||||
from dbt.task.run import RunTask
|
from dbt.task.run import RunTask
|
||||||
from dbt.task.run_operation import RunOperationTask
|
from dbt.task.run_operation import RunOperationTask
|
||||||
from dbt.task.seed import SeedTask
|
from dbt.task.seed import SeedTask
|
||||||
@@ -527,6 +528,45 @@ def parse(ctx, **kwargs):
|
|||||||
return ctx.obj["manifest"], True
|
return ctx.obj["manifest"], True
|
||||||
|
|
||||||
|
|
||||||
|
# chenyu: it is actually kind of confusing what arguments I need to include here, some of
|
||||||
|
# them are for project profile, feels like they should be defined alongside those decorators
|
||||||
|
# dbt retry
|
||||||
|
@cli.command("retry")
|
||||||
|
@click.pass_context
|
||||||
|
@p.defer
|
||||||
|
@p.deprecated_defer
|
||||||
|
@p.fail_fast
|
||||||
|
@p.favor_state
|
||||||
|
@p.deprecated_favor_state
|
||||||
|
@p.indirect_selection #?
|
||||||
|
@p.profile
|
||||||
|
@p.profiles_dir
|
||||||
|
@p.project_dir
|
||||||
|
@p.state_required
|
||||||
|
@p.deprecated_state
|
||||||
|
@p.store_failures #?
|
||||||
|
@p.target
|
||||||
|
@p.target_path
|
||||||
|
@p.threads
|
||||||
|
@p.vars #?
|
||||||
|
@requires.preflight
|
||||||
|
@requires.profile
|
||||||
|
@requires.project
|
||||||
|
@requires.runtime_config
|
||||||
|
@requires.manifest
|
||||||
|
def retry(ctx, **kwargs):
|
||||||
|
"""Rerun a previous failed command given a previous run result artifact"""
|
||||||
|
task = RetryTask(
|
||||||
|
ctx.obj["flags"],
|
||||||
|
ctx.obj["runtime_config"],
|
||||||
|
ctx.obj["manifest"],
|
||||||
|
)
|
||||||
|
|
||||||
|
results = task.run()
|
||||||
|
success = task.interpret_results(results)
|
||||||
|
return results, success
|
||||||
|
|
||||||
|
|
||||||
# dbt run
|
# dbt run
|
||||||
@cli.command("run")
|
@cli.command("run")
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
|
|||||||
@@ -434,6 +434,20 @@ state = click.option(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
state_required = click.option(
|
||||||
|
"--state",
|
||||||
|
envvar="DBT_STATE",
|
||||||
|
help="If set, use the given directory as the source for JSON files to compare with this project.",
|
||||||
|
required=True,
|
||||||
|
type=click.Path(
|
||||||
|
dir_okay=True,
|
||||||
|
file_okay=False,
|
||||||
|
readable=True,
|
||||||
|
resolve_path=True,
|
||||||
|
path_type=Path,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
deprecated_state = click.option(
|
deprecated_state = click.option(
|
||||||
"--deprecated-state",
|
"--deprecated-state",
|
||||||
envvar="DBT_ARTIFACT_STATE_PATH",
|
envvar="DBT_ARTIFACT_STATE_PATH",
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from dbt.cli.exceptions import (
|
|||||||
ResultExit,
|
ResultExit,
|
||||||
)
|
)
|
||||||
from dbt.cli.flags import Flags
|
from dbt.cli.flags import Flags
|
||||||
|
from dbt.cli.utils import get_profile, get_project
|
||||||
from dbt.config import RuntimeConfig
|
from dbt.config import RuntimeConfig
|
||||||
from dbt.config.runtime import load_project, load_profile, UnsetProfile
|
from dbt.config.runtime import load_project, load_profile, UnsetProfile
|
||||||
from dbt.events.functions import fire_event, LOG_VERSION, set_invocation_id, setup_event_logger
|
from dbt.events.functions import fire_event, LOG_VERSION, set_invocation_id, setup_event_logger
|
||||||
@@ -132,12 +133,7 @@ def profile(func):
|
|||||||
ctx = args[0]
|
ctx = args[0]
|
||||||
assert isinstance(ctx, Context)
|
assert isinstance(ctx, Context)
|
||||||
|
|
||||||
flags = ctx.obj["flags"]
|
ctx.obj["profile"] = get_profile(ctx.obj["flags"])
|
||||||
# TODO: Generalize safe access to flags.THREADS:
|
|
||||||
# https://github.com/dbt-labs/dbt-core/issues/6259
|
|
||||||
threads = getattr(flags, "THREADS", None)
|
|
||||||
profile = load_profile(flags.PROJECT_DIR, flags.VARS, flags.PROFILE, flags.TARGET, threads)
|
|
||||||
ctx.obj["profile"] = profile
|
|
||||||
|
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
@@ -155,9 +151,7 @@ def project(func):
|
|||||||
raise DbtProjectError("profile required for project")
|
raise DbtProjectError("profile required for project")
|
||||||
|
|
||||||
flags = ctx.obj["flags"]
|
flags = ctx.obj["flags"]
|
||||||
project = load_project(
|
project = get_project(flags, ctx.obj["profile"])
|
||||||
flags.PROJECT_DIR, flags.VERSION_CHECK, ctx.obj["profile"], flags.VARS
|
|
||||||
)
|
|
||||||
ctx.obj["project"] = project
|
ctx.obj["project"] = project
|
||||||
|
|
||||||
if dbt.tracking.active_user is not None:
|
if dbt.tracking.active_user is not None:
|
||||||
|
|||||||
29
core/dbt/cli/utils.py
Normal file
29
core/dbt/cli/utils.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
from dbt.cli.flags import Flags
|
||||||
|
from dbt.config import RuntimeConfig
|
||||||
|
from dbt.config.runtime import Profile, Project, load_project, load_profile
|
||||||
|
|
||||||
|
|
||||||
|
def get_profile(flags: Flags) -> Profile:
|
||||||
|
# TODO: Generalize safe access to flags.THREADS:
|
||||||
|
# https://github.com/dbt-labs/dbt-core/issues/6259
|
||||||
|
threads = getattr(flags, "THREADS", None)
|
||||||
|
return load_profile(flags.PROJECT_DIR, flags.VARS, flags.PROFILE, flags.TARGET, threads)
|
||||||
|
|
||||||
|
|
||||||
|
def get_project(flags: Flags, profile: Profile) -> Project:
|
||||||
|
return load_project(
|
||||||
|
flags.PROJECT_DIR,
|
||||||
|
flags.VERSION_CHECK,
|
||||||
|
profile,
|
||||||
|
flags.VARS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_runtime_config(flags: Flags) -> RuntimeConfig:
|
||||||
|
profile = get_profile(flags)
|
||||||
|
project = get_project(flags, profile)
|
||||||
|
return RuntimeConfig.from_parts(
|
||||||
|
args=flags,
|
||||||
|
profile=profile,
|
||||||
|
project=project,
|
||||||
|
)
|
||||||
100
core/dbt/task/retry.py
Normal file
100
core/dbt/task/retry.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from dbt import selected_resources
|
||||||
|
from dbt.cli.flags import Flags
|
||||||
|
from dbt.cli.utils import get_runtime_config
|
||||||
|
from dbt.config.runtime import RuntimeConfig
|
||||||
|
from dbt.contracts.graph.manifest import Manifest
|
||||||
|
from dbt.contracts.results import NodeStatus
|
||||||
|
from dbt.contracts.state import PreviousState
|
||||||
|
from dbt.exceptions import DbtInternalError
|
||||||
|
from dbt.graph import GraphQueue
|
||||||
|
|
||||||
|
from dbt.task.base import BaseTask
|
||||||
|
from dbt.task.build import BuildTask
|
||||||
|
from dbt.task.compile import CompileTask
|
||||||
|
from dbt.task.freshness import FreshnessTask
|
||||||
|
from dbt.task.generate import GenerateTask
|
||||||
|
from dbt.task.run import RunTask
|
||||||
|
from dbt.task.seed import SeedTask
|
||||||
|
from dbt.task.show import ShowTask
|
||||||
|
from dbt.task.snapshot import SnapshotTask
|
||||||
|
from dbt.task.test import TestTask
|
||||||
|
|
||||||
|
|
||||||
|
TASK_DICT = {
|
||||||
|
"build": BuildTask,
|
||||||
|
"compile": CompileTask,
|
||||||
|
"freshness": FreshnessTask,
|
||||||
|
"generate": GenerateTask,
|
||||||
|
"run": RunTask,
|
||||||
|
"seed": SeedTask,
|
||||||
|
"show": ShowTask,
|
||||||
|
"snapshot": SnapshotTask,
|
||||||
|
"test": TestTask,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class RetryTask:
|
||||||
|
task_class: Optional[BaseTask]
|
||||||
|
|
||||||
|
def __init__(self, flags: Flags, config: RuntimeConfig, manifest: Manifest):
|
||||||
|
self.flags = flags
|
||||||
|
self.config = config
|
||||||
|
self.manifest = manifest
|
||||||
|
|
||||||
|
def run(self) -> Any: # TODO: type as anything a task can return (look at dbtRunner)
|
||||||
|
previous_state = PreviousState(
|
||||||
|
self.flags.state,
|
||||||
|
self.flags.state,
|
||||||
|
)
|
||||||
|
|
||||||
|
cmd_name = previous_state.results.args.get("which")
|
||||||
|
if cmd_name == "retry":
|
||||||
|
raise DbtInternalError("Can't retry a retry command")
|
||||||
|
|
||||||
|
self.task_class = TASK_DICT.get(cmd_name, None)
|
||||||
|
if self.task_class is None:
|
||||||
|
raise DbtInternalError(f'No command mapped to string "{cmd_name}"')
|
||||||
|
|
||||||
|
# should this interact with --warn-error?
|
||||||
|
statuses_to_skip = [NodeStatus.Success, NodeStatus.Pass, NodeStatus.Warn]
|
||||||
|
unique_ids = set([
|
||||||
|
result.unique_id
|
||||||
|
for result in previous_state.results.results
|
||||||
|
if result.status not in statuses_to_skip
|
||||||
|
])
|
||||||
|
|
||||||
|
# does #6009 need to be resolved before this ticket?
|
||||||
|
selected_resources.set_selected_resources(unique_ids)
|
||||||
|
|
||||||
|
class TaskWrapper(self.task_class):
|
||||||
|
def original_compile_manifest(self):
|
||||||
|
return super().compile_manifest()
|
||||||
|
|
||||||
|
def compile_manifest(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_graph_queue(self):
|
||||||
|
new_graph = self.graph.get_subset_graph(unique_ids)
|
||||||
|
return GraphQueue(
|
||||||
|
new_graph.graph,
|
||||||
|
self.manifest,
|
||||||
|
unique_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
retry_flags = Flags.from_dict_for_cmd(cmd_name, previous_state.results.args)
|
||||||
|
retry_config = get_runtime_config(retry_flags)
|
||||||
|
|
||||||
|
task = TaskWrapper(
|
||||||
|
retry_flags,
|
||||||
|
retry_config,
|
||||||
|
self.manifest,
|
||||||
|
)
|
||||||
|
|
||||||
|
task.original_compile_manifest()
|
||||||
|
|
||||||
|
return task.run()
|
||||||
|
|
||||||
|
def interpret_results(self, *args, **kwargs):
|
||||||
|
return self.task_class.interpret_results(*args, **kwargs)
|
||||||
48
tests/functional/retry/test_retry.py
Normal file
48
tests/functional/retry/test_retry.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from dbt.tests.util import run_dbt
|
||||||
|
from dbt.contracts.results import RunStatus
|
||||||
|
|
||||||
|
|
||||||
|
model_one = """
|
||||||
|
select 1 as fun
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_two = """
|
||||||
|
select * from {{ ref("model_one") }}
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_three = """
|
||||||
|
breaking line
|
||||||
|
select * from {{ ref("model_two") }}
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_four = """
|
||||||
|
select * from {{ ref("model_three") }}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunRetry:
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def models(self):
|
||||||
|
return {
|
||||||
|
"model_one.sql": model_one,
|
||||||
|
"model_two.sql": model_two,
|
||||||
|
"model_three.sql": model_three,
|
||||||
|
"model_four.sql": model_four,
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_run(self, project):
|
||||||
|
run_results = run_dbt(["run", "--target-path", "state"], expect_pass=False)
|
||||||
|
assert len(run_results) == 4
|
||||||
|
assert run_results[0].status == RunStatus.Success
|
||||||
|
assert run_results[1].status == RunStatus.Success
|
||||||
|
assert run_results[2].status == RunStatus.Error
|
||||||
|
assert run_results[3].status == RunStatus.Skipped
|
||||||
|
|
||||||
|
retry_results = run_dbt(["retry", "--state", "state"], expect_pass=False)
|
||||||
|
assert len(retry_results) == 2
|
||||||
|
assert retry_results[0].node.name == "model_three"
|
||||||
|
assert retry_results[0].status == RunStatus.Error
|
||||||
|
assert retry_results[1].node.name == "model_four"
|
||||||
|
assert retry_results[1].status == RunStatus.Skipped
|
||||||
Reference in New Issue
Block a user