Source code for pyspark.sql.merge
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys
from typing import Dict, Optional, TYPE_CHECKING
from pyspark.sql.column import Column
from pyspark.sql.utils import to_scala_map
if TYPE_CHECKING:
    from pyspark.sql.dataframe import DataFrame
__all__ = ["MergeIntoWriter"]
class MergeIntoWriter:
    """
    `MergeIntoWriter` provides methods to define and execute merge actions based
    on specified conditions.
    .. versionadded: 4.0.0
    """
    def __init__(self, df: "DataFrame", table: str, condition: Column):
        self._spark = df.sparkSession
        from pyspark.sql.classic.column import _to_java_column
        self._jwriter = df._jdf.mergeInto(table, _to_java_column(condition))
[docs]    def whenMatched(self, condition: Optional[Column] = None) -> "MergeIntoWriter.WhenMatched":
        """
        Initialize a `WhenMatched` action with a condition.
        This `WhenMatched` action will be executed when a source row matches a target table row
        based on the merge condition and the specified `condition` is satisfied.
        This `WhenMatched` can be followed by one of the following merge actions:
          - `updateAll`: Update all the matched target table rows with source dataset rows.
          - `update(Dict)`: Update all the matched target table rows while changing only
            a subset of columns based on the provided assignment.
          - `delete`: Delete all target rows that have a match in the source table.
        """
        return self.WhenMatched(self, condition) 
[docs]    def whenNotMatched(
        self, condition: Optional[Column] = None
    ) -> "MergeIntoWriter.WhenNotMatched":
        """
        Initialize a `WhenNotMatched` action with a condition.
        This `WhenNotMatched` action will be executed when a source row does not match any target
        row based on the merge condition and the specified `condition` is satisfied.
        This `WhenNotMatched` can be followed by one of the following merge actions:
          - `insertAll`: Insert all rows from the source that are not already in the target table.
          - `insert(Dict)`: Insert all rows from the source that are not already in the target
            table, with the specified columns based on the provided assignment.
        """
        return self.WhenNotMatched(self, condition) 
[docs]    def whenNotMatchedBySource(
        self, condition: Optional[Column] = None
    ) -> "MergeIntoWriter.WhenNotMatchedBySource":
        """
        Initialize a `WhenNotMatchedBySource` action with a condition.
        This `WhenNotMatchedBySource` action will be executed when a target row does not match any
        rows in the source table based on the merge condition and the specified `condition`
        is satisfied.
        This `WhenNotMatchedBySource` can be followed by one of the following merge actions:
          - `updateAll`: Update all the not matched target table rows with source dataset rows.
          - `update(Dict)`: Update all the not matched target table rows while changing only
            the specified columns based on the provided assignment.
          - `delete`: Delete all target rows that have no matches in the source table.
        """
        return self.WhenNotMatchedBySource(self, condition) 
[docs]    def withSchemaEvolution(self) -> "MergeIntoWriter":
        """
        Enable automatic schema evolution for this merge operation.
        """
        self._jwriter = self._jwriter.withSchemaEvolution()
        return self 
