Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package com.thealgorithms.dynamicprogramming;

import java.util.Arrays;
import java.util.Comparator;

/**
* Computes the minimum search cost of an optimal binary search tree.
*
* <p>The algorithm sorts the keys, preserves the corresponding search frequencies, and uses
* dynamic programming with Knuth's optimization to compute the minimum weighted search cost.
*
* <p>Reference:
* https://en.wikipedia.org/wiki/Optimal_binary_search_tree
*/
public final class OptimalBinarySearchTree {
private OptimalBinarySearchTree() {
}

/**
* Computes the minimum weighted search cost for the given keys and search frequencies.
*
* @param keys the BST keys
* @param frequencies the search frequencies associated with the keys
* @return the minimum search cost
* @throws IllegalArgumentException if the input is invalid
*/
public static long findOptimalCost(int[] keys, int[] frequencies) {
validateInput(keys, frequencies);
if (keys.length == 0) {
return 0L;
}

int[][] sortedNodes = sortNodes(keys, frequencies);
int nodeCount = sortedNodes.length;
long[] prefixSums = buildPrefixSums(sortedNodes);
long[][] optimalCost = new long[nodeCount][nodeCount];
int[][] root = new int[nodeCount][nodeCount];

for (int index = 0; index < nodeCount; index++) {
optimalCost[index][index] = sortedNodes[index][1];
root[index][index] = index;
}

for (int length = 2; length <= nodeCount; length++) {
for (int start = 0; start <= nodeCount - length; start++) {
int end = start + length - 1;
long frequencySum = prefixSums[end + 1] - prefixSums[start];
optimalCost[start][end] = Long.MAX_VALUE;

int leftBoundary = root[start][end - 1];
int rightBoundary = root[start + 1][end];
for (int currentRoot = leftBoundary; currentRoot <= rightBoundary; currentRoot++) {
long leftCost = currentRoot > start ? optimalCost[start][currentRoot - 1] : 0L;
long rightCost = currentRoot < end ? optimalCost[currentRoot + 1][end] : 0L;
long currentCost = frequencySum + leftCost + rightCost;

if (currentCost < optimalCost[start][end]) {
optimalCost[start][end] = currentCost;
root[start][end] = currentRoot;
}
}
}
}

return optimalCost[0][nodeCount - 1];
}

private static void validateInput(int[] keys, int[] frequencies) {
if (keys == null || frequencies == null) {
throw new IllegalArgumentException("Keys and frequencies cannot be null");
}
if (keys.length != frequencies.length) {
throw new IllegalArgumentException("Keys and frequencies must have the same length");
}

for (int frequency : frequencies) {
if (frequency < 0) {
throw new IllegalArgumentException("Frequencies cannot be negative");
}
}
}

private static int[][] sortNodes(int[] keys, int[] frequencies) {
int[][] sortedNodes = new int[keys.length][2];
for (int index = 0; index < keys.length; index++) {
sortedNodes[index][0] = keys[index];
sortedNodes[index][1] = frequencies[index];
}

Arrays.sort(sortedNodes, Comparator.comparingInt(node -> node[0]));

for (int index = 1; index < sortedNodes.length; index++) {
if (sortedNodes[index - 1][0] == sortedNodes[index][0]) {
throw new IllegalArgumentException("Keys must be distinct");
}
}

return sortedNodes;
}

private static long[] buildPrefixSums(int[][] sortedNodes) {
long[] prefixSums = new long[sortedNodes.length + 1];
for (int index = 0; index < sortedNodes.length; index++) {
prefixSums[index + 1] = prefixSums[index] + sortedNodes[index][1];
}
return prefixSums;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package com.thealgorithms.dynamicprogramming;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

import java.util.Arrays;
import java.util.stream.Stream;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

class OptimalBinarySearchTreeTest {

@ParameterizedTest
@MethodSource("validTestCases")
void testFindOptimalCost(int[] keys, int[] frequencies, long expectedCost) {
assertEquals(expectedCost, OptimalBinarySearchTree.findOptimalCost(keys, frequencies));
}

private static Stream<Arguments> validTestCases() {
return Stream.of(Arguments.of(new int[] {}, new int[] {}, 0L), Arguments.of(new int[] {15}, new int[] {9}, 9L), Arguments.of(new int[] {10, 12}, new int[] {34, 50}, 118L), Arguments.of(new int[] {20, 10, 30}, new int[] {50, 34, 8}, 134L),
Arguments.of(new int[] {12, 10, 20, 42, 25, 37}, new int[] {8, 34, 50, 3, 40, 30}, 324L), Arguments.of(new int[] {1, 2, 3}, new int[] {0, 0, 0}, 0L));
}

@ParameterizedTest
@MethodSource("crossCheckTestCases")
void testFindOptimalCostAgainstBruteForce(int[] keys, int[] frequencies) {
assertEquals(bruteForceOptimalCost(keys, frequencies), OptimalBinarySearchTree.findOptimalCost(keys, frequencies));
}

private static Stream<Arguments> crossCheckTestCases() {
return Stream.of(Arguments.of(new int[] {3, 1, 2}, new int[] {4, 2, 6}), Arguments.of(new int[] {5, 2, 8, 6}, new int[] {3, 7, 1, 4}), Arguments.of(new int[] {9, 4, 11, 2}, new int[] {1, 8, 2, 5}));
}

@ParameterizedTest
@MethodSource("invalidTestCases")
void testFindOptimalCostInvalidInput(int[] keys, int[] frequencies) {
assertThrows(IllegalArgumentException.class, () -> OptimalBinarySearchTree.findOptimalCost(keys, frequencies));
}

private static Stream<Arguments> invalidTestCases() {
return Stream.of(Arguments.of(null, new int[] {}), Arguments.of(new int[] {}, null), Arguments.of(new int[] {1, 2}, new int[] {3}), Arguments.of(new int[] {1, 1}, new int[] {2, 3}), Arguments.of(new int[] {1, 2}, new int[] {3, -1}));
}

private static long bruteForceOptimalCost(int[] keys, int[] frequencies) {
int[][] sortedNodes = new int[keys.length][2];
for (int index = 0; index < keys.length; index++) {
sortedNodes[index][0] = keys[index];
sortedNodes[index][1] = frequencies[index];
}
Arrays.sort(sortedNodes, java.util.Comparator.comparingInt(node -> node[0]));

int[] sortedFrequencies = new int[sortedNodes.length];
for (int index = 0; index < sortedNodes.length; index++) {
sortedFrequencies[index] = sortedNodes[index][1];
}

return bruteForceOptimalCost(sortedFrequencies, 0, sortedFrequencies.length - 1, 1);
}

private static long bruteForceOptimalCost(int[] frequencies, int start, int end, int depth) {
if (start > end) {
return 0L;
}

long minimumCost = Long.MAX_VALUE;
for (int root = start; root <= end; root++) {
long currentCost = (long) depth * frequencies[root] + bruteForceOptimalCost(frequencies, start, root - 1, depth + 1) + bruteForceOptimalCost(frequencies, root + 1, end, depth + 1);
minimumCost = Math.min(minimumCost, currentCost);
}
return minimumCost;
}
}
Loading