Compare commits

...

15 Commits

Author SHA1 Message Date
Github Build Bot
275232f15d Add generated CLI API docs 2023-02-15 03:44:23 +00:00
Michelle Ark
bb1d22ee3a get_with_fallback 2023-02-14 22:42:51 -05:00
Michelle Ark
6926e7071b Merge branch 'main' into arky/default-flags 2023-02-14 11:33:32 -05:00
Github Build Bot
e2e38da1f4 Add generated CLI API docs 2023-02-14 15:57:01 +00:00
Michelle Ark
7ff090330d Merge branch 'main' into arky/default-flags 2023-02-14 10:55:34 -05:00
Github Build Bot
78ceb55cd9 Add generated CLI API docs 2023-02-08 17:53:59 +00:00
Michelle Ark
c176287d17 Merge branch 'feature/click-cli' into arky/default-flags 2023-02-08 12:52:01 -05:00
Michelle Ark
a156aabb10 support empty Flags constructor 2023-02-08 12:46:14 -05:00
Github Build Bot
177b244d82 Add generated CLI API docs 2023-02-08 03:12:26 +00:00
Michelle Ark
140a3e8007 profit! 2023-02-07 22:10:56 -05:00
Michelle Ark
1454dfdc97 Merge branch 'feature/click-cli' into arky/default-flags 2023-02-07 22:05:01 -05:00
Michelle Ark
b83786f31e Merge branch 'feature/click-cli' into arky/default-flags 2023-02-07 22:02:08 -05:00
Github Build Bot
f0b026ae72 Add generated CLI API docs 2023-02-04 01:33:03 +00:00
Michelle Ark
19f47d16f7 default flags + test 2023-02-03 20:31:32 -05:00
Michelle Ark
414e98198d load default click param value 2023-02-03 19:51:34 -05:00
8 changed files with 141 additions and 97 deletions

View File