[docs]    def merge(self) -> None:
        """
        Execute the merge operation.
        """
        self._jwriter.merge() 
    class WhenMatched:
        """
        A class for defining actions to be taken when matching rows in a DataFrame during
        a merge operation."""
        def __init__(self, writer: "MergeIntoWriter", condition: Optional[Column]):
            self.writer = writer
            if condition is None:
                self.when_matched = writer._jwriter.whenMatched()
            else:
                from pyspark.sql.classic.column import _to_java_column
                self.when_matched = writer._jwriter.whenMatched(_to_java_column(condition))
        def updateAll(self) -> "MergeIntoWriter":
            """
            Specifies an action to update all matched rows in the DataFrame.
            """
            self.writer._jwriter = self.when_matched.updateAll()
            return self.writer
        def update(self, assignments: Dict[str, Column]) -> "MergeIntoWriter":
            """
            Specifies an action to update matched rows in the DataFrame with the provided column
            assignments.
            """
            jvm = self.writer._spark._jvm
            from pyspark.sql.classic.column import _to_java_column
            jmap = to_scala_map(jvm, {k: _to_java_column(v) for k, v in assignments.items()})
            self.writer._jwriter = self.when_matched.update(jmap)
            return self.writer
        def delete(self) -> "MergeIntoWriter":
            """
            Specifies an action to delete matched rows from the DataFrame.
            """
            self.writer._jwriter = self.when_matched.delete()
            return self.writer
    class WhenNotMatched:
        """
        A class for defining actions to be taken when no matching rows are found in a DataFrame
        during a merge operation."""
        def __init__(self, writer: "MergeIntoWriter", condition: Optional[Column]):
            self.writer = writer
            if condition is None:
                self.when_not_matched = writer._jwriter.whenNotMatched()
            else:
                from pyspark.sql.classic.column import _to_java_column
                self.when_not_matched = writer._jwriter.whenNotMatched(_to_java_column(condition))
        def insertAll(self) -> "MergeIntoWriter":
            """
            Specifies an action to insert all non-matched rows into the DataFrame.
            """
            self.writer._jwriter = self.when_not_matched.insertAll()
            return self.writer
        def insert(self, assignments: Dict[str, Column]) -> "MergeIntoWriter":
            """
            Specifies an action to insert non-matched rows into the DataFrame with the provided
            column assignments.
            """
            jvm = self.writer._spark._jvm
            from pyspark.sql.classic.column import _to_java_column
            jmap = to_scala_map(jvm, {k: _to_java_column(v) for k, v in assignments.items()})
            self.writer._jwriter = self.when_not_matched.insert(jmap)
            return self.writer
    class WhenNotMatchedBySource:
        """
        A class for defining actions to be performed when there is no match by source
        during a merge operation in a MergeIntoWriter.
        """
        def __init__(self, writer: "MergeIntoWriter", condition: Optional[Column]):
            self.writer = writer
            if condition is None:
                self.when_not_matched_by_source = writer._jwriter.whenNotMatchedBySource()
            else:
                from pyspark.sql.classic.column import _to_java_column
                self.when_not_matched_by_source = writer._jwriter.whenNotMatchedBySource(
                    _to_java_column(condition)
                )
        def updateAll(self) -> "MergeIntoWriter":
            """
            Specifies an action to update all non-matched rows in the target DataFrame when
            not matched by the source.
            """
            self.writer._jwriter = self.when_not_matched_by_source.updateAll()
            return self.writer
        def update(self, assignments: Dict[str, Column]) -> "MergeIntoWriter":
            """
            Specifies an action to update non-matched rows in the target DataFrame with the provided
            column assignments when not matched by the source.
            """
            jvm = self.writer._spark._jvm
            from pyspark.sql.classic.column import _to_java_column
            jmap = to_scala_map(jvm, {k: _to_java_column(v) for k, v in assignments.items()})
            self.writer._jwriter = self.when_not_matched_by_source.update(jmap)
            return self.writer
        def delete(self) -> "MergeIntoWriter":
            """
            Specifies an action to delete matched rows from the DataFrame.
            """
            self.writer._jwriter = self.when_not_matched_by_source.delete()
            return self.writer
def _test() -> None:
    import doctest
    import os
    import py4j
    from pyspark.core.context import SparkContext
    from pyspark.sql import SparkSession
    import pyspark.sql.merge
    os.chdir(os.environ["SPARK_HOME"])
    globs = pyspark.sql.merge.__dict__.copy()
    sc = SparkContext("local[4]", "PythonTest")
    try:
        spark = SparkSession._getActiveSessionOrCreate()
    except py4j.protocol.Py4JError:
        spark = SparkSession(sc)
    globs["spark"] = spark
    (failure_count, test_count) = doctest.testmod(
        pyspark.sql.merge,
        globs=globs,
        optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF,
    )
    spark.stop()
    if failure_count:
        sys.exit(-1)
if __name__ == "__main__":
    _test()