Compare commits

...

2 Commits

Author SHA1 Message Date
Jeremy Cohen
f7705a3e4b Unit test for yaml loading 2022-06-14 11:39:01 +02:00
jeremyyeo
92847ce90f fix yaml parsing 2022-06-08 22:39:04 +12:00
2 changed files with 50 additions and 4 deletions

View File

@@ -31,6 +31,7 @@ class UniqueKeyLoader(SafeLoader):
def construct_mapping(self, node, deep=False):
mapping = set()
self.flatten_mapping(node) # This processes yaml anchors / merge keys (<<).
for key_node, value_node in node.value:
key = self.construct_object(key_node, deep=deep)
if key in mapping:
@@ -68,13 +69,16 @@ def contextualized_yaml_error(raw_contents, error):
)
def safe_load(contents) -> Optional[Dict[str, Any]]:
def safe_load(contents, unique=False) -> Optional[Dict[str, Any]]:
if unique:
return yaml.load(contents, Loader=UniqueKeyLoader)
else:
return yaml.load(contents, Loader=SafeLoader)
def load_yaml_text(contents, path=None):
try:
return safe_load(contents)
return safe_load(contents, unique=True)
except (yaml.scanner.ScannerError, yaml.YAMLError) as e:
if hasattr(e, "problem_mark"):
error = contextualized_yaml_error(contents, e)
@@ -84,5 +88,7 @@ def load_yaml_text(contents, path=None):
raise dbt.exceptions.ValidationException(error)
except dbt.exceptions.DuplicateYamlKeyException as e:
# TODO: We may want to raise an exception instead of a warning in the future.
if path:
e.msg = f"{e} {path.searched_path}/{path.relative_path}."
dbt.exceptions.warn_or_raise(e, log_fmt=warning_tag("{}"))
return safe_load(contents)

View File

@@ -0,0 +1,40 @@
from dbt.clients.yaml_helper import load_yaml_text
import unittest
profile_with_anchor = """
# profiles.yml
postgres:
outputs:
dev: &profile
type: postgres
host: localhost
user: root
password: password
schema: public
database: postgres
port: 5432
threads: 8
prod:
<<: *profile
uat: *profile
target: dev
"""
project_with_duped_var = """
# dbt_project.yml
vars:
foo: bar
foo: bar
"""
class YamlLoadingUnitTest(unittest.TestCase):
def test_load_yaml_anchors(self):
profile_yml = load_yaml_text(profile_with_anchor)
assert(profile_yml)
def test_load_duped_var(self):
dbt_project_yml = load_yaml_text(project_with_duped_var)
assert(dbt_project_yml)