@@ -7,31 +7,20 @@ from multiprocessing import get_context
from pprint import pformat as pf
from typing import Set, List
from click import Context, get_current_context, BadOptionUsage
from click import Context, get_current_context, BadOptionUsage, Command, BadParameter
from click.core import ParameterSource
from dbt.config.profile import read_user_config
from dbt.contracts.project import UserConfig
import dbt.cli.params as p
from dbt.helper_types import WarnErrorOptions
from dbt.cli.resolvers import default_project_dir, default_log_path
from dbt.cli.resolvers import default_log_path
if os.name != "nt":
# https://bugs.python.org/issue41567
import multiprocessing.popen_spawn_posix # type: ignore # noqa: F401
# TODO anything that has a default in params should be removed here?
# Or maybe only the ones that's in the root click group
FLAGS_DEFAULTS = {
"INDIRECT_SELECTION": "eager",
"TARGET_PATH": None,
# cli args without user_config or env var option
"FULL_REFRESH": False,
"STRICT_MODE": False,
"STORE_FAILURES": False,
}
# For backwards compatability, some params are defined across multiple levels,
# Top-level value should take precedence.
# e.g. dbt --target-path test2 run --target-path test2
@@ -60,89 +49,92 @@ def convert_config(config_name, config_value):
class Flags:
def __init__(self, ctx: Context = None, user_config: UserConfig = None) -> None:
# set the default flags
for key, value in FLAGS_DEFAULTS.items():
object.__setattr__(self, key, value)
if ctx is None:
ctx = get_current_context()
try:
ctx = get_current_context(silent=True)
except RuntimeError:
ctx = None
def assign_params(ctx, params_assigned_from_default):
def assign_params(ctx, params_assigned_from_default, params_assigned_from_user):
"""Recursively adds all click params to flag object"""
for param_name, param_value in ctx.params.items():
# TODO: this is to avoid duplicate params being defined in two places (version_check in run and cli)
# However this is a bit of a hack and we should find a better way to do this
# N.B. You have to use the base MRO method (object.__setattr__) to set attributes
# when using frozen dataclasses.
# https://docs.python.org/3/library/dataclasses.html#frozen-instances
if hasattr(self, param_name.upper()):
if param_name not in EXPECTED_DUPLICATE_PARAMS:
raise Exception(
f"Duplicate flag names found in click command: {param_name}"
)
else:
# Expected duplicate param from multi-level click command (ex: dbt --full_refresh run --full_refresh)
# Overwrite user-configured param with value from parent context
if ctx.get_parameter_source(param_name) != ParameterSource.DEFAULT:
object.__setattr__(self, param_name.upper(), param_value)
if param_name in EXPECTED_DUPLICATE_PARAMS:
# Expected duplicate param from multi-level click command (ex: dbt --full_refresh run --full_refresh)
# Overwrite user-configured param with value from parent context
if ctx.get_parameter_source(param_name) != ParameterSource.DEFAULT:
object.__setattr__(self, param_name.upper(), param_value)
params_assigned_from_user.add(param_name)
else:
object.__setattr__(self, param_name.upper(), param_value)
params_assigned_from_user.add(param_name)
if ctx.get_parameter_source(param_name) == ParameterSource.DEFAULT:
params_assigned_from_default.add(param_name)
params_assigned_from_user.remove(param_name)
if ctx.parent:
assign_params(ctx.parent, params_assigned_from_default)
assign_params(ctx.parent, params_assigned_from_default, params_assigned_from_user)
params_assigned_from_default = set() # type: Set[str]
assign_params(ctx, params_assigned_from_default)
params_assigned_from_user = set() # type: Set[str]
which = None
# Get the invoked command flags
invoked_subcommand_name = (
ctx.invoked_subcommand if hasattr(ctx, "invoked_subcommand") else None
)
if invoked_subcommand_name is not None:
invoked_subcommand = getattr(import_module("dbt.cli.main"), invoked_subcommand_name)
invoked_subcommand.allow_extra_args = True
invoked_subcommand.ignore_unknown_options = True
invoked_subcommand_ctx = invoked_subcommand.make_context(None, sys.argv)
assign_params(invoked_subcommand_ctx, params_assigned_from_default)
if ctx:
# Assign params from ctx
assign_params(ctx, params_assigned_from_default, params_assigned_from_user)
# Get the invoked command flags
invoked_subcommand_name = (
ctx.invoked_subcommand if hasattr(ctx, "invoked_subcommand") else None
)
if invoked_subcommand_name is not None:
invoked_subcommand = getattr(
import_module("dbt.cli.main"), invoked_subcommand_name
)
invoked_subcommand.allow_extra_args = True
invoked_subcommand.ignore_unknown_options = True
invoked_subcommand_ctx = invoked_subcommand.make_context(None, sys.argv)
assign_params(
invoked_subcommand_ctx, params_assigned_from_default, params_assigned_from_user
)
which = invoked_subcommand_name or ctx.info_name
# Load user config if not provided, if available from profiles dir
if not user_config:
profiles_dir = getattr(self, "PROFILES_DIR", None)
user_config = read_user_config(profiles_dir) if profiles_dir else None
user_config = read_user_config(self.get_with_fallback("PROFILES_DIR"))
# Overwrite default assignments with user config if available
if user_config:
param_assigned_from_default_copy = params_assigned_from_default.copy()
for param_assigned_from_default in params_assigned_from_default:
user_config_param_value = getattr(user_config, param_assigned_from_default, None)
if user_config_param_value is not None:
object.__setattr__(
self,
param_assigned_from_default.upper(),
convert_config(param_assigned_from_default, user_config_param_value),
)
param_assigned_from_default_copy.remove(param_assigned_from_default)
params_assigned_from_default = param_assigned_from_default_copy
# Overwrite default assignments with user config
for param_assigned_from_default in params_assigned_from_default:
user_config_param_value = getattr(user_config, param_assigned_from_default, None)
if user_config_param_value is not None:
object.__setattr__(
self,
param_assigned_from_default.upper(),
convert_config(param_assigned_from_default, user_config_param_value),
)
params_assigned_from_user.add(param_assigned_from_default)
# Hard coded flags
object.__setattr__(self, "WHICH", invoked_subcommand_name or ctx.info_name)
object.__setattr__(self, "WHICH", which)
object.__setattr__(self, "MP_CONTEXT", get_context("spawn"))
# Default LOG_PATH from PROJECT_DIR, if available.
if getattr(self, "LOG_PATH", None) is None:
project_dir = getattr(self, "PROJECT_DIR", default_project_dir())
version_check = getattr(self, "VERSION_CHECK", True)
project_dir = self.get_with_fallback("PROJECT_DIR")
version_check = self.get_with_fallback("PROJECT_DIR")
object.__setattr__(self, "LOG_PATH", default_log_path(project_dir, version_check))
# Support console DO NOT TRACK initiave
if os.getenv("DO_NOT_TRACK", "").lower() in ("1", "t", "true", "y", "yes"):
object.__setattr__(self, "SEND_ANONYMOUS_USAGE_STATS", False)
# Check mutual exclusivity once all flags are set
self._assert_mutually_exclusive(
params_assigned_from_default, ["WARN_ERROR", "WARN_ERROR_OPTIONS"]
params_assigned_from_user, ["WARN_ERROR", "WARN_ERROR_OPTIONS"]
)
# Support lower cased access for legacy code
@@ -155,19 +147,50 @@ class Flags:
def __str__(self) -> str:
return str(pf(self.__dict__))
def get_with_fallback(self, name):
return getattr(self, name, self.get_default(name))
def _assert_mutually_exclusive(
self, params_assigned_from_default: Set[str], group: List[str]
self, params_assigned_from_user: Set[str], group: List[str]
) -> None:
"""
Ensure no elements from group are simultaneously provided by a user, as inferred from params_assigned_from_default.
Ensure no elements from group are simultaneously provided by a user, as inferred from params_assigned_from_user.
Raises click.UsageError if any two elements from group are simultaneously provided by a user.
"""
set_flag = None
for flag in group:
flag_set_by_user = flag.lower() not in params_assigned_from_default
flag_set_by_user = flag.lower() in params_assigned_from_user
if flag_set_by_user and set_flag:
raise BadOptionUsage(
flag.lower(), f"{flag.lower()}: not allowed with argument {set_flag.lower()}"
)
elif flag_set_by_user:
set_flag = flag
@classmethod
def get_default(cls, param_name: str):
param_decorator_name = param_name.lower()
# TODO: move log path out of dbt-profile
if param_decorator_name == "log_path":
# Not possible to get project_dir or version_check on uninstantiated class.
return "logs"
try:
param_decorator = getattr(p, param_decorator_name)
except AttributeError:
raise AttributeError(f"'{cls.__name__}' object has no attribute '{param_name}'")
command = param_decorator(Command(None))
param = command.params[0]
default = param.default
# TODO: make ticket for getting out of returning lambdas from params from defaults
if callable(default):
return default()
else:
if param.type:
try:
return param.type.convert(default, param, None)
except (BadParameter, TypeError):
return default
return default

