"""Async implementation of Redis checkpoint saver."""

from __future__ import annotations

import asyncio
import json
import logging
from collections import defaultdict
from contextlib import asynccontextmanager
from types import TracebackType
from typing import (
    Any,
    AsyncIterator,
    Dict,
    List,
    Optional,
    Sequence,
    Tuple,
    Type,
    Union,
    cast,
)

import orjson
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base import (
    WRITES_IDX_MAP,
    ChannelVersions,
    Checkpoint,
    CheckpointMetadata,
    CheckpointTuple,
    PendingWrite,
    get_checkpoint_id,
)
from langgraph.constants import TASKS
from redis.asyncio import Redis as AsyncRedis
from redis.asyncio.cluster import RedisCluster as AsyncRedisCluster
from redisvl.index import AsyncSearchIndex
from redisvl.query import FilterQuery
from redisvl.query.filter import Num, Tag
from redisvl.redis.connection import RedisConnectionFactory
from ulid import ULID

from langgraph.checkpoint.redis.base import (
    CHECKPOINT_PREFIX,
    CHECKPOINT_WRITE_PREFIX,
    REDIS_KEY_SEPARATOR,
    BaseRedisSaver,
)
from langgraph.checkpoint.redis.key_registry import (
    AsyncCheckpointKeyRegistry as AsyncKeyRegistry,
)
from langgraph.checkpoint.redis.util import (
    EMPTY_ID_SENTINEL,
    from_storage_safe_id,
    from_storage_safe_str,
    to_storage_safe_id,
    to_storage_safe_str,
)

logger = logging.getLogger(__name__)


