Bring multiple jinja variants through to the parser. (#5794)

Co-authored-by: Danny Jones <51742311+WittierDinosaur@users.noreply.github.com>
This commit is contained in:
Alan Cruickshank
2024-04-25 22:42:43 +01:00
committed by GitHub
parent 7ccbcd2d8f
commit 4830bbc45a
13 changed files with 278 additions and 110 deletions

View File

@@ -171,6 +171,11 @@ def parse(
Returns:
:obj:`Dict[str, Any]` JSON containing the parsed structure.
Note:
In the case of multiple potential variants from the raw source file
only the first variant is returned by the simple API. For access to
the other variants, use the underlying main API directly.
"""
cfg = config or get_simple_config(
dialect=dialect,
@@ -180,11 +185,14 @@ def parse(
parsed = linter.parse_string(sql)
# If we encounter any parsing errors, raise them in a combined issue.
if parsed.violations:
raise APIParsingError(parsed.violations)
violations = parsed.violations
if violations:
raise APIParsingError(violations)
# Return a JSON representation of the parse tree.
if parsed.tree is None: # pragma: no cover
return {}
record = parsed.tree.as_record(show_raw=True)
# NOTE: For the simple API - only a single variant is returned.
root_variant = parsed.root_variant()
assert root_variant, "Files parsed without violations must have a valid variant"
assert root_variant.tree, "Files parsed without violations must have a valid tree"
record = root_variant.tree.as_record(show_raw=True)
assert record
return record

View File

@@ -1364,19 +1364,24 @@ def parse(
output_stream, bench, code_only, total_time, verbose, parsed_strings
)
else:
parsed_strings_dict = [
dict(
filepath=linted_result.fname,
segments=(
linted_result.tree.as_record(
code_only=code_only, show_raw=True, include_meta=include_meta
)
if linted_result.tree
else None
),
parsed_strings_dict = []
for parsed_string in parsed_strings:
# TODO: Multiple variants aren't yet supported here in the non-human
# output of the parse command.
root_variant = parsed_string.root_variant()
# Updating violation count ensures the correct return code below.
violations_count += len(parsed_string.violations)
if root_variant:
assert root_variant.tree
segments = root_variant.tree.as_record(
code_only=code_only, show_raw=True, include_meta=include_meta
)
else:
# Parsing failed - return null for segments.
segments = None
parsed_strings_dict.append(
{"filepath": parsed_string.fname, "segments": segments}
)
for linted_result in parsed_strings
]
if format == FormatType.yaml.value:
# For yaml dumping always dump double quoted strings if they contain

View File

@@ -635,25 +635,53 @@ class OutputStreamFormatter:
verbose: int,
parsed_strings: List[ParsedString],
) -> int:
"""Used by human formatting during the parse."""
"""Used by human formatting during the `sqlfluff parse` command."""
violations_count = 0
timing = TimingSummary()
for parsed_string in parsed_strings:
timing.add(parsed_string.time_dict)
if parsed_string.tree:
output_stream.write(parsed_string.tree.stringify(code_only=code_only))
else:
num_variants = len(parsed_string.parsed_variants)
root_variant = parsed_string.root_variant()
if not root_variant:
# TODO: Make this prettier
output_stream.write("...Failed to Parse...") # pragma: no cover
output_stream.write(
self.colorize("...Failed to Parse...", Color.red)
) # pragma: no cover
elif num_variants == 1:
# Backward compatible single parse
assert root_variant.tree
output_stream.write(root_variant.tree.stringify(code_only=code_only))
else:
# Multi variant parse setup.
output_stream.write(
self.colorize(
f"SQLFluff parsed {num_variants} variants of this file",
Color.blue,
)
)
for idx, variant in enumerate(parsed_string.parsed_variants):
output_stream.write(
self.colorize(
f"Variant {idx + 1}:",
Color.blue,
)
)
if variant.tree:
output_stream.write(variant.tree.stringify(code_only=code_only))
else: # pragma: no cover
output_stream.write(
self.colorize("...Failed to Parse...", Color.red)
)
violations_count += len(parsed_string.violations)
if parsed_string.violations:
violations = parsed_string.violations
violations_count += len(violations)
if violations:
output_stream.write("==== parsing violations ====") # pragma: no cover
for v in parsed_string.violations:
for v in violations:
output_stream.write(self.format_violation(v)) # pragma: no cover
if parsed_string.violations:
if violations:
output_stream.write(
self.format_dialect_warning(parsed_string.config.get("dialect"))
)

View File

@@ -1,9 +1,14 @@
"""Defines small container classes to hold intermediate results during linting."""
from typing import Any, Dict, List, NamedTuple, Optional, Tuple
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
from sqlfluff.core.config import FluffConfig
from sqlfluff.core.errors import SQLBaseError, SQLTemplaterError
from sqlfluff.core.errors import (
SQLBaseError,
SQLLexError,
SQLParseError,
SQLTemplaterError,
)
from sqlfluff.core.parser.segments.base import BaseSegment
from sqlfluff.core.templaters import TemplatedFile
@@ -34,25 +39,91 @@ class RenderedFile(NamedTuple):
source_str: str
class ParsedVariant(NamedTuple):
"""An object to store the result of parsing a single TemplatedFile.
Args:
templated_file (:obj:`TemplatedFile`): Containing the details
of the templated file. If templating fails, this will be `None`.
tree (:obj:`BaseSegment`): The segment structure representing the
parsed file. If parsing fails due to an unrecoverable
violation then we will be None.
lexing_violations (:obj:`list` of :obj:`SQLLexError`): Any violations
raised during the lexing phase.
parsing_violations (:obj:`list` of :obj:`SQLParseError`): Any violations
raised during the lexing phase.
"""
templated_file: TemplatedFile
tree: Optional[BaseSegment]
lexing_violations: List[SQLLexError]
parsing_violations: List[SQLParseError]
def violations(self) -> List[Union[SQLLexError, SQLParseError]]:
"""Returns the combined lexing and parsing violations for this variant."""
return [*self.lexing_violations, *self.parsing_violations]
class ParsedString(NamedTuple):
"""An object to store the result of parsing a string.
Args:
`parsed` is a segment structure representing the parsed file. If
parsing fails due to an unrecoverable violation then we will
return None.
`violations` is a :obj:`list` of violations so far, which will either be
templating, lexing or parsing violations at this stage.
`time_dict` is a :obj:`dict` containing timings for how long each step
parsed_variants (:obj:`list` of :obj:`ParsedVariant`): The parsed
variants of this file. Empty if parsing or templating failed.
templating_violations (:obj:`list` of :obj:`SQLTemplaterError`):
Any violations raised during the templating phase. Any violations
raised during lexing or parsing can be found in the
`parsed_variants`, or accessed using the `.violations()` method
which combines all the violations.
time_dict (:obj:`dict`): Contains timings for how long each step
took in the process.
`templated_file` is a :obj:`TemplatedFile` containing the details
of the templated file. If templating fails, this will return None.
config (:obj:`FluffConfig`): The active config for this file,
including any parsed in-file directives.
fname (str): The name of the file. Used mostly for user feedback.
source_str (str): The raw content of the source file.
"""
tree: Optional[BaseSegment]
violations: List[SQLBaseError]
parsed_variants: List[ParsedVariant]
templating_violations: List[SQLTemplaterError]
time_dict: Dict[str, Any]
templated_file: Optional[TemplatedFile]
config: FluffConfig
fname: str
source_str: str
@property
def violations(self) -> List[SQLBaseError]:
"""Returns the combination of violations for this variant.
NOTE: This is implemented as a property for backward compatibility.
"""
return [
*self.templating_violations,
*(v for variant in self.parsed_variants for v in variant.violations()),
]
def root_variant(self) -> Optional[ParsedVariant]:
"""Returns the root variant if successfully parsed, otherwise None."""
if not self.parsed_variants:
# In the case of a fatal templating error, there will be no valid
# variants. Return None.
return None
root_variant = self.parsed_variants[0]
if not root_variant.tree:
# In the case of a parsing fail, there will be a variant, but it will
# have failed to parse and so will have a null tree. Count this as
# an inappropriate variant to return, so return None.
return None
return root_variant
@property
def tree(self) -> BaseSegment:
"""Return the main variant tree.
NOTE: This method is primarily for testing convenience and therefore
asserts that parsing has been successful. If this isn't appropriate
for the given use case, then don't use this property.
"""
assert self.parsed_variants, "No successfully parsed variants."
root_variant = self.parsed_variants[0]
assert root_variant.tree, "Root variant not successfully parsed."
return root_variant.tree

View File

@@ -32,7 +32,12 @@ from sqlfluff.core.errors import (
SQLTemplaterError,
)
from sqlfluff.core.helpers.file import get_encoding
from sqlfluff.core.linter.common import ParsedString, RenderedFile, RuleTuple
from sqlfluff.core.linter.common import (
ParsedString,
ParsedVariant,
RenderedFile,
RuleTuple,
)
from sqlfluff.core.linter.fix import apply_fixes, compute_anchor_edit_info
from sqlfluff.core.linter.linted_dir import LintedDir
from sqlfluff.core.linter.linted_file import (
@@ -304,49 +309,55 @@ class Linter:
parse_statistics: bool = False,
) -> ParsedString:
"""Parse a rendered file."""
t0 = time.monotonic()
violations = cast(List[SQLBaseError], rendered.templater_violations)
tokens: Optional[Sequence[BaseSegment]]
# TODO: We're limiting ourselves to only the first variant for now.
# We'll eventually parse more variants here.
if rendered.templated_variants:
_root_variant = rendered.templated_variants[0]
tokens, lex_errors = cls._lex_templated_file(_root_variant, rendered.config)
violations += lex_errors
else:
# Having no TemplatedFile to parse implies that templating failed.
# There will be no file or tokens to parse, but we'll still return
# a ParsedFile object and associated timings.
_root_variant = None
tokens = None
parsed_variants: List[ParsedVariant] = []
_lexing_time = 0.0
_parsing_time = 0.0
t1 = time.monotonic()
linter_logger.info("PARSING (%s)", rendered.fname)
if tokens:
parsed, parse_errors = cls._parse_tokens(
tokens,
rendered.config,
fname=rendered.fname,
parse_statistics=parse_statistics,
for idx, variant in enumerate(rendered.templated_variants):
t0 = time.monotonic()
linter_logger.info("Parse Rendered. Lexing Variant %s", idx)
tokens, lex_errors = cls._lex_templated_file(variant, rendered.config)
t1 = time.monotonic()
linter_logger.info("Parse Rendered. Parsing Variant %s", idx)
if tokens:
parsed, parse_errors = cls._parse_tokens(
tokens,
rendered.config,
fname=rendered.fname,
parse_statistics=parse_statistics,
)
else: # pragma: no cover
parsed = None
parse_errors = []
_lt = t1 - t0
_pt = time.monotonic() - t1
linter_logger.info(
"Parse Rendered. Variant %s. Lex in %s. Parse in %s.", idx, _lt, _pt
)
violations += parse_errors
else:
parsed = None
parsed_variants.append(
ParsedVariant(
variant,
parsed,
lex_errors,
parse_errors,
)
)
_lexing_time += _lt
_parsing_time += _pt
time_dict = {
**rendered.time_dict,
"lexing": t1 - t0,
"parsing": time.monotonic() - t1,
"lexing": _lexing_time,
"parsing": _parsing_time,
}
return ParsedString(
parsed,
violations,
time_dict,
_root_variant,
rendered.config,
rendered.fname,
rendered.source_str,
parsed_variants=parsed_variants,
templating_violations=rendered.templater_violations,
time_dict=time_dict,
config=rendered.config,
fname=rendered.fname,
source_str=rendered.source_str,
)
@classmethod
@@ -592,8 +603,15 @@ class Linter:
"""Lint a ParsedString and return a LintedFile."""
violations = parsed.violations
time_dict = parsed.time_dict
tree: Optional[BaseSegment]
if parsed.tree:
tree: Optional[BaseSegment] = None
# TODO: Eventually enable linting of more than just the first variant.
if parsed.parsed_variants:
tree = parsed.parsed_variants[0].tree
variant = parsed.parsed_variants[0].templated_file
else:
variant = None
if tree:
t0 = time.monotonic()
linter_logger.info("LINTING (%s)", parsed.fname)
(
@@ -602,12 +620,12 @@ class Linter:
ignore_mask,
rule_timings,
) = cls.lint_fix_parsed(
parsed.tree,
tree,
config=parsed.config,
rule_pack=rule_pack,
fix=fix,
fname=parsed.fname,
templated_file=parsed.templated_file,
templated_file=variant,
formatter=formatter,
)
# Update the timing dict
@@ -617,8 +635,6 @@ class Linter:
# than any generated during the fixing cycle.
violations += initial_linting_errors
else:
# If no parsed tree, set to None
tree = None
ignore_mask = None
rule_timings = []
if not parsed.config.get("disable_noqa"):
@@ -649,7 +665,7 @@ class Linter:
FileTimings(time_dict, rule_timings),
tree,
ignore_mask=ignore_mask,
templated_file=parsed.templated_file,
templated_file=variant,
encoding=encoding,
)

View File

@@ -117,13 +117,13 @@ def assert_rule_pass_in_sql(code, sql, configs=None, msg=None):
# This section is mainly for aid in debugging.
rendered = linter.render_string(sql, fname="<STR>", config=cfg, encoding="utf-8")
parsed = linter.parse_rendered(rendered)
if parsed.violations:
tree = parsed.tree # Delegate assertions to the `.tree` property
violations = parsed.violations
if violations:
if msg:
print(msg) # pragma: no cover
assert parsed.tree
pytest.fail(parsed.violations[0].desc() + "\n" + parsed.tree.stringify())
assert parsed.tree
print(f"Parsed:\n {parsed.tree.stringify()}")
pytest.fail(violations[0].desc() + "\n" + tree.stringify())
print(f"Parsed:\n {tree.stringify()}")
# Note that lint_string() runs the templater and parser again, in order to
# test the whole linting pipeline in the same way that users do. In other

View File

@@ -5,6 +5,7 @@ import json
import pytest
import sqlfluff
from sqlfluff.api import APIParsingError
from sqlfluff.core.errors import SQLFluffUserError
my_bad_query = "SeLEct *, 1, blah as fOO from myTable"
@@ -562,3 +563,16 @@ def test__api__invalid_dialect():
)
assert str(err.value) == "Error: Unknown dialect 'not_a_real_dialect'"
def test__api__parse_exceptions():
"""Test parse behaviour with errors."""
# Parsable content
result = sqlfluff.parse("SELECT 1")
assert result
# Templater fail
with pytest.raises(APIParsingError):
sqlfluff.parse('SELECT {{ 1 > "a"}}')
# Templater success but parsing fail
with pytest.raises(APIParsingError):
sqlfluff.parse("THIS IS NOT SQL")

View File

@@ -470,6 +470,13 @@ def test__cli__command_render_stdin():
(parse, ["-n", "test/fixtures/cli/passing_b.sql", "--format", "yaml"]),
# Check parsing with no output (used mostly for testing)
(parse, ["-n", "test/fixtures/cli/passing_b.sql", "--format", "none"]),
# Parsing with variants
(
parse,
[
"test/fixtures/cli/jinja_variants.sql",
],
),
# Check the benching commands
(parse, ["-n", "test/fixtures/cli/passing_timing.sql", "--bench"]),
(lint, ["-n", "test/fixtures/cli/passing_timing.sql", "--bench"]),
@@ -719,6 +726,22 @@ def test__cli__command_lint_parse(command):
),
2,
),
# Test machine format parse command with an unparsable file.
(
(
parse,
["test/fixtures/linter/parse_lex_error.sql", "-f", "yaml"],
),
1,
),
# Test machine format parse command with a fatal templating error.
(
(
parse,
["test/fixtures/cli/jinja_fatal_fail.sql", "-f", "yaml"],
),
1,
),
],
)
def test__cli__command_lint_parse_with_retcode(command, ret_code):

View File

@@ -157,7 +157,7 @@ def test_segments_recursive_crawl():
linter = Linter(dialect="ansi")
parsed = linter.parse_string(sql)
functional_tree = segments.Segments(parsed.tree)
functional_tree = segments.Segments(parsed.root_variant().tree)
assert len(functional_tree.recursive_crawl("common_table_expression")) == 1
assert len(functional_tree.recursive_crawl("table_reference")) == 3

View File

@@ -20,6 +20,7 @@ from jinja2.parser import Parser
from sqlfluff.core import FluffConfig, Linter
from sqlfluff.core.errors import SQLFluffSkipFile, SQLFluffUserError, SQLTemplaterError
from sqlfluff.core.parser import BaseSegment
from sqlfluff.core.templaters import JinjaTemplater
from sqlfluff.core.templaters.base import RawFileSlice, TemplatedFile
from sqlfluff.core.templaters.jinja import DummyUndefined
@@ -42,6 +43,15 @@ JINJA_MACRO_CALL_SQL = (
)
def get_parsed(path: str) -> BaseSegment:
"""Testing helper to parse paths."""
linter = Linter()
# Get the first file matching the path string
first_path = next(linter.parse_path(path))
# Delegate parse assertions to the `.tree` property
return first_path.tree
@pytest.mark.parametrize(
"instr, expected_outstr",
[
@@ -657,29 +667,25 @@ def test__templater_jinja_lint_empty():
"""
lntr = Linter(dialect="ansi")
parsed = lntr.parse_string(in_str='{{ "" }}')
assert parsed.templated_file.source_str == '{{ "" }}'
assert parsed.templated_file.templated_str == ""
parsed_variant = parsed.parsed_variants[0]
assert parsed_variant.templated_file.source_str == '{{ "" }}'
assert parsed_variant.templated_file.templated_str == ""
# Get the types of the segments
print(f"Segments: {parsed.tree.raw_segments}")
seg_types = [seg.get_type() for seg in parsed.tree.raw_segments]
print(f"Segments: {parsed_variant.tree.raw_segments}")
seg_types = [seg.get_type() for seg in parsed_variant.tree.raw_segments]
assert seg_types == ["placeholder", "end_of_file"]
def assert_structure(yaml_loader, path, code_only=True, include_meta=False):
"""Check that a parsed sql file matches the yaml file with the same name."""
lntr = Linter()
p = list(lntr.parse_path(path + ".sql"))
parsed = p[0][0]
if parsed is None:
print(p)
raise RuntimeError(p[0][1])
parsed = get_parsed(path + ".sql")
# Whitespace is important here to test how that's treated
tpl = parsed.to_tuple(code_only=code_only, show_raw=True, include_meta=include_meta)
# Check nothing unparsable
if "unparsable" in parsed.type_set():
print(parsed.stringify())
raise ValueError("Input file is unparsable.")
_hash, expected = yaml_loader(path + ".yml")
_, expected = yaml_loader(path + ".yml")
assert tpl == expected
@@ -755,9 +761,7 @@ def test__templater_jinja_block_matching(caplog):
caplog.set_level(logging.DEBUG, logger="sqlfluff.lexer")
path = "test/fixtures/templater/jinja_l_metas/002.sql"
# Parse the file.
p = list(Linter().parse_path(path))
parsed = p[0][0]
assert parsed
parsed = get_parsed(path)
# We only care about the template elements
template_segments = [
seg

View File

@@ -220,13 +220,12 @@ def test__dialect__ansi_parse_indented_joins(sql_string, indented_joins, meta_lo
)
)
parsed = lnt.parse_string(sql_string)
tree = parsed.tree
# Check that there's nothing unparsable
assert "unparsable" not in parsed.tree.type_set()
assert "unparsable" not in tree.type_set()
# Check all the segments that *should* be metas, ARE.
# NOTE: This includes the end of file marker.
res_meta_locs = tuple(
idx
for idx, raw_seg in enumerate(parsed.tree.get_raw_segments())
if raw_seg.is_meta
idx for idx, raw_seg in enumerate(tree.get_raw_segments()) if raw_seg.is_meta
)
assert res_meta_locs == meta_loc

View File

@@ -72,8 +72,8 @@ def test__dialect__base_file_parse(dialect, file):
# Check we're all there.
assert parsed.tree.raw == raw
# Check that there's nothing unparsable
typs = parsed.tree.type_set()
assert "unparsable" not in typs
types = parsed.tree.type_set()
assert "unparsable" not in types
# When testing the validity of fixes we re-parse sections of the file.
# To ensure this is safe - here we re-parse the unfixed file to ensure
# it's still valid even in the case that no fixes have been applied.
@@ -107,8 +107,7 @@ def test__dialect__base_broad_fix(
parsed: Optional[ParsedString] = lex_and_parse(config_overrides, raw)
if not parsed: # Empty file case
return
else:
print(parsed.tree.stringify())
print(parsed.tree.stringify())
config = FluffConfig(overrides=config_overrides)
linter = Linter(config=config)

View File

@@ -0,0 +1 @@
select {{ 1 > "foo"}}