View File

@@ -414,7 +414,6 @@ warn_error = click.option(
"--warn-error",
envvar="DBT_WARN_ERROR",
help="If dbt would normally warn, instead raise an exception. Examples include --select that selects nothing, deprecations, configurations with no associated models, invalid test configurations, and missing sources/refs in tests.",
default=None,
is_flag=True,
)

View File

@@ -83,11 +83,12 @@ def profile(func):
if ctx.obj.get("profile") is None:
flags = ctx.obj["flags"]
# TODO: Generalize safe access to flags.THREADS:
# https://github.com/dbt-labs/dbt-core/issues/6259
threads = getattr(flags, "THREADS", None)
profile = load_profile(
flags.PROJECT_DIR, flags.VARS, flags.PROFILE, flags.TARGET, threads
flags.PROJECT_DIR,
flags.VARS,
flags.PROFILE,
flags.TARGET,
flags.get_with_fallback("THREADS"),
)
ctx.obj["profile"] = profile

View File

@@ -181,15 +181,7 @@ class Profile(HasCredentials):
args_profile_name: Optional[str],
project_profile_name: Optional[str] = None,
) -> str:
# TODO: Duplicating this method as direct copy of the implementation in dbt.cli.resolvers
# dbt.cli.resolvers implementation can't be used because it causes a circular dependency.
# This should be removed and use a safe default access on the Flags module when
# https://github.com/dbt-labs/dbt-core/issues/6259 is closed.
def default_profiles_dir():
from pathlib import Path
return Path.cwd() if (Path.cwd() / "profiles.yml").exists() else Path.home() / ".dbt"
flags = get_flags()
profile_name = project_profile_name
if args_profile_name is not None:
profile_name = args_profile_name
@@ -206,7 +198,7 @@ defined in your profiles.yml file. You can find profiles.yml here:
{profiles_file}/profiles.yml
""".format(
profiles_file=default_profiles_dir()
profiles_file=flags.get_default("PROFILES_DIR")
)
raise DbtProjectError(NO_SUPPLIED_PROFILE_ERROR)
return profile_name

Binary file not shown.

View File

@@ -256,9 +256,7 @@ class DebugTask(BaseTask):
profile_name,
self.args.profile,
self.args.target,
# TODO: Generalize safe access to flags.THREADS:
# https://github.com/dbt-labs/dbt-core/issues/6259
getattr(self.args, "threads", None),
self.args.get_with_fallback("threads"),
)
except dbt.exceptions.DbtConfigError as exc:
profile_errors.append(str(exc))

View File

@@ -95,12 +95,7 @@ class GraphRunnableTask(ConfiguredTask):
def get_selection_spec(self) -> SelectionSpec:
default_selector_name = self.config.get_default_selector_name()
# TODO: The "eager" string below needs to be replaced with programatic access
# to the default value for the indirect selection parameter in
# dbt.cli.params.indirect_selection
#
# Doing that is actually a little tricky, so I'm punting it to a new ticket GH #6397
indirect_selection = getattr(self.args, "INDIRECT_SELECTION", "eager")
indirect_selection = self.args.get_with_fallback("INDIRECT_SELECTION")
if self.args.selector:
# use pre-defined selector (--selector)

View File

@@ -9,6 +9,7 @@ from dbt.cli.main import cli
from dbt.contracts.project import UserConfig
from dbt.cli.flags import Flags
from dbt.helper_types import WarnErrorOptions
from dbt.cli.resolvers import default_project_dir
class TestFlags:
@@ -35,15 +36,21 @@ class TestFlags:
@pytest.mark.parametrize("param", cli.params)
def test_cli_group_flags_from_params(self, run_context, param):
flags = Flags(run_context)
if param.name.upper() in ("VERSION", "LOG_PATH"):
if param.name.upper() == "VERSION":
return
assert hasattr(flags, param.name.upper())
assert getattr(flags, param.name.upper()) == run_context.params[param.name.lower()]
def test_log_path_default(self, run_context):
flags = Flags(run_context)
assert hasattr(flags, "LOG_PATH")
assert getattr(flags, "LOG_PATH") == Path("logs")
# LOG_PATH is set from dbt_project.yml, independently from its param definition
if param.name.upper() == "LOG_PATH":
assert getattr(flags, param.name.upper()) == Path("logs")
else:
assert getattr(flags, param.name.upper()) == run_context.params[param.name.lower()]
def test_expected_duplicate_params_precedence(self):
# default flags
context = self.make_dbt_context("run", ["--version-check", "deps", "--no-version_check"])
flags = Flags(context)
assert flags.VERSION_CHECK
@pytest.mark.parametrize(
"set_stats_param,do_not_track,expected_anonymous_usage_stats",
@@ -194,3 +201,32 @@ class TestFlags:
with pytest.raises(click.BadOptionUsage):
Flags(context, user_config)
@pytest.mark.parametrize(
"param_name,expected_default",
[
("port", 8080),
("warn_error", False),
("warn_error_options", WarnErrorOptions(include=[], exclude=[])),
("vars", {}),
("threads", None),
("state", None),
("project_dir", default_project_dir()),
],
)
def test_get_default(self, param_name, expected_default):
assert Flags.get_default(param_name) == expected_default
assert Flags.get_default(param_name.upper()) == expected_default
@pytest.mark.parametrize("param", cli.params)
def test_get_with_fallback(self, param):
# default flags
flags = Flags()
flags.get_with_fallback(param.name)
flags.get_with_fallback(param.name.upper())
def test_access_invalid_param(self):
# default flags
flags = Flags()
with pytest.raises(AttributeError):
flags.INVALID_PARAM