Uncalled source in pipeline.run( (#3369)

This commit is contained in:
anuunchin
2025-11-24 13:45:12 +01:00
committed by GitHub
parent 033312d373
commit 81ebbcca43
2 changed files with 44 additions and 7 deletions

View File

@@ -52,7 +52,7 @@ from dlt.extract.items_transform import ItemTransform
from dlt.common.metrics import DataWriterAndCustomMetrics
from dlt.extract.pipe_iterator import PipeIterator
from dlt.extract.source import DltSource
from dlt.extract.reference import SourceReference
from dlt.extract.reference import SourceReference, SourceFactory
from dlt.extract.resource import DltResource
from dlt.extract.storage import ExtractStorage
from dlt.extract.extractors import ObjectExtractor, ArrowExtractor, Extractor, ModelExtractor
@@ -125,6 +125,11 @@ def data_to_sources(
# many resources with the same name may be present
r_ = resources.setdefault(data_item.name, [])
r_.append(data_item)
elif isinstance(data_item, SourceFactory):
source = data_item()
if schema:
source.schema = schema
sources.append(source)
else:
# iterator/iterable/generator
# create resource first without table template

View File

@@ -42,7 +42,7 @@ from dlt.common.runtime.collector import DictCollector, LogCollector
from dlt.common.schema.exceptions import TableIdentifiersFrozen
from dlt.common.schema.typing import TColumnSchema
from dlt.common.schema.utils import get_first_column_name_with_prop, new_column, new_table
from dlt.common.typing import DictStrAny
from dlt.common.typing import DictStrAny, TDataItems
from dlt.common.utils import uniq_id
from dlt.common.schema import Schema
@@ -1848,12 +1848,9 @@ def test_invalid_data_edge_cases() -> None:
def my_source():
return dlt.resource(itertools.count(start=1), name="infinity").add_limit(5)
# this function will be evaluated like any other. it returns resource which in the pipe
# is just an iterator and it will be iterated
# TODO: we should probably block that behavior
pipeline.run(my_source)
assert pipeline.last_trace.last_normalize_info.row_counts["my_source"] == 5
assert pipeline.last_trace.last_normalize_info.row_counts["infinity"] == 5
def res_return():
return dlt.resource(itertools.count(start=1), name="infinity").add_limit(5)
@@ -1876,7 +1873,7 @@ def test_invalid_data_edge_cases() -> None:
yield dlt.resource(itertools.count(start=1), name="infinity").add_limit(5)
pipeline.run(my_source_yield)
assert pipeline.last_trace.last_normalize_info.row_counts["my_source_yield"] == 5
assert pipeline.last_trace.last_normalize_info.row_counts["infinity"] == 5
# pipeline = dlt.pipeline(pipeline_name="invalid", destination=DUMMY_COMPLETE)
# with pytest.raises(PipelineStepFailed) as pip_ex:
@@ -4347,6 +4344,41 @@ def test_signal_extract_step_shutdown(sig: int) -> None:
_done = True
def test_uninitialized_source_factory() -> None:
"""Test that passing an uncalled source factory preserves resource-level hints such as write_disposition"""
data = [
{"id": 1, "name": "bulbasaur", "size": {"weight": 6.9, "height": 0.7}},
{"id": 4, "name": "charmander", "size": {"weight": 8.5, "height": 0.6}},
{"id": 25, "name": "pikachu", "size": {"weight": 6, "height": 0.4}},
]
@dlt.resource(
name="pokemon_resource",
write_disposition="merge",
primary_key="id",
)
def pokemon() -> TDataItems:
yield data
@dlt.source(name="pokemon_source")
def pokemon_source():
yield pokemon
pipeline = dlt.pipeline(
pipeline_name=uniq_id(),
destination="duckdb",
dataset_name="pokemon_data",
)
pipeline.run(pokemon_source)
pipeline.run(pokemon_source)
with pipeline.sql_client() as client:
rows = client.execute_sql("SELECT * FROM pokemon_resource")
assert len(rows) == 3
def test_cleanup() -> None:
# this must happen after all forked tests (problems with tests teardowns in other tests)
pass