Source code for astro.resolver.resolver

"""Canonical ID resolver with vectorized Polars operations."""

from __future__ import annotations

import uuid
from datetime import date
from pathlib import Path

import polars as pl

from astro.resolver.hashing import compute_hash_columns
from astro.resolver.models import (
    EntryStatus,
    HashGroupsConfig,
    ResolverConfig,
    changed_column_name,
    hash_column_name,
)
from astro.resolver.store import (
    CANONICAL_ID_COLUMN,
    LAST_CHANGED_DATE_COLUMN,
    SOURCE_KEY_COLUMN,
    UPDATE_DATES_COLUMN,
    ResolverStore,
)

_NAMESPACED_SOURCE_KEY_COLUMN = "_astro_source_key"
_STORED_SUFFIX = "_stored"


[docs] class CanonicalIdResolver: """Map source keys to canonical UUIDs and detect grouped field changes. Stores persistent state at ``{pipeline_dir}/.persistent/{name}.parquet``. """
[docs] def __init__( self, pipeline_dir: Path, name: str, hash_groups: HashGroupsConfig, ) -> None: """Create a resolver scoped to a named persistent store. Args: pipeline_dir: Pipeline working directory. name: Store name (for example ``establishments``). hash_groups: Mapping of group names to ``"*all"`` or field name lists. """ self._config = ResolverConfig(name=name, hash_groups=hash_groups) self.hash_groups = hash_groups self._store = ResolverStore(pipeline_dir, name, hash_groups)
@property def store_path(self) -> Path: return self._store.path
[docs] def resolve( self, data: pl.DataFrame, *, source_key_column: str, namespace: str, run_date: date, exclude_columns: frozenset[str] = frozenset(), ) -> pl.DataFrame: """Resolve canonical IDs and change flags for each row in ``data``. Args: data: Input DataFrame containing ``source_key_column``. source_key_column: Column with pipeline-provided identifiers. namespace: Prefix for stored keys as ``{namespace}:{source_key}``. run_date: Date used for change tracking. exclude_columns: Columns excluded from ``"*all"`` hash groups. Returns: DataFrame augmented with ``canonical_id``, ``status``, and ``{group}_changed`` columns. """ self._validate_input(data, source_key_column=source_key_column, namespace=namespace) working = compute_hash_columns( data, self.hash_groups, source_key_column=source_key_column, exclude_columns=exclude_columns, ) working = working.with_columns( pl.concat_str( [pl.lit(f"{namespace}:"), pl.col(source_key_column).cast(pl.String)] ).alias(_NAMESPACED_SOURCE_KEY_COLUMN) ) store = self._store.load() store_for_join = store.select( pl.col(SOURCE_KEY_COLUMN), pl.col(CANONICAL_ID_COLUMN).alias(f"{CANONICAL_ID_COLUMN}{_STORED_SUFFIX}"), pl.col(LAST_CHANGED_DATE_COLUMN).alias(f"{LAST_CHANGED_DATE_COLUMN}{_STORED_SUFFIX}"), pl.col(UPDATE_DATES_COLUMN).alias(f"{UPDATE_DATES_COLUMN}{_STORED_SUFFIX}"), *[ pl.col(hash_column_name(group_name)).alias( f"{hash_column_name(group_name)}{_STORED_SUFFIX}" ) for group_name in self.hash_groups ], ) joined = working.join( store_for_join, left_on=_NAMESPACED_SOURCE_KEY_COLUMN, right_on=SOURCE_KEY_COLUMN, how="left", ) is_new = pl.col(f"{CANONICAL_ID_COLUMN}{_STORED_SUFFIX}").is_null() group_changed_exprs: list[pl.Expr] = [] hash_changed_exprs: list[pl.Expr] = [] for group_name in self.hash_groups: hash_column = hash_column_name(group_name) stored_hash_column = f"{hash_column}{_STORED_SUFFIX}" changed_column = changed_column_name(group_name) group_changed = is_new | ( pl.col(hash_column) != pl.col(stored_hash_column).fill_null("__missing__") ) group_changed_exprs.append(group_changed.alias(changed_column)) hash_changed_exprs.append(group_changed) any_hash_changed = pl.any_horizontal(hash_changed_exprs) if hash_changed_exprs else is_new all_hashes_match = ~is_new & ~any_hash_changed joined = joined.with_columns(group_changed_exprs).with_columns( pl.when(is_new) .then(pl.lit(EntryStatus.NEW.value)) .when(all_hashes_match) .then(pl.lit(EntryStatus.UNCHANGED.value)) .otherwise(pl.lit(EntryStatus.CHANGED.value)) .alias("status") ) joined = self._assign_new_canonical_ids(joined, is_new=is_new) joined = joined.with_columns( pl.coalesce( pl.col(f"{CANONICAL_ID_COLUMN}{_STORED_SUFFIX}"), pl.col("_astro_new_canonical_id"), ).alias(CANONICAL_ID_COLUMN) ) run_date_expr = pl.lit(run_date, dtype=pl.Date) joined = joined.with_columns( pl.when(is_new | any_hash_changed) .then(run_date_expr) .otherwise(pl.col(f"{LAST_CHANGED_DATE_COLUMN}{_STORED_SUFFIX}")) .alias(LAST_CHANGED_DATE_COLUMN), pl.when(is_new) .then(pl.concat_list([run_date_expr])) .when(any_hash_changed & ~is_new) .then( pl.col(f"{UPDATE_DATES_COLUMN}{_STORED_SUFFIX}") .fill_null(pl.lit([], dtype=pl.List(pl.Date))) .list.concat(run_date_expr) ) .otherwise(pl.col(f"{UPDATE_DATES_COLUMN}{_STORED_SUFFIX}")) .alias(UPDATE_DATES_COLUMN), ) store_updates = joined.select( pl.col(_NAMESPACED_SOURCE_KEY_COLUMN).alias(SOURCE_KEY_COLUMN), pl.col(CANONICAL_ID_COLUMN), pl.col(LAST_CHANGED_DATE_COLUMN), pl.col(UPDATE_DATES_COLUMN), *[pl.col(hash_column_name(group_name)) for group_name in self.hash_groups], ) self._upsert_store(store, store_updates) output_columns = [ *data.columns, CANONICAL_ID_COLUMN, "status", *[changed_column_name(group_name) for group_name in self.hash_groups], ] return joined.select(output_columns)
def _assign_new_canonical_ids(self, joined: pl.DataFrame, *, is_new: pl.Expr) -> pl.DataFrame: new_keys = ( joined.filter(is_new) .select(_NAMESPACED_SOURCE_KEY_COLUMN) .unique() .to_series() .to_list() ) if not new_keys: return joined.with_columns( pl.lit(None, dtype=pl.String).alias("_astro_new_canonical_id") ) uuid_assignments = pl.DataFrame( { _NAMESPACED_SOURCE_KEY_COLUMN: new_keys, "_astro_new_canonical_id": [str(uuid.uuid4()) for _ in new_keys], } ) return joined.join(uuid_assignments, on=_NAMESPACED_SOURCE_KEY_COLUMN, how="left") def _upsert_store(self, store: pl.DataFrame, store_updates: pl.DataFrame) -> None: if store_updates.is_empty(): return batch_keys = store_updates.get_column(SOURCE_KEY_COLUMN) unchanged_store = store.filter(~pl.col(SOURCE_KEY_COLUMN).is_in(batch_keys)) merged_store = pl.concat([unchanged_store, store_updates], how="vertical_relaxed").unique( SOURCE_KEY_COLUMN, keep="last", ) self._store.save(merged_store) def _validate_input( self, data: pl.DataFrame, *, source_key_column: str, namespace: str, ) -> None: if not namespace: raise ValueError("namespace must not be empty.") if source_key_column not in data.columns: raise ValueError(f"source_key_column {source_key_column!r} not found in data.") missing_fields: set[str] = set() for _group_name, fields in self.hash_groups.items(): if fields == "*all": continue missing_fields.update(set(fields) - set(data.columns)) if missing_fields: missing_columns = ", ".join(sorted(missing_fields)) raise ValueError(f"Hash groups reference missing columns: {missing_columns}")