class AsyncRedisSaver(
    BaseRedisSaver[Union[AsyncRedis, AsyncRedisCluster], AsyncSearchIndex]
):
    """Async Redis implementation for checkpoint saver.

    Supports standard Redis URLs (redis://), SSL (rediss://), and
    Sentinel URLs (redis+sentinel://host:26379/service_name/db).
    """

    _redis_url: str
    checkpoints_index: AsyncSearchIndex
    checkpoint_writes_index: AsyncSearchIndex

    _redis: Union[
        AsyncRedis, AsyncRedisCluster
    ]  # Support both standalone and cluster clients
    # Whether to assume the Redis server is a cluster; None triggers auto-detection
    cluster_mode: Optional[bool] = None
    _key_registry: Optional[AsyncKeyRegistry] = None  # Track keys to avoid SCAN/KEYS

    # Instance-level cache (will be initialized in __init__)

    def __init__(
        self,
        redis_url: Optional[str] = None,
        *,
        redis_client: Optional[Union[AsyncRedis, AsyncRedisCluster]] = None,
        connection_args: Optional[Dict[str, Any]] = None,
        ttl: Optional[Dict[str, Any]] = None,
        checkpoint_prefix: str = CHECKPOINT_PREFIX,
        checkpoint_write_prefix: str = CHECKPOINT_WRITE_PREFIX,
    ) -> None:
        super().__init__(
            redis_url=redis_url,
            redis_client=redis_client,
            connection_args=connection_args,
            ttl=ttl,
            checkpoint_prefix=checkpoint_prefix,
            checkpoint_write_prefix=checkpoint_write_prefix,
        )
        self.loop = asyncio.get_running_loop()

        # Instance-level cache for frequently used keys (limited size to prevent memory issues)
        self._key_cache: Dict[str, str] = {}
        self._key_cache_max_size = 1000  # Configurable limit

        self._separator = REDIS_KEY_SEPARATOR

    def configure_client(
        self,
        redis_url: Optional[str] = None,
        redis_client: Optional[Union[AsyncRedis, AsyncRedisCluster]] = None,
        connection_args: Optional[Dict[str, Any]] = None,
    ) -> None:
        """Configure the Redis client.

        Supports standard Redis URLs (redis://), SSL (rediss://), and
        Sentinel URLs (redis+sentinel://host:26379/service_name/db).
        """
        self._owns_its_client = redis_client is None

        if redis_client is None:
            self._redis = RedisConnectionFactory.get_async_redis_connection(
                redis_url, **(connection_args or {})
            )
        else:
            self._redis = redis_client

    def create_indexes(self) -> None:
        """Create indexes without connecting to Redis."""
        self.checkpoints_index = AsyncSearchIndex.from_dict(
            self.checkpoints_schema, redis_client=self._redis
        )
        self.checkpoint_writes_index = AsyncSearchIndex.from_dict(
            self.writes_schema, redis_client=self._redis
        )

    def _make_redis_checkpoint_key_cached(
        self, thread_id: str, checkpoint_ns: str, checkpoint_id: str
    ) -> str:
        """Optimized key generation with caching."""
        # Create cache key
        cache_key = f"ckpt:{thread_id}:{checkpoint_ns}:{checkpoint_id}"

        # Check cache first
        if cache_key in self._key_cache:
            return self._key_cache[cache_key]

        # Generate key using optimized string operations
        safe_thread_id = str(to_storage_safe_id(thread_id))
        safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)
        safe_checkpoint_id = str(to_storage_safe_id(checkpoint_id))

        # Use pre-computed prefix and join
        key = self._separator.join(
            [
                self._checkpoint_prefix,
                safe_thread_id,
                safe_checkpoint_ns,
                safe_checkpoint_id,
            ]
        )

        # Cache for future use (limit cache size to prevent memory issues)
        if len(self._key_cache) < self._key_cache_max_size:
            self._key_cache[cache_key] = key

        return key

    def _make_redis_checkpoint_writes_key_cached(
        self,
        thread_id: str,
        checkpoint_ns: str,
        checkpoint_id: str,
        task_id: str,
        idx: Optional[int],
    ) -> str:
        """Optimized writes key generation with caching."""
        # Create cache key
        cache_key = f"write:{thread_id}:{checkpoint_ns}:{checkpoint_id}:{task_id}:{idx}"

        # Check cache first
        if cache_key in self._key_cache:
            return self._key_cache[cache_key]

        # Generate key using optimized string operations
        safe_thread_id = str(to_storage_safe_id(thread_id))
        safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)
        safe_checkpoint_id = str(to_storage_safe_id(checkpoint_id))

        # Build key components
        key_parts = [
            self._checkpoint_write_prefix,
            safe_thread_id,
            safe_checkpoint_ns,
            safe_checkpoint_id,
            task_id,
        ]

        if idx is not None:
            key_parts.append(str(idx))

        # Use pre-computed separator
        key = self._separator.join(key_parts)

        # Cache for future use (limit cache size)
        if len(self._key_cache) < 10000:
            self._key_cache[cache_key] = key

        return key

    async def __aenter__(self) -> AsyncRedisSaver:
        """Async context manager enter."""
        await self.asetup()

        # Set client info once Redis is set up
        await self.aset_client_info()

        return self

    async def __aexit__(
        self,
        _exc_type: Optional[Type[BaseException]],
        _exc_val: Optional[BaseException],
        _exc_tb: Optional[TracebackType],
    ) -> None:
        """Async context manager exit."""
        if self._owns_its_client:
            await self._redis.aclose()
            # RedisCluster doesn't have connection_pool attribute
            if getattr(self._redis, "connection_pool", None):
                coro = self._redis.connection_pool.disconnect()
                if coro:
                    await coro

            # Prevent RedisVL from attempting to close the client
            # on an event loop in a separate thread.
            self.checkpoints_index._redis_client = None
            self.checkpoint_writes_index._redis_client = None

    async def asetup(self) -> None:
        """Set up the checkpoint saver."""
        self.create_indexes()
        await self.checkpoints_index.create(overwrite=False)
        await self.checkpoint_writes_index.create(overwrite=False)

        # Detect cluster mode if not explicitly set
        await self._detect_cluster_mode()

        # Initialize key registry
        self._key_registry = AsyncKeyRegistry(self._redis)

    async def setup(self) -> None:  # type: ignore[override]
        """Set up the checkpoint saver asynchronously.

        This method creates the necessary indices in Redis and detects cluster mode.
        It MUST be called before using the checkpointer.

        This async method follows the canonical pattern used by PostgreSQL and SQLite
        checkpointers in the LangGraph ecosystem. The type ignore is necessary because
        the base class defines a sync setup() method, but async checkpointers require
        an async setup() method to properly handle coroutines.

        Usage: await checkpointer.setup()
        """
        await self.asetup()

    async def _detect_cluster_mode(self) -> None:
        """Detect if the Redis client is a cluster client by inspecting its class."""
        if self.cluster_mode is not None:
            logger.info(
                f"Redis cluster_mode explicitly set to {self.cluster_mode}, skipping detection."
            )
            return

        # Determine cluster mode based on client class
        if isinstance(self._redis, AsyncRedisCluster):
            logger.info("Redis client is a cluster client")
            self.cluster_mode = True
        else:
            logger.info("Redis client is a standalone client")
            self.cluster_mode = False

    async def _apply_ttl_to_keys(
        self,
        main_key: str,
        related_keys: Optional[list[str]] = None,
        ttl_minutes: Optional[float] = None,
    ) -> Any:
        """Apply Redis native TTL to keys asynchronously.

        Args:
            main_key: The primary Redis key
            related_keys: Additional Redis keys that should expire at the same time
            ttl_minutes: Time-to-live in minutes, overrides default_ttl if provided
                        Use -1 to remove TTL (make keys persistent)

        Returns:
            Result of the Redis operation
        """
        if ttl_minutes is None:
            # Check if there's a default TTL in config
            if self.ttl_config and "default_ttl" in self.ttl_config:
                ttl_minutes = self.ttl_config.get("default_ttl")

        if ttl_minutes is not None:
            # Special case: -1 means remove TTL (make persistent)
            if ttl_minutes == -1:
                # Apply PERSIST individually per key so that a single failure
                # does not prevent TTL removal on the remaining keys.
                all_keys = [main_key] + (related_keys or [])
                for key in all_keys:
                    try:
                        await self._redis.persist(key)
                    except Exception:
                        logger.warning("Failed to remove TTL from key: %s", key)

                return True

            # Regular TTL setting
            ttl_seconds = int(ttl_minutes * 60)

            # Apply TTL individually per key so that a single EXPIRE failure
            # (e.g. MOVED on Redis Enterprise proxy) does not prevent TTL
            # from being set on the remaining keys.
            all_keys = [main_key] + (related_keys or [])
            for key in all_keys:
                try:
                    await self._redis.expire(key, ttl_seconds)
                except Exception:
                    logger.warning("Failed to apply TTL to key: %s", key)

            return True

    async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
        """Get a checkpoint tuple from Redis asynchronously."""
        thread_id = config["configurable"]["thread_id"]
        checkpoint_id = get_checkpoint_id(config)
        checkpoint_ns = config["configurable"].get("checkpoint_ns", "")

        # For values we store in Redis, we need to convert empty strings to the
        # sentinel value.
        storage_safe_thread_id = to_storage_safe_id(thread_id)
        storage_safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)

        if checkpoint_id and checkpoint_id != EMPTY_ID_SENTINEL:
            # Use direct key access instead of FT.SEARCH when checkpoint_id is known
            storage_safe_checkpoint_id = to_storage_safe_id(checkpoint_id)

            # Construct direct key for checkpoint data
            checkpoint_key = self._make_redis_checkpoint_key(
                storage_safe_thread_id,
                storage_safe_checkpoint_ns,
                storage_safe_checkpoint_id,
            )

            # Create pipeline for efficient batch operations
            pipeline = self._redis.pipeline(transaction=False)

            # Add checkpoint data fetch to pipeline
            pipeline.json().get(checkpoint_key)

            # Add TTL check if refresh_on_read is enabled
            if self.ttl_config and self.ttl_config.get("refresh_on_read"):
                pipeline.ttl(checkpoint_key)

            # Execute pipeline to get checkpoint data and TTL
            pipeline_results = await pipeline.execute()

            checkpoint_data = pipeline_results[0]
            if not checkpoint_data:
                return None

            # Extract TTL if we fetched it
            current_ttl = None
            if self.ttl_config and self.ttl_config.get("refresh_on_read"):
                current_ttl = pipeline_results[1]

            # Create doc-like object from direct access
            doc = {
                "thread_id": checkpoint_data.get("thread_id", storage_safe_thread_id),
                "checkpoint_ns": checkpoint_data.get(
                    "checkpoint_ns", storage_safe_checkpoint_ns
                ),
                "checkpoint_id": checkpoint_data.get(
                    "checkpoint_id", storage_safe_checkpoint_id
                ),
                "parent_checkpoint_id": checkpoint_data.get(
                    "parent_checkpoint_id", storage_safe_checkpoint_id
                ),
                "$.checkpoint": json.dumps(checkpoint_data.get("checkpoint", {})),
                "$.metadata": checkpoint_data.get(
                    "metadata", "{}"
                ),  # metadata is already a JSON string
            }
        else:
            # Try to get latest checkpoint using pointer
            latest_pointer_key = f"checkpoint_latest:{storage_safe_thread_id}:{storage_safe_checkpoint_ns}"

            checkpoint_key = await self._redis.get(latest_pointer_key)
            if not checkpoint_key:
                # No pointer means no checkpoints exist
                return None

            # Create pipeline for efficient operations
            pipeline = self._redis.pipeline(transaction=False)

            # Add checkpoint data fetch to pipeline
            pipeline.json().get(checkpoint_key)

            # Add TTL check if refresh_on_read is enabled
            if self.ttl_config and self.ttl_config.get("refresh_on_read"):
                pipeline.ttl(checkpoint_key)

            # Execute pipeline
            pipeline_results = await pipeline.execute()

            checkpoint_data = pipeline_results[0]
            if not checkpoint_data:
                # Pointer exists but checkpoint is missing - data inconsistency
                return None

            # Extract TTL if we fetched it
            current_ttl = None
            if self.ttl_config and self.ttl_config.get("refresh_on_read"):
                current_ttl = pipeline_results[1]

            # Create doc-like object from direct access
            doc = {
                "thread_id": checkpoint_data.get("thread_id", storage_safe_thread_id),
                "checkpoint_ns": checkpoint_data.get(
                    "checkpoint_ns", storage_safe_checkpoint_ns
                ),
                "checkpoint_id": checkpoint_data.get("checkpoint_id"),
                "parent_checkpoint_id": checkpoint_data.get("parent_checkpoint_id"),
                "$.checkpoint": json.dumps(checkpoint_data.get("checkpoint", {})),
                "$.metadata": checkpoint_data.get(
                    "metadata", "{}"
                ),  # metadata is already a JSON string
            }

        doc_thread_id = from_storage_safe_id(doc["thread_id"])
        doc_checkpoint_ns = from_storage_safe_str(doc["checkpoint_ns"])
        doc_checkpoint_id = from_storage_safe_id(doc["checkpoint_id"])
        doc_parent_checkpoint_id = from_storage_safe_id(doc["parent_checkpoint_id"])

        # Lazy TTL refresh - only refresh if TTL is below threshold
        if self.ttl_config and self.ttl_config.get("refresh_on_read"):
            # If we didn't get TTL from pipeline (i.e., came from else branch), fetch it now
            if "current_ttl" not in locals():
                # Get the checkpoint key
                checkpoint_key = self._make_redis_checkpoint_key(
                    to_storage_safe_id(doc_thread_id),
                    to_storage_safe_str(doc_checkpoint_ns),
                    to_storage_safe_id(doc_checkpoint_id),
                )
                current_ttl = await self._redis.ttl(checkpoint_key)

            # Always refresh TTL when refresh_on_read is enabled
            # This ensures all related keys maintain synchronized TTLs

            # Only refresh if key exists and has TTL (skip keys with no expiry)
            # TTL states: -2 = key doesn't exist, -1 = key exists but no TTL, 0 = expired, >0 = seconds remaining
            if current_ttl > 0:
                # Get write keys from registry instead of SCAN
                write_keys = []

                if self._key_registry:
                    write_keys = await self._key_registry.get_write_keys(
                        doc_thread_id, doc_checkpoint_ns, doc_checkpoint_id
                    )

                # Apply TTL to checkpoint and write keys
                await self._apply_ttl_to_keys(
                    checkpoint_key, write_keys if write_keys else None
                )

                # Also refresh TTL on registry keys if they exist
                if self._key_registry and self.ttl_config:
                    ttl_minutes = self.ttl_config.get("default_ttl")
                    if ttl_minutes is not None:
                        ttl_seconds = int(ttl_minutes * 60)
                        # Registry TTL is handled per checkpoint
                        await self._key_registry.apply_ttl(
                            doc_thread_id,
                            doc_checkpoint_ns,
                            doc_checkpoint_id,
                            ttl_seconds,
                        )

        # Fetch channel_values - pass channel_versions if we have them from direct access
        checkpoint_raw = (
            doc.get("$.checkpoint")
            if isinstance(doc, dict)
            else getattr(doc, "$.checkpoint", None)
        )
        if isinstance(checkpoint_raw, str):
            checkpoint_data_dict = json.loads(checkpoint_raw)
        else:
            checkpoint_data_dict = checkpoint_raw

        channel_versions_from_checkpoint = (
            checkpoint_data_dict.get("channel_versions")
            if checkpoint_data_dict
            else None
        )

        # Run channel_values, pending_sends, and pending_writes loads in parallel
        # Create list of coroutines to run
        tasks: List[Any] = []

        # Always load channel values
        tasks.append(
            self.aget_channel_values(
                thread_id=doc_thread_id,
                checkpoint_ns=doc_checkpoint_ns,
                checkpoint_id=doc_checkpoint_id,
                channel_versions=channel_versions_from_checkpoint,
            )
        )

        # Conditionally load pending sends if parent exists
        if doc_parent_checkpoint_id:
            tasks.append(
                self._aload_pending_sends(
                    thread_id=thread_id,
                    checkpoint_ns=doc_checkpoint_ns,
                    parent_checkpoint_id=doc_parent_checkpoint_id,
                )
            )

        # Always load pending writes
        tasks.append(
            self._aload_pending_writes(thread_id, checkpoint_ns, doc_checkpoint_id)
        )

        # Execute all tasks in parallel - pending_sends is optional
        if doc_parent_checkpoint_id:
            results = await asyncio.gather(*tasks)
            channel_values: Dict[str, Any] = self._recursive_deserialize(results[0])
            pending_sends: List[Tuple[str, Union[str, bytes]]] = results[1]
            pending_writes: List[PendingWrite] = results[2]
        else:
            # Only channel_values and pending_writes tasks
            results = await asyncio.gather(*tasks)
            channel_values = self._recursive_deserialize(results[0])
            pending_sends = []
            pending_writes = results[1]

        # Fetch and parse metadata
        raw_metadata = (
            doc.get("$.metadata", "{}")
            if isinstance(doc, dict)
            else getattr(doc, "$.metadata", "{}")
        )
        metadata_dict = (
            json.loads(raw_metadata) if isinstance(raw_metadata, str) else raw_metadata
        )

        # Ensure metadata matches CheckpointMetadata type
        sanitized_metadata = {
            k.replace("\u0000", ""): (
                v.replace("\u0000", "") if isinstance(v, str) else v
            )
            for k, v in metadata_dict.items()
        }
        metadata = cast(CheckpointMetadata, sanitized_metadata)

        config_param: RunnableConfig = {
            "configurable": {
                "thread_id": thread_id,
                "checkpoint_ns": checkpoint_ns,
                "checkpoint_id": doc_checkpoint_id,
            }
        }

        # Handle both direct dict access and FT.SEARCH results
        checkpoint_data = doc["$.checkpoint"]
        if isinstance(checkpoint_data, dict):
            # Direct key access returns dict, convert to JSON string for consistency
            checkpoint_data = json.dumps(checkpoint_data)

        checkpoint_param = self._load_checkpoint(
            checkpoint_data,
            channel_values,
            pending_sends,
        )

        # Build parent config if parent_checkpoint_id exists
        parent_config: RunnableConfig | None = None
        if doc_parent_checkpoint_id:
            parent_config = {
                "configurable": {
                    "thread_id": thread_id,
                    "checkpoint_ns": checkpoint_ns,
                    "checkpoint_id": doc_parent_checkpoint_id,
                }
            }

        return CheckpointTuple(
            config=config_param,
            checkpoint=checkpoint_param,
            metadata=metadata,
            parent_config=parent_config,
            pending_writes=pending_writes,
        )

    async def alist(
        self,
        config: Optional[RunnableConfig],
        *,
        filter: Optional[dict[str, Any]] = None,
        before: Optional[RunnableConfig] = None,  # noqa: ARG002
        limit: Optional[int] = None,
    ) -> AsyncIterator[CheckpointTuple]:
        """List checkpoints from Redis asynchronously."""
        # Construct the filter expression
        filter_expression = []
        if config:
            filter_expression.append(
                Tag("thread_id")
                == to_storage_safe_id(config["configurable"]["thread_id"])
            )
            if run_id := config["configurable"].get("run_id"):
                filter_expression.append(Tag("run_id") == to_storage_safe_id(run_id))

            # Search for checkpoints with any namespace, including an empty
            # string, while `checkpoint_id` has to have a value.
            if checkpoint_ns := config["configurable"].get("checkpoint_ns"):
                filter_expression.append(
                    Tag("checkpoint_ns") == to_storage_safe_str(checkpoint_ns)
                )
            if checkpoint_id := get_checkpoint_id(config):
                filter_expression.append(
                    Tag("checkpoint_id") == to_storage_safe_id(checkpoint_id)
                )

        if filter:
            for k, v in filter.items():
                if k == "source":
                    filter_expression.append(Tag("source") == v)
                elif k == "step":
                    filter_expression.append(Num("step") == v)
                elif k == "thread_id":
                    filter_expression.append(Tag("thread_id") == to_storage_safe_id(v))
                elif k == "run_id":
                    filter_expression.append(Tag("run_id") == to_storage_safe_id(v))
                else:
                    raise ValueError(f"Unsupported filter key: {k}")

        if before:
            before_checkpoint_id = get_checkpoint_id(before)
            if before_checkpoint_id:
                try:
                    before_ulid = ULID.from_str(before_checkpoint_id)
                    before_ts = before_ulid.timestamp
                    # Use numeric range query: checkpoint_ts < before_ts
                    filter_expression.append(Num("checkpoint_ts") < before_ts)
                except Exception:
                    # If not a valid ULID, ignore the before filter
                    pass

        # Combine all filter expressions
        combined_filter = filter_expression[0] if filter_expression else "*"
        for expr in filter_expression[1:]:
            combined_filter &= expr

        # Construct the Redis query
        # Sort by checkpoint_id in descending order to get most recent checkpoints first
        query = FilterQuery(
            filter_expression=combined_filter,
            return_fields=[
                "thread_id",
                "checkpoint_ns",
                "checkpoint_id",
                "parent_checkpoint_id",
                "$.checkpoint",
                "$.metadata",
                "has_writes",  # Include has_writes to optimize pending_writes loading
            ],
            num_results=limit or 10000,
            sort_by=("checkpoint_id", "DESC"),
        )

        # Execute the query asynchronously
        results = await self.checkpoints_index.search(query)

        # Pre-process all docs to collect batch query requirements
        all_docs_data = []
        pending_sends_batch_keys = []
        pending_writes_batch_keys = []

        for doc in results.docs:
            # Extract all attributes once
            doc_dict = doc.__dict__ if hasattr(doc, "__dict__") else {}

            thread_id = from_storage_safe_id(doc["thread_id"])
            checkpoint_ns = from_storage_safe_str(doc["checkpoint_ns"])
            checkpoint_id = from_storage_safe_id(doc["checkpoint_id"])
            parent_checkpoint_id = from_storage_safe_id(doc["parent_checkpoint_id"])

            # Get channel values from inline checkpoint data (already returned by FT.SEARCH)
            checkpoint_data = doc_dict.get("$.checkpoint") or getattr(
                doc, "$.checkpoint", None
            )
            if checkpoint_data:
                # Parse checkpoint to extract inline channel_values
                if isinstance(checkpoint_data, list) and checkpoint_data:
                    checkpoint_data = checkpoint_data[0]

                # Use orjson for faster parsing
                checkpoint_dict = (
                    checkpoint_data
                    if isinstance(checkpoint_data, dict)
                    else orjson.loads(checkpoint_data)
                )
                channel_values = self._recursive_deserialize(
                    checkpoint_dict.get("channel_values", {})
                )
            else:
                # If checkpoint data is missing, the document is corrupted
                # Set empty channel values rather than attempting a fallback
                channel_values = {}

            # Collect batch keys for pending_sends
            if parent_checkpoint_id and parent_checkpoint_id != "None":
                batch_key = (thread_id, checkpoint_ns, parent_checkpoint_id)
                pending_sends_batch_keys.append(batch_key)

            # Collect batch keys for pending_writes
            checkpoint_has_writes = doc_dict.get("has_writes") or getattr(
                doc, "has_writes", False
            )
            # Convert string "False" to boolean false if needed (optimize for common case)
            if checkpoint_has_writes == "true":
                checkpoint_has_writes = True
            elif checkpoint_has_writes == "false" or checkpoint_has_writes == "False":
                checkpoint_has_writes = False

            if checkpoint_has_writes:
                batch_key = (thread_id, checkpoint_ns, checkpoint_id)
                pending_writes_batch_keys.append(batch_key)

            # Store processed doc data for final iteration
            all_docs_data.append(
                {
                    "doc": doc,
                    "doc_dict": doc_dict,
                    "thread_id": thread_id,
                    "checkpoint_ns": checkpoint_ns,
                    "checkpoint_id": checkpoint_id,
                    "parent_checkpoint_id": parent_checkpoint_id,
                    "checkpoint_data": checkpoint_data,
                    "checkpoint_dict": checkpoint_dict if checkpoint_data else None,
                    "channel_values": channel_values,
                    "has_writes": checkpoint_has_writes,
                }
            )

        # Load pending_sends for all parent checkpoints at once
        pending_sends_map = {}
        if pending_sends_batch_keys:
            pending_sends_map = await self._abatch_load_pending_sends(
                pending_sends_batch_keys
            )

        # Load pending_writes for all checkpoints with writes at once
        pending_writes_map = {}
        if pending_writes_batch_keys:
            pending_writes_map = await self._abatch_load_pending_writes(
                pending_writes_batch_keys
            )

        # Process the results using pre-loaded batch data
        for doc_data in all_docs_data:
            thread_id = doc_data["thread_id"]
            checkpoint_ns = doc_data["checkpoint_ns"]
            checkpoint_id = doc_data["checkpoint_id"]
            parent_checkpoint_id = doc_data["parent_checkpoint_id"]

            # Get pending_sends from batch results
            pending_sends: List[Tuple[str, Union[str, bytes]]] = []
            if parent_checkpoint_id:
                batch_key = (thread_id, checkpoint_ns, parent_checkpoint_id)
                pending_sends = pending_sends_map.get(batch_key, [])

            # Fetch and parse metadata
            doc_dict = doc_data["doc_dict"]
            raw_metadata = doc_dict.get("$.metadata") or getattr(
                doc_data["doc"], "$.metadata", "{}"
            )
            # Use orjson for faster parsing
            metadata_dict = (
                orjson.loads(raw_metadata)
                if isinstance(raw_metadata, str)
                else raw_metadata
            )

            # Only sanitize if null bytes detected (rare case)
            if any(
                "\u0000" in str(v) for v in metadata_dict.values() if isinstance(v, str)
            ):
                sanitized_metadata = {
                    k.replace("\u0000", ""): (
                        v.replace("\u0000", "") if isinstance(v, str) else v
                    )
                    for k, v in metadata_dict.items()
                }
                metadata = cast(CheckpointMetadata, sanitized_metadata)
            else:
                metadata = cast(CheckpointMetadata, metadata_dict)

            # Pre-create the config structure more efficiently
            config_param: RunnableConfig = {
                "configurable": {
                    "thread_id": thread_id,
                    "checkpoint_ns": checkpoint_ns,
                    "checkpoint_id": checkpoint_id,
                }
            }

            # Pass already parsed checkpoint_dict to avoid re-parsing
            checkpoint_param = self._load_checkpoint(
                (
                    doc_data["checkpoint_dict"]
                    if doc_data["checkpoint_data"]
                    else doc_data["doc"]["$.checkpoint"]
                ),
                doc_data["channel_values"],
                pending_sends,
            )

            # Get pending_writes from batch results
            pending_writes: List[PendingWrite] = []
            if doc_data["has_writes"]:
                batch_key = (thread_id, checkpoint_ns, checkpoint_id)
                pending_writes = pending_writes_map.get(batch_key, [])

            # Build parent config if parent_checkpoint_id exists
            parent_config: RunnableConfig | None = None
            if parent_checkpoint_id:
                parent_config = {
                    "configurable": {
                        "thread_id": thread_id,
                        "checkpoint_ns": checkpoint_ns,
                        "checkpoint_id": parent_checkpoint_id,
                    }
                }

            yield CheckpointTuple(
                config=config_param,
                checkpoint=checkpoint_param,
                metadata=metadata,
                parent_config=parent_config,
                pending_writes=pending_writes,
            )

    async def aput(
        self,
        config: RunnableConfig,
        checkpoint: Checkpoint,
        metadata: CheckpointMetadata,
        new_versions: ChannelVersions,
        stream_mode: str = "values",
    ) -> RunnableConfig:
        """Store a checkpoint to Redis with proper transaction handling.

        This method ensures that all Redis operations are performed atomically
        using Redis transactions. In case of interruption (asyncio.CancelledError),
        the transaction will be aborted, ensuring consistency.

        Args:
            config: The config to associate with the checkpoint
            checkpoint: The checkpoint data to store
            metadata: Additional metadata to save with the checkpoint
            new_versions: New channel versions as of this write
            stream_mode: The streaming mode being used (values, updates, etc.)

        Returns:
            Updated configuration after storing the checkpoint

        Raises:
            asyncio.CancelledError: If the operation is cancelled/interrupted
        """
        configurable = config["configurable"].copy()

        run_id = configurable.pop("run_id", metadata.get("run_id"))
        thread_id = configurable.pop("thread_id")
        checkpoint_ns = configurable.pop("checkpoint_ns")
        # Get checkpoint_id from config - this will be parent if saving a child
        config_checkpoint_id = configurable.pop("checkpoint_id", None)
        # For backward compatibility with thread_ts
        thread_ts = configurable.pop("thread_ts", "")

        # Determine the checkpoint ID
        # This follows the original logic but with clearer parent handling
        checkpoint_id = config_checkpoint_id or thread_ts or checkpoint.get("id", "")

        # If checkpoint has its own ID that's different from what we'd use,
        # and we have a config checkpoint_id, then config checkpoint_id is the parent
        parent_checkpoint_id = None
        if (
            checkpoint.get("id")
            and config_checkpoint_id
            and checkpoint.get("id") != config_checkpoint_id
        ):
            parent_checkpoint_id = config_checkpoint_id
            checkpoint_id = checkpoint["id"]

        # For values we store in Redis, we need to convert empty strings to the
        # sentinel value.
        storage_safe_thread_id = to_storage_safe_id(thread_id)
        storage_safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)
        storage_safe_checkpoint_id = to_storage_safe_id(checkpoint_id)

        copy = checkpoint.copy()
        next_config = {
            "configurable": {
                "thread_id": thread_id,
                "checkpoint_ns": checkpoint_ns,
                "checkpoint_id": checkpoint_id,
            }
        }

        # Store checkpoint data with cluster-aware handling
        try:
            # Store checkpoint data WITH inline channel values
            # Extract timestamp from checkpoint_id (ULID)
            checkpoint_ts = None
            if checkpoint_id:
                try:
                    from ulid import ULID

                    ulid_obj = ULID.from_str(checkpoint_id)
                    checkpoint_ts = ulid_obj.timestamp  # milliseconds since epoch
                except Exception:
                    # If not a valid ULID, use current time
                    import time

                    checkpoint_ts = time.time() * 1000

            checkpoint_data = {
                "thread_id": storage_safe_thread_id,
                "run_id": to_storage_safe_id(run_id) if run_id else "",
                "checkpoint_ns": storage_safe_checkpoint_ns,
                "checkpoint_id": storage_safe_checkpoint_id,
                "parent_checkpoint_id": (
                    to_storage_safe_id(parent_checkpoint_id)
                    if parent_checkpoint_id
                    else ""
                ),
                "checkpoint_ts": checkpoint_ts,
                "checkpoint": self._dump_checkpoint(copy),
                "metadata": self._dump_metadata(metadata),
                "has_writes": False,  # Track if this checkpoint has pending writes
            }

            # store at top-level for filters in list()
            if all(key in metadata for key in ["source", "step"]):
                checkpoint_data["source"] = metadata["source"]
                checkpoint_data["step"] = metadata["step"]

            # Prepare checkpoint key
            checkpoint_key = self._make_redis_checkpoint_key_cached(
                thread_id,
                checkpoint_ns,
                checkpoint_id,
            )

            # Calculate TTL in seconds if configured
            ttl_seconds = None
            if self.ttl_config and "default_ttl" in self.ttl_config:
                ttl_seconds = int(self.ttl_config["default_ttl"] * 60)

            # Store checkpoint with TTL in a single operation using SearchIndex
            await self.checkpoints_index.load(
                [checkpoint_data],
                keys=[checkpoint_key],
                ttl=ttl_seconds,  # RedisVL applies TTL in its internal pipeline
            )

            # For test compatibility: ensure TTL operations are visible to mocks
            if (
                self.cluster_mode
                and self.ttl_config
                and "default_ttl" in self.ttl_config
                and ttl_seconds is not None
            ):
                # In cluster mode, also call expire directly so tests can verify
                try:
                    await self._redis.expire(checkpoint_key, ttl_seconds)
                except Exception:
                    logger.warning(
                        "Failed to apply TTL to checkpoint key: %s", checkpoint_key
                    )

            # Update latest checkpoint pointer
            latest_pointer_key = f"checkpoint_latest:{storage_safe_thread_id}:{storage_safe_checkpoint_ns}"
            await self._redis.set(latest_pointer_key, checkpoint_key)

            # Apply TTL to latest pointer key as well (best-effort)
            if ttl_seconds is not None:
                try:
                    await self._redis.expire(latest_pointer_key, ttl_seconds)
                except Exception:
                    logger.warning(
                        "Failed to apply TTL to latest pointer key: %s",
                        latest_pointer_key,
                    )

            return next_config

        except asyncio.CancelledError:
            # Handle cancellation/interruption based on stream mode
            if stream_mode in ("values", "messages"):
                # For these modes, we want to ensure any partial state is committed
                # to allow resuming the stream later
                try:
                    # Store minimal checkpoint data
                    checkpoint_data = {
                        "thread_id": storage_safe_thread_id,
                        "run_id": to_storage_safe_id(run_id) if run_id else "",
                        "checkpoint_ns": storage_safe_checkpoint_ns,
                        "checkpoint_id": storage_safe_checkpoint_id,
                        "parent_checkpoint_id": (
                            to_storage_safe_id(
                                str(checkpoint.get("parent_checkpoint_id", ""))
                            )
                            if checkpoint.get("parent_checkpoint_id")
                            else ""
                        ),
                        "checkpoint": self._dump_checkpoint(copy),
                        "metadata": self._dump_metadata(
                            {
                                **metadata,
                                "interrupted": True,
                                "stream_mode": stream_mode,
                            }
                        ),
                        "has_writes": False,  # Track if this checkpoint has pending writes
                    }

                    # Prepare checkpoint key
                    checkpoint_key = self._make_redis_checkpoint_key(
                        storage_safe_thread_id,
                        storage_safe_checkpoint_ns,
                        storage_safe_checkpoint_id,
                    )

                    if self.cluster_mode:
                        # For cluster mode, execute operation directly
                        await self._redis.json().set(  # type: ignore[misc]
                            checkpoint_key, "$", checkpoint_data
                        )
                    else:
                        # For non-cluster mode, use pipeline
                        pipeline = self._redis.pipeline(transaction=False)
                        pipeline.json().set(checkpoint_key, "$", checkpoint_data)
                        await pipeline.execute()
                except Exception:
                    # If this also fails, we just propagate the original cancellation
                    pass

            # Re-raise the cancellation
            raise

        except Exception as e:
            # Re-raise other exceptions
            raise e

    async def aput_writes(
        self,
        config: RunnableConfig,
        writes: Sequence[Tuple[str, Any]],
        task_id: str,
        task_path: str = "",
    ) -> None:
        """Store intermediate writes linked to a checkpoint using Redis JSON.

        This method uses Redis pipeline without transaction to avoid lock contention
        during parallel test execution.

        Args:
            config (RunnableConfig): Configuration of the related checkpoint.
            writes (List[Tuple[str, Any]]): List of writes to store.
            task_id (str): Identifier for the task creating the writes.
            task_path (str): Path of the task creating the writes.

        Raises:
            asyncio.CancelledError: If the operation is cancelled/interrupted
        """
        if not writes:
            return

        thread_id = config["configurable"]["thread_id"]
        checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
        checkpoint_id = config["configurable"]["checkpoint_id"]

        # Transform writes into appropriate format
        writes_objects = []
        for idx, (channel, value) in enumerate(writes):
            type_, blob = self.serde.dumps_typed(value)
            write_obj = {
                "thread_id": to_storage_safe_id(thread_id),
                "checkpoint_ns": to_storage_safe_str(checkpoint_ns),
                "checkpoint_id": to_storage_safe_id(checkpoint_id),
                "task_id": task_id,
                "task_path": task_path,
                "idx": WRITES_IDX_MAP.get(channel, idx),
                "channel": channel,
                "type": type_,
                "blob": self._encode_blob(
                    blob
                ),  # Encode bytes to base64 string for Redis
            }
            writes_objects.append(write_obj)

        try:
            # Determine if this is an upsert case
            upsert_case = all(w[0] in WRITES_IDX_MAP for w in writes)
            created_keys = []

            if self.cluster_mode:
                # For cluster mode, execute operations individually
                for write_obj in writes_objects:
                    key = self._make_redis_checkpoint_writes_key_cached(
                        thread_id,
                        checkpoint_ns,
                        checkpoint_id,
                        task_id,
                        write_obj["idx"],  # type: ignore[arg-type]
                    )

                    # Redis JSON.SET is an UPSERT by default
                    await self._redis.json().set(key, "$", cast(Any, write_obj))  # type: ignore[misc]
                    created_keys.append(key)

                # Apply TTL to newly created keys
                if (
                    created_keys
                    and self.ttl_config
                    and "default_ttl" in self.ttl_config
                ):
                    await self._apply_ttl_to_keys(
                        created_keys[0],
                        created_keys[1:] if len(created_keys) > 1 else None,
                    )

                # Register write keys in the key registry for cluster mode
                if self._key_registry:
                    write_keys = []
                    for write_obj in writes_objects:
                        key = self._make_redis_checkpoint_writes_key_cached(
                            thread_id,
                            checkpoint_ns,
                            checkpoint_id,
                            task_id,
                            write_obj["idx"],  # type: ignore[arg-type]
                        )
                        write_keys.append(key)

                    if write_keys:
                        # Use per-checkpoint sorted set registry
                        zset_key = self._key_registry.make_write_keys_zset_key(
                            thread_id, checkpoint_ns, checkpoint_id
                        )

                        # Add all write keys with their index as score for ordering
                        zadd_mapping = {key: idx for idx, key in enumerate(write_keys)}
                        await self._redis.zadd(zset_key, zadd_mapping)  # type: ignore[arg-type]

                        # Apply TTL to registry key if configured (best-effort)
                        if self.ttl_config and "default_ttl" in self.ttl_config:
                            ttl_seconds = int(self.ttl_config.get("default_ttl") * 60)
                            try:
                                await self._redis.expire(zset_key, ttl_seconds)
                            except Exception:
                                logger.warning(
                                    "Failed to apply TTL to write registry key: %s",
                                    zset_key,
                                    exc_info=True,
                                )

            else:
                # For non-cluster mode, use pipeline without transaction to avoid lock contention.
                # IMPORTANT: Only critical commands (JSON.SET, JSON.MERGE, ZADD) go in the
                # pipeline. EXPIRE (TTL) commands are applied separately afterward to avoid
                # pipeline failures on Redis Enterprise proxy, where mixed JSON module +
                # native commands in a single pipeline can cause EXPIRE to fail, aborting
                # the entire pipeline and losing interrupt writes.
                pipeline = self._redis.pipeline(transaction=False)

                # Add all write operations to the pipeline (critical)
                for write_obj in writes_objects:
                    key = self._make_redis_checkpoint_writes_key_cached(
                        thread_id,
                        checkpoint_ns,
                        checkpoint_id,
                        task_id,
                        write_obj["idx"],  # type: ignore[arg-type]
                    )

                    pipeline.json().set(key, "$", cast(Any, write_obj))
                    created_keys.append(key)

                # Update checkpoint to indicate it has writes (critical)
                checkpoint_key = ""
                if writes_objects:
                    checkpoint_key = self._make_redis_checkpoint_key(
                        thread_id, checkpoint_ns, checkpoint_id
                    )
                    pipeline.json().merge(checkpoint_key, "$", {"has_writes": True})

                # Registry operation (critical)
                zset_key = ""
                if self._key_registry and created_keys:
                    zset_key = self._key_registry.make_write_keys_zset_key(
                        thread_id, checkpoint_ns, checkpoint_id
                    )
                    zadd_mapping = {key: idx for idx, key in enumerate(created_keys)}
                    pipeline.zadd(zset_key, zadd_mapping)  # type: ignore[arg-type]

                # Execute critical commands with raise_on_error=False to get
                # per-command results instead of aborting on first failure
                results = await pipeline.execute(raise_on_error=False)

                # Check results for critical command failures
                merge_failed = False
                for result in results:
                    if isinstance(result, Exception):
                        err_str = str(result)
                        if "JSON.MERGE" in err_str or "merge" in err_str.lower():
                            merge_failed = True
                        else:
                            raise result

                # Handle JSON.MERGE fallback for older Redis versions
                if merge_failed and checkpoint_key:
                    try:
                        checkpoint_data = await self._redis.json().get(  # type: ignore[misc]
                            checkpoint_key
                        )
                        if isinstance(
                            checkpoint_data, dict
                        ) and not checkpoint_data.get("has_writes"):
                            checkpoint_data["has_writes"] = True
                            await self._redis.json().set(  # type: ignore[misc]
                                checkpoint_key, "$", checkpoint_data
                            )
                    except Exception:
                        pass

                # Apply TTL separately (best-effort — failures here don't lose
                # writes). Individual calls ensure partial success: if one key's
                # EXPIRE fails on RE proxy, the others still get TTL applied.
                if (
                    created_keys
                    and self.ttl_config
                    and "default_ttl" in self.ttl_config
                ):
                    ttl_seconds = int(self.ttl_config["default_ttl"] * 60)
                    for key in created_keys:
                        try:
                            await self._redis.expire(key, ttl_seconds)
                        except Exception:
                            logger.warning(
                                "Failed to apply TTL to checkpoint write key: %s",
                                key,
                            )

                if zset_key and self.ttl_config and "default_ttl" in self.ttl_config:
                    try:
                        ttl_seconds = int(self.ttl_config["default_ttl"] * 60)
                        await self._redis.expire(zset_key, ttl_seconds)
                    except Exception:
                        logger.warning(
                            "Failed to apply TTL to write registry key: %s",
                            zset_key,
                            exc_info=True,
                        )

        except asyncio.CancelledError:
            # Handle cancellation/interruption
            # Pipeline will be automatically discarded
            # Either all operations succeed or none do
            raise

        except Exception as e:
            # Re-raise other exceptions
            raise e

    def put_writes(
        self,
        config: RunnableConfig,
        writes: Sequence[Tuple[str, Any]],
        task_id: str,
        task_path: str = "",
    ) -> None:
        """Synchronous wrapper for aput_writes.

        Args:
            config (RunnableConfig): Configuration of the related checkpoint.
            writes (List[Tuple[str, Any]]): List of writes to store.
            task_id (str): Identifier for the task creating the writes.
            task_path (str): Path of the task creating the writes.
        """
        return asyncio.run_coroutine_threadsafe(
            self.aput_writes(config, writes, task_id), self.loop
        ).result()

    def get_channel_values(
        self, thread_id: str, checkpoint_ns: str = "", checkpoint_id: str = ""
    ) -> Dict[str, Any]:
        """Retrieve channel_values using efficient FT.SEARCH with checkpoint_id (sync wrapper)."""
        try:
            if asyncio.get_running_loop() is self.loop:
                raise asyncio.InvalidStateError(
                    "Synchronous calls to AsyncRedisSaver are only allowed from a "
                    "different thread. From the main thread, use the async interface."
                    "For example, use `await checkpointer.get_channel_values(...)`."
                )
        except RuntimeError:
            pass
        return asyncio.run_coroutine_threadsafe(
            self.aget_channel_values(
                thread_id,
                checkpoint_ns,
                checkpoint_id,
            ),
            self.loop,
        ).result()

    def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
        """Get a checkpoint tuple from Redis.

        Args:
            config (RunnableConfig): The config to use for retrieving the checkpoint.

        Returns:
            Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found.

        Raises:
            asyncio.InvalidStateError: If called from the wrong thread/event loop
        """
        try:
            # check if we are in the main thread, only bg threads can block
            if asyncio.get_running_loop() is self.loop:
                raise asyncio.InvalidStateError(
                    "Synchronous calls to AsyncRedisSaver are only allowed from a "
                    "different thread. From the main thread, use the async interface."
                    "For example, use `await checkpointer.aget_tuple(...)` or `await "
                    "graph.ainvoke(...)`."
                )
        except RuntimeError:
            pass
        return asyncio.run_coroutine_threadsafe(
            self.aget_tuple(config), self.loop
        ).result()

    def put(
        self,
        config: RunnableConfig,
        checkpoint: Checkpoint,
        metadata: CheckpointMetadata,
        new_versions: ChannelVersions,
    ) -> RunnableConfig:
        """Store a checkpoint to Redis.

        Args:
            config (RunnableConfig): The config to associate with the checkpoint.
            checkpoint (Checkpoint): The checkpoint to save.
            metadata (CheckpointMetadata): Additional metadata to save with the checkpoint.
            new_versions (ChannelVersions): New channel versions as of this write.

        Returns:
            RunnableConfig: Updated configuration after storing the checkpoint.

        Raises:
            asyncio.InvalidStateError: If called from the wrong thread/event loop
        """
        try:
            # check if we are in the main thread, only bg threads can block
            if asyncio.get_running_loop() is self.loop:
                raise asyncio.InvalidStateError(
                    "Synchronous calls to AsyncRedisSaver are only allowed from a "
                    "different thread. From the main thread, use the async interface."
                    "For example, use `await checkpointer.aput(...)` or `await "
                    "graph.ainvoke(...)`."
                )
        except RuntimeError:
            pass
        return asyncio.run_coroutine_threadsafe(
            self.aput(config, checkpoint, metadata, new_versions), self.loop
        ).result()

    @classmethod
    @asynccontextmanager
    async def from_conn_string(
        cls,
        redis_url: Optional[str] = None,
        *,
        redis_client: Optional[Union[AsyncRedis, AsyncRedisCluster]] = None,
        connection_args: Optional[Dict[str, Any]] = None,
        ttl: Optional[Dict[str, Any]] = None,
        checkpoint_prefix: str = CHECKPOINT_PREFIX,
        checkpoint_write_prefix: str = CHECKPOINT_WRITE_PREFIX,
    ) -> AsyncIterator[AsyncRedisSaver]:
        async with cls(
            redis_url=redis_url,
            redis_client=redis_client,
            connection_args=connection_args,
            ttl=ttl,
            checkpoint_prefix=checkpoint_prefix,
            checkpoint_write_prefix=checkpoint_write_prefix,
        ) as saver:
            yield saver

    async def aget_channel_values(
        self,
        thread_id: str,
        checkpoint_ns: str = "",
        checkpoint_id: str = "",
        channel_versions: Optional[Dict[str, Any]] = None,
    ) -> Dict[str, Any]:
        """Retrieve channel_values using efficient FT.SEARCH with checkpoint_id."""
        storage_safe_thread_id = to_storage_safe_id(thread_id)
        storage_safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)
        storage_safe_checkpoint_id = to_storage_safe_id(checkpoint_id)

        # Get checkpoint with inline channel_values using single JSON.GET operation (MongoDB approach)
        checkpoint_key = self._make_redis_checkpoint_key_cached(
            thread_id,
            checkpoint_ns,
            checkpoint_id,
        )

        # Single JSON.GET operation to retrieve checkpoint with inline channel_values
        checkpoint_data = await self._redis.json().get(checkpoint_key, "$.checkpoint")  # type: ignore[misc]

        if not checkpoint_data:
            return {}

        # checkpoint_data[0] is already a deserialized dict, not a typed tuple
        checkpoint = checkpoint_data[0]
        return checkpoint.get("channel_values", {})

    async def _aload_pending_sends(
        self, thread_id: str, checkpoint_ns: str = "", parent_checkpoint_id: str = ""
    ) -> List[Tuple[str, Union[str, bytes]]]:
        """Load pending sends for a parent checkpoint.

        Args:
            thread_id: The thread ID
            checkpoint_ns: The checkpoint namespace
            parent_checkpoint_id: The ID of the parent checkpoint

        Returns:
            List of (type, blob) tuples representing pending sends
        """
        if not parent_checkpoint_id:
            return []

        # FAST PATH: Try sorted set registry first
        if self._key_registry:
            try:
                # Check if parent checkpoint has any writes in the sorted set
                write_count = await self._key_registry.get_write_count(
                    thread_id, checkpoint_ns, parent_checkpoint_id
                )

                if write_count == 0:
                    # No writes for parent checkpoint - return immediately
                    return []

                # Get exact write keys from the per-checkpoint registry
                write_keys = await self._key_registry.get_write_keys(
                    thread_id, checkpoint_ns, parent_checkpoint_id
                )

                # Filter for TASKS channel writes
                task_write_keys = []
                for key in write_keys:
                    # Keys contain channel info: checkpoint_write:thread:ns:checkpoint:task:idx
                    # We need to check if it's a TASKS channel write
                    # This is a simple heuristic - we might need to fetch to be sure
                    if TASKS in key or "__pregel_tasks" in key:
                        task_write_keys.append(key)

                if not task_write_keys:
                    return []

                # Fetch task writes using pipeline (safe for cluster mode)
                pipeline = self._redis.pipeline(transaction=False)
                for key in task_write_keys:
                    pipeline.json().get(key)

                results = await pipeline.execute()

                # Extract pending sends and sort them
                pending_sends_with_sort_keys = []
                for write_data in results:
                    if write_data and write_data.get("channel") == TASKS:
                        pending_sends_with_sort_keys.append(
                            (
                                write_data.get("task_path", ""),
                                write_data.get("task_id", ""),
                                write_data.get("idx", 0),
                                write_data.get("type", ""),
                                write_data.get("blob", b""),
                            )
                        )

                # Sort by task_path, task_id, idx
                pending_sends_with_sort_keys.sort(key=lambda x: (x[0], x[1], x[2]))

                # Return just the (type, blob) tuples
                return [(item[3], item[4]) for item in pending_sends_with_sort_keys]

            except Exception:
                # If sorted set approach fails, fall back to FT.SEARCH
                pass

        # Fallback to FT.SEARCH logic
        parent_writes_query = FilterQuery(
            filter_expression=(
                (Tag("thread_id") == to_storage_safe_id(thread_id))
                & (Tag("checkpoint_ns") == to_storage_safe_str(checkpoint_ns))
                & (Tag("checkpoint_id") == to_storage_safe_id(parent_checkpoint_id))
                & (Tag("channel") == TASKS)
            ),
            return_fields=["type", "$.blob", "task_path", "task_id", "idx"],
            num_results=100,
        )
        res = await self.checkpoint_writes_index.search(parent_writes_query)

        # Sort results for deterministic order
        docs = sorted(
            res.docs,
            key=lambda d: (
                getattr(d, "task_path", ""),
                getattr(d, "task_id", ""),
                getattr(d, "idx", 0),
            ),
        )

        # Convert to expected format
        return [
            (d.type.encode(), blob)
            for d in docs
            if (blob := getattr(d, "$.blob", getattr(d, "blob", None))) is not None
        ]

    async def _aload_pending_writes(
        self,
        thread_id: str,
        checkpoint_ns: str = "",
        checkpoint_id: str = "",
    ) -> List[PendingWrite]:
        if checkpoint_id is None:
            return []  # Early return if no checkpoint_id

        # FAST PATH: Try sorted set registry first
        if self._key_registry:
            try:
                # Check if this checkpoint has any writes in the sorted set
                write_count = await self._key_registry.get_write_count(
                    thread_id, checkpoint_ns, checkpoint_id
                )

                if write_count == 0:
                    # No writes for this checkpoint - return immediately
                    return []

                # Get exact write keys from the per-checkpoint registry
                write_keys = await self._key_registry.get_write_keys(
                    thread_id, checkpoint_ns, checkpoint_id
                )

                # Fetch all writes efficiently using pipeline
                pipeline = self._redis.pipeline(transaction=False)
                for key in write_keys:
                    pipeline.json().get(key)

                results = await pipeline.execute()

                # Build the writes dictionary
                writes_dict: Dict[Tuple[str, str], Dict[str, Any]] = {}

                for write_data in results:
                    if write_data:
                        task_id = write_data.get("task_id", "")
                        idx = str(write_data.get("idx", 0))
                        writes_dict[(task_id, idx)] = {
                            "task_id": task_id,
                            "idx": idx,
                            "channel": write_data.get("channel", ""),
                            "type": write_data.get("type", ""),
                            "blob": write_data.get("blob", b""),
                        }

                # Deserialize and return
                pending_writes = BaseRedisSaver._load_writes(self.serde, writes_dict)
                return pending_writes

            except Exception:
                # If sorted set approach fails, fall back to FT.SEARCH
                pass

        # FALLBACK: Use search index instead of keys() to avoid CrossSlot errors
        # Note: All tag fields use sentinel values for consistency
        writes_query = FilterQuery(
            filter_expression=(Tag("thread_id") == to_storage_safe_id(thread_id))
            & (Tag("checkpoint_ns") == to_storage_safe_str(checkpoint_ns))
            & (Tag("checkpoint_id") == to_storage_safe_id(checkpoint_id)),
            return_fields=["task_id", "idx", "channel", "type", "$.blob"],
            num_results=1000,  # Adjust as needed
        )

        writes_results = await self.checkpoint_writes_index.search(writes_query)

        # Sort results by idx to maintain order
        sorted_writes = sorted(writes_results.docs, key=lambda x: getattr(x, "idx", 0))

        # Build the writes dictionary from search results
        search_writes_dict: Dict[Tuple[str, str], Dict[str, Any]] = {}
        for doc in sorted_writes:
            task_id = str(getattr(doc, "task_id", ""))
            idx = str(getattr(doc, "idx", 0))
            blob_data = getattr(doc, "$.blob", "")
            # Ensure blob is bytes for deserialization
            if isinstance(blob_data, str):
                blob_data = blob_data.encode("utf-8")
            search_writes_dict[(task_id, idx)] = {
                "task_id": task_id,
                "idx": idx,
                "channel": str(getattr(doc, "channel", "")),
                "type": str(getattr(doc, "type", "")),
                "blob": blob_data,
            }

        pending_writes = BaseRedisSaver._load_writes(self.serde, search_writes_dict)
        return pending_writes

    async def _abatch_load_pending_sends(
        self, batch_keys: List[Tuple[str, str, str]]
    ) -> Dict[Tuple[str, str, str], List[Tuple[str, Union[str, bytes]]]]:
        """Batch load pending sends for multiple parent checkpoints.

        Args:
            batch_keys: List of (thread_id, checkpoint_ns, parent_checkpoint_id) tuples

        Returns:
            Dict mapping batch_key -> list of (type, blob) tuples
        """
        if not batch_keys:
            return {}

        results_map = {}

        # Group by thread_id and checkpoint_ns for efficient querying
        grouped_keys: Dict[Tuple[str, str], List[str]] = {}
        for thread_id, checkpoint_ns, parent_checkpoint_id in batch_keys:
            group_key = (thread_id, checkpoint_ns)
            if group_key not in grouped_keys:
                grouped_keys[group_key] = []
            grouped_keys[group_key].append(parent_checkpoint_id)

        # Batch query for each group
        for (thread_id, checkpoint_ns), parent_checkpoint_ids in grouped_keys.items():
            storage_safe_thread_id = to_storage_safe_id(thread_id)
            storage_safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)
            storage_safe_parent_checkpoint_ids = [
                to_storage_safe_id(pid) for pid in parent_checkpoint_ids
            ]

            # Build filter for multiple parent checkpoint IDs
            thread_filter = Tag("thread_id") == storage_safe_thread_id
            ns_filter = Tag("checkpoint_ns") == storage_safe_checkpoint_ns
            channel_filter = Tag("channel") == TASKS

            # Create filter for multiple parent checkpoint IDs (Tag supports lists)
            checkpoint_filter = (
                Tag("checkpoint_id") == storage_safe_parent_checkpoint_ids
            )

            batch_query = FilterQuery(
                filter_expression=thread_filter
                & ns_filter
                & checkpoint_filter
                & channel_filter,
                return_fields=[
                    "checkpoint_id",
                    "type",
                    "$.blob",
                    "task_path",
                    "task_id",
                    "idx",
                ],
                num_results=1000,  # Increased limit for batch loading
            )

            batch_results = await self.checkpoint_writes_index.search(batch_query)

            # Group results by parent checkpoint ID
            writes_by_checkpoint: Dict[str, List[Any]] = {}
            for doc in batch_results.docs:
                parent_checkpoint_id = from_storage_safe_id(doc.checkpoint_id)
                if parent_checkpoint_id not in writes_by_checkpoint:
                    writes_by_checkpoint[parent_checkpoint_id] = []
                writes_by_checkpoint[parent_checkpoint_id].append(doc)

            # Sort and format results for each parent checkpoint
            for parent_checkpoint_id in parent_checkpoint_ids:
                batch_key = (thread_id, checkpoint_ns, parent_checkpoint_id)
                writes = writes_by_checkpoint.get(parent_checkpoint_id, [])

                # Sort results by task_path, task_id, idx
                sorted_writes = sorted(
                    writes,
                    key=lambda x: (
                        getattr(x, "task_path", ""),
                        getattr(x, "task_id", ""),
                        getattr(x, "idx", 0),
                    ),
                )

                # Extract type and blob pairs
                # Handle both direct attribute access and JSON path access
                results_map[batch_key] = [
                    (
                        getattr(doc, "type", ""),
                        getattr(doc, "$.blob", getattr(doc, "blob", b"")),
                    )
                    for doc in sorted_writes
                ]

        return results_map

    async def _abatch_load_pending_writes(
        self, batch_keys: List[Tuple[str, str, str]]
    ) -> Dict[Tuple[str, str, str], List[PendingWrite]]:
        """Batch load pending writes for multiple checkpoints.

        Args:
            batch_keys: List of (thread_id, checkpoint_ns, checkpoint_id) tuples

        Returns:
            Dict mapping batch_key -> list of PendingWrite objects
        """
        if not batch_keys:
            return {}

        results_map: Dict[Tuple[str, str, str], List[PendingWrite]] = {}

        # If we have a key registry, use it for efficient batch loading
        if self._key_registry:
            # First, collect all write keys for all checkpoints in parallel
            pipeline = self._redis.pipeline(transaction=False)

            # Add all ZCARD operations to pipeline to check write counts
            for thread_id, checkpoint_ns, checkpoint_id in batch_keys:
                zset_key = self._key_registry.make_write_keys_zset_key(
                    thread_id, checkpoint_ns, checkpoint_id
                )
                pipeline.zcard(zset_key)

            # Execute all ZCARD operations at once
            write_counts = await pipeline.execute()

            # Now get the actual keys for checkpoints that have writes
            pipeline = self._redis.pipeline(transaction=False)
            checkpoints_with_writes = []

            for i, (thread_id, checkpoint_ns, checkpoint_id) in enumerate(batch_keys):
                if write_counts[i] > 0:
                    checkpoints_with_writes.append(
                        (thread_id, checkpoint_ns, checkpoint_id)
                    )
                    zset_key = self._key_registry.make_write_keys_zset_key(
                        thread_id, checkpoint_ns, checkpoint_id
                    )
                    pipeline.zrange(zset_key, 0, -1)
                else:
                    # No writes for this checkpoint
                    batch_key = (thread_id, checkpoint_ns, checkpoint_id)
                    results_map[batch_key] = []

            if checkpoints_with_writes:
                # Get all write keys at once
                all_write_keys_results = await pipeline.execute()

                # Now fetch all the actual write data in a single pipeline
                pipeline = self._redis.pipeline(transaction=False)
                write_key_mapping = {}  # Maps pipeline index to checkpoint info
                pipeline_index = 0

                for i, (thread_id, checkpoint_ns, checkpoint_id) in enumerate(
                    checkpoints_with_writes
                ):
                    write_keys = all_write_keys_results[i]
                    if write_keys:
                        decoded_keys = [
                            key.decode() if isinstance(key, bytes) else key
                            for key in write_keys
                        ]
                        for key in decoded_keys:
                            pipeline.json().get(key)
                            write_key_mapping[pipeline_index] = (
                                thread_id,
                                checkpoint_ns,
                                checkpoint_id,
                                key,
                            )
                            pipeline_index += 1

                # Execute all JSON.GET operations at once
                if pipeline_index > 0:
                    all_writes_data = await pipeline.execute()

                    # Group results by checkpoint
                    writes_by_checkpoint: Dict[
                        Tuple[str, str, str], Dict[Tuple[str, str], Dict[str, Any]]
                    ] = {}

                    for idx, write_data in enumerate(all_writes_data):
                        if write_data:
                            thread_id, checkpoint_ns, checkpoint_id, key = (
                                write_key_mapping[idx]
                            )
                            batch_key = (thread_id, checkpoint_ns, checkpoint_id)

                            if batch_key not in writes_by_checkpoint:
                                writes_by_checkpoint[batch_key] = {}

                            task_id = write_data.get("task_id", "")
                            idx_val = str(write_data.get("idx", 0))
                            writes_by_checkpoint[batch_key][(task_id, idx_val)] = {
                                "task_id": task_id,
                                "idx": idx_val,
                                "channel": write_data.get("channel", ""),
                                "type": write_data.get("type", ""),
                                "blob": write_data.get("blob", b""),
                            }

                    # Deserialize and store results
                    for batch_key, writes_dict in writes_by_checkpoint.items():
                        results_map[batch_key] = BaseRedisSaver._load_writes(
                            self.serde, writes_dict
                        )
        else:
            # Fallback to batch search
            # Group by thread_id and checkpoint_ns for efficient querying
            grouped_keys: Dict[Tuple[str, str], List[str]] = {}
            for thread_id, checkpoint_ns, checkpoint_id in batch_keys:
                group_key = (thread_id, checkpoint_ns)
                if group_key not in grouped_keys:
                    grouped_keys[group_key] = []
                grouped_keys[group_key].append(checkpoint_id)

            # Batch query for each group
            for (thread_id, checkpoint_ns), checkpoint_ids in grouped_keys.items():
                storage_safe_thread_id = to_storage_safe_id(thread_id)
                storage_safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)
                storage_safe_checkpoint_ids = [
                    to_storage_safe_id(cid) for cid in checkpoint_ids
                ]

                # Build batch query
                thread_filter = Tag("thread_id") == storage_safe_thread_id
                ns_filter = Tag("checkpoint_ns") == storage_safe_checkpoint_ns
                checkpoint_filter = Tag("checkpoint_id") == storage_safe_checkpoint_ids

                batch_query = FilterQuery(
                    filter_expression=thread_filter & ns_filter & checkpoint_filter,
                    return_fields=[
                        "checkpoint_id",
                        "task_id",
                        "idx",
                        "channel",
                        "type",
                        "$.blob",
                    ],
                    num_results=5000,  # Increased limit for batch
                )

                batch_results = await self.checkpoint_writes_index.search(batch_query)

                # Group results by checkpoint ID
                fallback_writes_by_checkpoint: Dict[
                    str, Dict[Tuple[str, str], Dict[str, Any]]
                ] = {}
                for doc in batch_results.docs:
                    checkpoint_id = from_storage_safe_id(doc.checkpoint_id)
                    if checkpoint_id not in fallback_writes_by_checkpoint:
                        fallback_writes_by_checkpoint[checkpoint_id] = {}

                    task_id = getattr(doc, "task_id", "")
                    idx_str = str(getattr(doc, "idx", 0))
                    blob = getattr(doc, "$.blob", getattr(doc, "blob", b""))

                    fallback_writes_by_checkpoint[checkpoint_id][(task_id, idx_str)] = {
                        "task_id": task_id,
                        "idx": idx_str,
                        "channel": getattr(doc, "channel", ""),
                        "type": getattr(doc, "type", ""),
                        "blob": blob,
                    }

                # Process results for each checkpoint
                for checkpoint_id in checkpoint_ids:
                    batch_key = (thread_id, checkpoint_ns, checkpoint_id)
                    writes_dict = fallback_writes_by_checkpoint.get(checkpoint_id, {})
                    results_map[batch_key] = BaseRedisSaver._load_writes(
                        self.serde, writes_dict
                    )

        return results_map

    async def adelete_thread(self, thread_id: str) -> None:
        """Delete all checkpoints and writes associated with a specific thread ID.

        Args:
            thread_id: The thread ID whose checkpoints should be deleted.
        """
        storage_safe_thread_id = to_storage_safe_id(thread_id)

        # Delete all checkpoints for this thread
        checkpoint_query = FilterQuery(
            filter_expression=Tag("thread_id") == storage_safe_thread_id,
            return_fields=["checkpoint_ns", "checkpoint_id"],
            num_results=10000,  # Get all checkpoints for this thread
        )

        checkpoint_results = await self.checkpoints_index.search(checkpoint_query)

        # Collect all keys to delete
        keys_to_delete = []
        checkpoint_namespaces = set()

        for doc in checkpoint_results.docs:
            checkpoint_ns = getattr(doc, "checkpoint_ns", "")
            checkpoint_id = getattr(doc, "checkpoint_id", "")

            # Track unique namespaces for latest pointer cleanup
            checkpoint_namespaces.add(checkpoint_ns)

            # Delete checkpoint key
            checkpoint_key = self._make_redis_checkpoint_key(
                storage_safe_thread_id, checkpoint_ns, checkpoint_id
            )
            keys_to_delete.append(checkpoint_key)

        # Add latest checkpoint pointers to deletion list
        for checkpoint_ns in checkpoint_namespaces:
            latest_pointer_key = f"checkpoint_latest:{storage_safe_thread_id}:{to_storage_safe_str(checkpoint_ns)}"
            keys_to_delete.append(latest_pointer_key)

        # Channel values are stored inline — no separate blob keys to clean up.

        # Delete all writes for this thread
        writes_query = FilterQuery(
            filter_expression=Tag("thread_id") == storage_safe_thread_id,
            return_fields=["checkpoint_ns", "checkpoint_id", "task_id", "idx"],
            num_results=10000,
        )

        writes_results = await self.checkpoint_writes_index.search(writes_query)

        for doc in writes_results.docs:
            checkpoint_ns = getattr(doc, "checkpoint_ns", "")
            checkpoint_id = getattr(doc, "checkpoint_id", "")
            task_id = getattr(doc, "task_id", "")
            idx = getattr(doc, "idx", 0)

            write_key = self._make_redis_checkpoint_writes_key(
                storage_safe_thread_id, checkpoint_ns, checkpoint_id, task_id, idx
            )
            keys_to_delete.append(write_key)

        # Delete the registry sorted sets for each checkpoint
        if self._key_registry:
            # Get unique checkpoints from the results we already have
            processed_checkpoints = set()
            for doc in checkpoint_results.docs:
                checkpoint_ns = getattr(doc, "checkpoint_ns", "")
                checkpoint_id = getattr(doc, "checkpoint_id", "")
                checkpoint_key = (thread_id, checkpoint_ns, checkpoint_id)

                if checkpoint_key not in processed_checkpoints:
                    processed_checkpoints.add(checkpoint_key)
                    # Add the write registry key for this checkpoint
                    zset_key = self._key_registry.make_write_keys_zset_key(
                        thread_id, checkpoint_ns, checkpoint_id
                    )
                    keys_to_delete.append(zset_key)

        # Execute all deletions based on cluster mode
        if self.cluster_mode:
            # For cluster mode, delete keys individually
            for key in keys_to_delete:
                await self._redis.delete(key)
        else:
            # For non-cluster mode, use pipeline for efficiency
            pipeline = self._redis.pipeline()
            for key in keys_to_delete:
                pipeline.delete(key)
            await pipeline.execute()

    async def aprune(
        self,
        thread_ids: Sequence[str],
        *,
        strategy: str = "keep_latest",
        keep_last: Optional[int] = None,
        max_results: int = 10000,
    ) -> None:
        """Prune old checkpoints for the given threads per namespace.

        Retains the most-recent checkpoints **per checkpoint namespace** and
        removes the rest, along with their associated write keys and
        key-registry sorted sets.

        Each namespace (root ``""`` and any subgraph namespaces) is treated as
        an independent checkpoint chain.  Channel values are stored inline
        within each checkpoint document, so they are automatically removed
        when the checkpoint document is deleted.

        Args:
            thread_ids: Thread IDs whose old checkpoints should be pruned.
            strategy: Pruning strategy.  ``"keep_latest"`` retains only the
                most recent checkpoint per namespace (default).  ``"delete"``
                removes all checkpoints for the thread.
            keep_last: Optional override — number of recent checkpoints to
                retain per namespace.  When provided, takes precedence over
                ``strategy``.  Use ``keep_last=0`` to remove all checkpoints.
            max_results: Maximum number of checkpoints fetched from the index
                per thread in a single query.  Defaults to 10 000.
        """
        # Resolve keep_last from strategy if not explicitly provided
        if keep_last is None:
            if strategy == "delete":
                keep_last = 0
            else:
                keep_last = 1

        # Validate inputs
        if not thread_ids:
            raise ValueError("``thread_ids`` must be a non-empty sequence")
        if keep_last < 0:
            raise ValueError(f"``keep_last`` must be >= 0, got {keep_last}")
        if max_results < 1:
            raise ValueError(f"``max_results`` must be >= 1, got {max_results}")

        for thread_id in thread_ids:
            storage_safe_thread_id = to_storage_safe_id(thread_id)

            # Fetch all checkpoints for this thread across all namespaces
            checkpoint_query = FilterQuery(
                filter_expression=Tag("thread_id") == storage_safe_thread_id,
                return_fields=["checkpoint_ns", "checkpoint_id"],
                num_results=max_results,
            )
            checkpoint_results = await self.checkpoints_index.search(checkpoint_query)

            if not checkpoint_results.docs:
                continue

            # Group by namespace — each namespace is an independent checkpoint chain
            # (root graph vs. subgraph checkpoints must be evicted independently).
            by_ns: Dict[str, list] = defaultdict(list)
            for doc in checkpoint_results.docs:
                ns = getattr(doc, "checkpoint_ns", "")
                by_ns[ns].append(doc)

            # Within each namespace sort newest-first (ULIDs are lex time-ordered)
            # and collect checkpoints that fall outside the keep_last window.
            # Track namespaces where every checkpoint is evicted so we can clean
            # up the checkpoint_latest:{thread}:{ns} pointer key too
            to_evict = []
            fully_evicted_ns: set = set()
            for ns, ns_docs in by_ns.items():
                ns_sorted = sorted(
                    ns_docs,
                    key=lambda d: getattr(d, "checkpoint_id", ""),
                    reverse=True,
                )
                ns_evicted = ns_sorted[keep_last:]
                to_evict.extend(ns_evicted)
                if len(ns_evicted) == len(ns_docs):  # nothing left in this namespace
                    fully_evicted_ns.add(ns)

            if not to_evict:
                continue

            keys_to_delete = []
            for doc in to_evict:
                checkpoint_ns = getattr(doc, "checkpoint_ns", "")
                checkpoint_id = getattr(doc, "checkpoint_id", "")

                # Evict checkpoint document
                keys_to_delete.append(
                    self._make_redis_checkpoint_key(
                        storage_safe_thread_id, checkpoint_ns, checkpoint_id
                    )
                )

                # Evict all write documents for this checkpoint
                writes_query = FilterQuery(
                    filter_expression=(
                        (Tag("thread_id") == storage_safe_thread_id)
                        & (Tag("checkpoint_id") == checkpoint_id)
                    ),
                    return_fields=["checkpoint_ns", "checkpoint_id", "task_id", "idx"],
                    num_results=max_results,
                )
                writes_results = await self.checkpoint_writes_index.search(writes_query)
                for wdoc in writes_results.docs:
                    keys_to_delete.append(
                        self._make_redis_checkpoint_writes_key(
                            storage_safe_thread_id,
                            getattr(wdoc, "checkpoint_ns", ""),
                            getattr(wdoc, "checkpoint_id", ""),
                            getattr(wdoc, "task_id", ""),
                            int(getattr(wdoc, "idx", 0)),
                        )
                    )

                # Evict key-registry sorted set for this checkpoint
                if self._key_registry:
                    keys_to_delete.append(
                        self._key_registry.make_write_keys_zset_key(
                            thread_id, checkpoint_ns, checkpoint_id
                        )
                    )

            for ns in fully_evicted_ns:
                keys_to_delete.append(
                    f"checkpoint_latest:{storage_safe_thread_id}:{ns}"
                )

            if self.cluster_mode:
                for key in keys_to_delete:
                    await self._redis.delete(key)
            else:
                pipeline = self._redis.pipeline()
                for key in keys_to_delete:
                    pipeline.delete(key)
                await pipeline.execute()
