From 3eeb06ae5cd8e3be30f125b0450da5c6a9ff4d84 Mon Sep 17 00:00:00 2001 From: "oleksii.tumanov" Date: Fri, 13 Mar 2026 19:32:58 -0500 Subject: [PATCH] feat: add OptimalBinarySearchTree algorithm --- .../OptimalBinarySearchTree.java | 108 ++++++++++++++++++ .../OptimalBinarySearchTreeTest.java | 73 ++++++++++++ 2 files changed, 181 insertions(+) create mode 100644 src/main/java/com/thealgorithms/dynamicprogramming/OptimalBinarySearchTree.java create mode 100644 src/test/java/com/thealgorithms/dynamicprogramming/OptimalBinarySearchTreeTest.java diff --git a/src/main/java/com/thealgorithms/dynamicprogramming/OptimalBinarySearchTree.java b/src/main/java/com/thealgorithms/dynamicprogramming/OptimalBinarySearchTree.java new file mode 100644 index 000000000000..a727c9c5af60 --- /dev/null +++ b/src/main/java/com/thealgorithms/dynamicprogramming/OptimalBinarySearchTree.java @@ -0,0 +1,108 @@ +package com.thealgorithms.dynamicprogramming; + +import java.util.Arrays; +import java.util.Comparator; + +/** + * Computes the minimum search cost of an optimal binary search tree. + * + *

The algorithm sorts the keys, preserves the corresponding search frequencies, and uses + * dynamic programming with Knuth's optimization to compute the minimum weighted search cost. + * + *

Reference: + * https://en.wikipedia.org/wiki/Optimal_binary_search_tree + */ +public final class OptimalBinarySearchTree { + private OptimalBinarySearchTree() { + } + + /** + * Computes the minimum weighted search cost for the given keys and search frequencies. + * + * @param keys the BST keys + * @param frequencies the search frequencies associated with the keys + * @return the minimum search cost + * @throws IllegalArgumentException if the input is invalid + */ + public static long findOptimalCost(int[] keys, int[] frequencies) { + validateInput(keys, frequencies); + if (keys.length == 0) { + return 0L; + } + + int[][] sortedNodes = sortNodes(keys, frequencies); + int nodeCount = sortedNodes.length; + long[] prefixSums = buildPrefixSums(sortedNodes); + long[][] optimalCost = new long[nodeCount][nodeCount]; + int[][] root = new int[nodeCount][nodeCount]; + + for (int index = 0; index < nodeCount; index++) { + optimalCost[index][index] = sortedNodes[index][1]; + root[index][index] = index; + } + + for (int length = 2; length <= nodeCount; length++) { + for (int start = 0; start <= nodeCount - length; start++) { + int end = start + length - 1; + long frequencySum = prefixSums[end + 1] - prefixSums[start]; + optimalCost[start][end] = Long.MAX_VALUE; + + int leftBoundary = root[start][end - 1]; + int rightBoundary = root[start + 1][end]; + for (int currentRoot = leftBoundary; currentRoot <= rightBoundary; currentRoot++) { + long leftCost = currentRoot > start ? optimalCost[start][currentRoot - 1] : 0L; + long rightCost = currentRoot < end ? optimalCost[currentRoot + 1][end] : 0L; + long currentCost = frequencySum + leftCost + rightCost; + + if (currentCost < optimalCost[start][end]) { + optimalCost[start][end] = currentCost; + root[start][end] = currentRoot; + } + } + } + } + + return optimalCost[0][nodeCount - 1]; + } + + private static void validateInput(int[] keys, int[] frequencies) { + if (keys == null || frequencies == null) { + throw new IllegalArgumentException("Keys and frequencies cannot be null"); + } + if (keys.length != frequencies.length) { + throw new IllegalArgumentException("Keys and frequencies must have the same length"); + } + + for (int frequency : frequencies) { + if (frequency < 0) { + throw new IllegalArgumentException("Frequencies cannot be negative"); + } + } + } + + private static int[][] sortNodes(int[] keys, int[] frequencies) { + int[][] sortedNodes = new int[keys.length][2]; + for (int index = 0; index < keys.length; index++) { + sortedNodes[index][0] = keys[index]; + sortedNodes[index][1] = frequencies[index]; + } + + Arrays.sort(sortedNodes, Comparator.comparingInt(node -> node[0])); + + for (int index = 1; index < sortedNodes.length; index++) { + if (sortedNodes[index - 1][0] == sortedNodes[index][0]) { + throw new IllegalArgumentException("Keys must be distinct"); + } + } + + return sortedNodes; + } + + private static long[] buildPrefixSums(int[][] sortedNodes) { + long[] prefixSums = new long[sortedNodes.length + 1]; + for (int index = 0; index < sortedNodes.length; index++) { + prefixSums[index + 1] = prefixSums[index] + sortedNodes[index][1]; + } + return prefixSums; + } +} diff --git a/src/test/java/com/thealgorithms/dynamicprogramming/OptimalBinarySearchTreeTest.java b/src/test/java/com/thealgorithms/dynamicprogramming/OptimalBinarySearchTreeTest.java new file mode 100644 index 000000000000..17ff3ec728dc --- /dev/null +++ b/src/test/java/com/thealgorithms/dynamicprogramming/OptimalBinarySearchTreeTest.java @@ -0,0 +1,73 @@ +package com.thealgorithms.dynamicprogramming; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.util.Arrays; +import java.util.stream.Stream; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +class OptimalBinarySearchTreeTest { + + @ParameterizedTest + @MethodSource("validTestCases") + void testFindOptimalCost(int[] keys, int[] frequencies, long expectedCost) { + assertEquals(expectedCost, OptimalBinarySearchTree.findOptimalCost(keys, frequencies)); + } + + private static Stream validTestCases() { + return Stream.of(Arguments.of(new int[] {}, new int[] {}, 0L), Arguments.of(new int[] {15}, new int[] {9}, 9L), Arguments.of(new int[] {10, 12}, new int[] {34, 50}, 118L), Arguments.of(new int[] {20, 10, 30}, new int[] {50, 34, 8}, 134L), + Arguments.of(new int[] {12, 10, 20, 42, 25, 37}, new int[] {8, 34, 50, 3, 40, 30}, 324L), Arguments.of(new int[] {1, 2, 3}, new int[] {0, 0, 0}, 0L)); + } + + @ParameterizedTest + @MethodSource("crossCheckTestCases") + void testFindOptimalCostAgainstBruteForce(int[] keys, int[] frequencies) { + assertEquals(bruteForceOptimalCost(keys, frequencies), OptimalBinarySearchTree.findOptimalCost(keys, frequencies)); + } + + private static Stream crossCheckTestCases() { + return Stream.of(Arguments.of(new int[] {3, 1, 2}, new int[] {4, 2, 6}), Arguments.of(new int[] {5, 2, 8, 6}, new int[] {3, 7, 1, 4}), Arguments.of(new int[] {9, 4, 11, 2}, new int[] {1, 8, 2, 5})); + } + + @ParameterizedTest + @MethodSource("invalidTestCases") + void testFindOptimalCostInvalidInput(int[] keys, int[] frequencies) { + assertThrows(IllegalArgumentException.class, () -> OptimalBinarySearchTree.findOptimalCost(keys, frequencies)); + } + + private static Stream invalidTestCases() { + return Stream.of(Arguments.of(null, new int[] {}), Arguments.of(new int[] {}, null), Arguments.of(new int[] {1, 2}, new int[] {3}), Arguments.of(new int[] {1, 1}, new int[] {2, 3}), Arguments.of(new int[] {1, 2}, new int[] {3, -1})); + } + + private static long bruteForceOptimalCost(int[] keys, int[] frequencies) { + int[][] sortedNodes = new int[keys.length][2]; + for (int index = 0; index < keys.length; index++) { + sortedNodes[index][0] = keys[index]; + sortedNodes[index][1] = frequencies[index]; + } + Arrays.sort(sortedNodes, java.util.Comparator.comparingInt(node -> node[0])); + + int[] sortedFrequencies = new int[sortedNodes.length]; + for (int index = 0; index < sortedNodes.length; index++) { + sortedFrequencies[index] = sortedNodes[index][1]; + } + + return bruteForceOptimalCost(sortedFrequencies, 0, sortedFrequencies.length - 1, 1); + } + + private static long bruteForceOptimalCost(int[] frequencies, int start, int end, int depth) { + if (start > end) { + return 0L; + } + + long minimumCost = Long.MAX_VALUE; + for (int root = start; root <= end; root++) { + long currentCost = (long) depth * frequencies[root] + bruteForceOptimalCost(frequencies, start, root - 1, depth + 1) + bruteForceOptimalCost(frequencies, root + 1, end, depth + 1); + minimumCost = Math.min(minimumCost, currentCost); + } + return minimumCost; + } +}