Queries to find distance between two nodes of a Binary tree – O(logn) method

Given a binary tree, the task is to find the distance between two keys in a binary tree, no parent pointers are given. Distance between two nodes is the minimum number of edges to be traversed to reach one node from other.
This problem has been already discussed in previous post but it uses three traversals of the Binary tree, one for finding Lowest Common Ancestor(LCA) of two nodes(let A and B) and then two traversals for finding distance between LCA and A and LCA and B which has O(n) time complexity. In this post, a method will be discussed that requires the O(log(n)) time to find LCA of two nodes.
The distance between two nodes can be obtained in terms of lowest common ancestor. Following is the formula.
Dist(n1, n2) = Dist(root, n1) + Dist(root, n2) - 2*Dist(root, lca) 'n1' and 'n2' are the two given keys 'root' is root of given Binary Tree. 'lca' is lowest common ancestor of n1 and n2 Dist(n1, n2) is the distance between n1 and n2.
The above formula can also be written as:
Dist(n1, n2) = Level[n1] + Level[n2] - 2*Level[lca]
This problem can be breakdown into:
- Finding levels of each node
- Finding the Euler tour of binary tree
- Building segment tree for LCA,
These steps are explained below :
- Find the levels of each node by applying level order traversal.
- Find the LCA of two nodes in binary tree in O(logn) by Storing Euler tour of Binary tree in array and computing two other arrays with the help of levels of each node and Euler tour.
These steps are shown below:
- First, find Euler Tour of binary tree.
Euler tour of binary tree in example
- Then, store levels of each node in Euler array in a different array.
- Then, store First occurrences of all nodes of binary tree in Euler array. H stores the indices of nodes from Euler array, so that range of query for finding minimum can be minimized and there by further optimizing the query time.
- Then build segment tree on L array and take the low and high values from H array that will give us the first occurrences of say Two nodes(A and B) . Then, we query segment tree to find the minimum value say X in range (H[A] to H[B]). Then we use the index of value X as index to Euler array to get LCA, i.e. Euler[index(X)].
Let, A = 8 and B = 5.- H[8] = 1 and H[5] =2
- Querying on Segment tree, we get min value in L array between 1 and 2 as X=0, index=7
- Then, LCA= Euler[7], i.e LCA = 1.
- Finally, we apply distance formula discussed above to get distance between two nodes.
Implementation:
C++
// C++ program to find distance between// two nodes for multiple queries#include <bits/stdc++.h>#define MAX 100001using namespace std;/* A tree node structure */struct Node { int data; struct Node* left; struct Node* right;};/* Utility function to create a new Binary Tree node */struct Node* newNode(int data){ struct Node* temp = new struct Node; temp->data = data; temp->left = temp->right = NULL; return temp;}// Array to store level of each nodeint level[MAX];// Utility Function to store level of all nodesvoid FindLevels(struct Node* root){ if (!root) return; // queue to hold tree node with level queue<pair<struct Node*, int> > q; // let root node be at level 0 q.push({ root, 0 }); pair<struct Node*, int> p; // Do level Order Traversal of tree while (!q.empty()) { p = q.front(); q.pop(); // Node p.first is on level p.second level[p.first->data] = p.second; // If left child exits, put it in queue // with current_level +1 if (p.first->left) q.push({ p.first->left, p.second + 1 }); // If right child exists, put it in queue // with current_level +1 if (p.first->right) q.push({ p.first->right, p.second + 1 }); }}// Stores Euler Tourint Euler[MAX];// index in Euler arrayint idx = 0;// Find Euler Tourvoid eulerTree(struct Node* root){ // store current node's data Euler[++idx] = root->data; // If left node exists if (root->left) { // traverse left subtree eulerTree(root->left); // store parent node's data Euler[++idx] = root->data; } // If right node exists if (root->right) { // traverse right subtree eulerTree(root->right); // store parent node's data Euler[++idx] = root->data; }}// checks for visited nodesint vis[MAX];// Stores level of Euler Tourint L[MAX];// Stores indices of first occurrence// of nodes in Euler tourint H[MAX];// Preprocessing Euler Tour for finding LCAvoid preprocessEuler(int size){ for (int i = 1; i <= size; i++) { L[i] = level[Euler[i]]; // If node is not visited before if (vis[Euler[i]] == 0) { // Add to first occurrence H[Euler[i]] = i; // Mark it visited vis[Euler[i]] = 1; } }}// Stores values and positionspair<int, int> seg[4 * MAX];// Utility function to find minimum of// pair type valuespair<int, int> min(pair<int, int> a, pair<int, int> b){ if (a.first <= b.first) return a; else return b;}// Utility function to build segment treepair<int, int> buildSegTree(int low, int high, int pos){ if (low == high) { seg[pos].first = L[low]; seg[pos].second = low; return seg[pos]; } int mid = low + (high - low) / 2; buildSegTree(low, mid, 2 * pos); buildSegTree(mid + 1, high, 2 * pos + 1); seg[pos] = min(seg[2 * pos], seg[2 * pos + 1]);}// Utility function to find LCApair<int, int> LCA(int qlow, int qhigh, int low, int high, int pos){ if (qlow <= low && qhigh >= high) return seg[pos]; if (qlow > high || qhigh < low) return { INT_MAX, 0 }; int mid = low + (high - low) / 2; return min(LCA(qlow, qhigh, low, mid, 2 * pos), LCA(qlow, qhigh, mid + 1, high, 2 * pos + 1));}// Function to return distance between// two nodes n1 and n2int findDistance(int n1, int n2, int size){ // Maintain original Values int prevn1 = n1, prevn2 = n2; // Get First Occurrence of n1 n1 = H[n1]; // Get First Occurrence of n2 n2 = H[n2]; // Swap if low > high if (n2 < n1) swap(n1, n2); // Get position of minimum value int lca = LCA(n1, n2, 1, size, 1).second; // Extract value out of Euler tour lca = Euler[lca]; // return calculated distance return level[prevn1] + level[prevn2] - 2 * level[lca];}void preProcessing(Node* root, int N){ // Build Tree eulerTree(root); // Store Levels FindLevels(root); // Find L and H array preprocessEuler(2 * N - 1); // Build segment Tree buildSegTree(1, 2 * N - 1, 1);}/* Driver function to test above functions */int main(){ int N = 8; // Number of nodes /* Constructing tree given in the above figure */ Node* root = newNode(1); root->left = newNode(2); root->right = newNode(3); root->left->left = newNode(4); root->left->right = newNode(5); root->right->left = newNode(6); root->right->right = newNode(7); root->right->left->right = newNode(8); // Function to do all preprocessing preProcessing(root, N); cout << "Dist(4, 5) = " << findDistance(4, 5, 2 * N - 1) << "\n"; cout << "Dist(4, 6) = " << findDistance(4, 6, 2 * N - 1) << "\n"; cout << "Dist(3, 4) = " << findDistance(3, 4, 2 * N - 1) << "\n"; cout << "Dist(2, 4) = " << findDistance(2, 4, 2 * N - 1) << "\n"; cout << "Dist(8, 5) = " << findDistance(8, 5, 2 * N - 1) << "\n"; return 0;} |
Java
// Java program to find distance between // two nodes for multiple queriesimport java.io.*;import java.util.*;class GFG{ static int MAX = 100001; /* A tree node structure */ static class Node { int data; Node left, right; Node(int data) { this.data = data; this.left = this.right = null; } } static class Pair<T, V> { T first; V second; Pair() { } Pair(T first, V second) { this.first = first; this.second = second; } } // Array to store level of each node static int[] level = new int[MAX]; // Utility Function to store level of all nodes static void findLevels(Node root) { if (root == null) return; // queue to hold tree node with level Queue<Pair<Node, Integer>> q = new LinkedList<>(); // let root node be at level 0 q.add(new Pair<Node, Integer>(root, 0)); Pair<Node, Integer> p = new Pair<Node, Integer>(); // Do level Order Traversal of tree while (!q.isEmpty()) { p = q.poll(); // Node p.first is on level p.second level[p.first.data] = p.second; // If left child exits, put it in queue // with current_level +1 if (p.first.left != null) q.add(new Pair<Node, Integer>(p.first.left, p.second + 1)); // If right child exists, put it in queue // with current_level +1 if (p.first.right != null) q.add(new Pair<Node, Integer>(p.first.right, p.second + 1)); } } // Stores Euler Tour static int[] Euler = new int[MAX]; // index in Euler array static int idx = 0; // Find Euler Tour static void eulerTree(Node root) { // store current node's data Euler[++idx] = root.data; // If left node exists if (root.left != null) { // traverse left subtree eulerTree(root.left); // store parent node's data Euler[++idx] = root.data; } // If right node exists if (root.right != null) { // traverse right subtree eulerTree(root.right); // store parent node's data Euler[++idx] = root.data; } } // checks for visited nodes static int[] vis = new int[MAX]; // Stores level of Euler Tour static int[] L = new int[MAX]; // Stores indices of first occurrence // of nodes in Euler tour static int[] H = new int[MAX]; // Preprocessing Euler Tour for finding LCA static void preprocessEuler(int size) { for (int i = 1; i <= size; i++) { L[i] = level[Euler[i]]; // If node is not visited before if (vis[Euler[i]] == 0) { // Add to first occurrence H[Euler[i]] = i; // Mark it visited vis[Euler[i]] = 1; } } } // Stores values and positions @SuppressWarnings("unchecked") static Pair<Integer, Integer>[] seg = (Pair<Integer, Integer>[]) new Pair[4 * MAX]; // Utility function to find minimum of // pair type values static Pair<Integer, Integer> min(Pair<Integer, Integer> a, Pair<Integer, Integer> b) { if (a.first <= b.first) return a; return b; } // Utility function to build segment tree static Pair<Integer, Integer> buildSegTree(int low, int high, int pos) { if (low == high) { seg[pos].first = L[low]; seg[pos].second = low; return seg[pos]; } int mid = low + (high - low) / 2; buildSegTree(low, mid, 2 * pos); buildSegTree(mid + 1, high, 2 * pos + 1); seg[pos] = min(seg[2 * pos], seg[2 * pos + 1]); return seg[pos]; } // Utility function to find LCA static Pair<Integer, Integer> LCA(int qlow, int qhigh, int low, int high, int pos) { if (qlow <= low && qhigh >= high) return seg[pos]; if (qlow > high || qhigh < low) return new Pair<Integer, Integer> (Integer.MAX_VALUE, 0); int mid = low + (high - low) / 2; return min(LCA(qlow, qhigh, low, mid, 2 * pos), LCA(qlow, qhigh, mid + 1, high, 2 * pos + 1)); } // Function to return distance between // two nodes n1 and n2 static int findDistance(int n1, int n2, int size) { // Maintain original Values int prevn1 = n1, prevn2 = n2; // Get First Occurrence of n1 n1 = H[n1]; // Get First Occurrence of n2 n2 = H[n2]; // Swap if low > high if (n2 < n1) { int temp = n1; n1 = n2; n2 = temp; } // Get position of minimum value int lca = LCA(n1, n2, 1, size, 1).second; // Extract value out of Euler tour lca = Euler[lca]; // return calculated distance return level[prevn1] + level[prevn2] - 2 * level[lca]; } static void preProcessing(Node root, int N) { for (int i = 0; i < 4 * MAX; i++) { seg[i] = new Pair<>(); } // Build Tree eulerTree(root); // Store Levels findLevels(root); // Find L and H array preprocessEuler(2 * N - 1); // Build segment Tree buildSegTree(1, 2 * N - 1, 1); } // Driver Code public static void main(String[] args) { // Number of nodes int N = 8; /* Constructing tree given in the above figure */ Node root = new Node(1); root.left = new Node(2); root.right = new Node(3); root.left.left = new Node(4); root.left.right = new Node(5); root.right.left = new Node(6); root.right.right = new Node(7); root.right.left.right = new Node(8); // Function to do all preprocessing preProcessing(root, N); System.out.println("Dist(4, 5) = " + findDistance(4, 5, 2 * N - 1)); System.out.println("Dist(4, 6) = " + findDistance(4, 6, 2 * N - 1)); System.out.println("Dist(3, 4) = " + findDistance(3, 4, 2 * N - 1)); System.out.println("Dist(2, 4) = " + findDistance(2, 4, 2 * N - 1)); System.out.println("Dist(8, 5) = " + findDistance(8, 5, 2 * N - 1)); }}// This code is contributed by// sanjeev2552 |
Python3
# Python3 program to find distance between# two nodes for multiple queriesfrom collections import dequefrom sys import maxsize as INT_MAXMAX = 100001# A tree node structureclass Node: def __init__(self, data): self.data = data self.left = None self.right = None# Array to store level of each nodelevel = [0] * MAX# Utility Function to store level of all nodesdef findLevels(root: Node): global level if root is None: return # queue to hold tree node with level q = deque() # let root node be at level 0 q.append((root, 0)) # Do level Order Traversal of tree while q: p = q[0] q.popleft() # Node p.first is on level p.second level[p[0].data] = p[1] # If left child exits, put it in queue # with current_level +1 if p[0].left: q.append((p[0].left, p[1] + 1)) # If right child exists, put it in queue # with current_level +1 if p[0].right: q.append((p[0].right, p[1] + 1))# Stores Euler TourEuler = [0] * MAX# index in Euler arrayidx = 0# Find Euler Tourdef eulerTree(root: Node): global Euler, idx idx += 1 # store current node's data Euler[idx] = root.data # If left node exists if root.left: # traverse left subtree eulerTree(root.left) idx += 1 # store parent node's data Euler[idx] = root.data # If right node exists if root.right: # traverse right subtree eulerTree(root.right) idx += 1 # store parent node's data Euler[idx] = root.data# checks for visited nodesvis = [0] * MAX# Stores level of Euler TourL = [0] * MAX# Stores indices of the first occurrence# of nodes in Euler tourH = [0] * MAX# Preprocessing Euler Tour for finding LCAdef preprocessEuler(size: int): global L, H, vis for i in range(1, size + 1): L[i] = level[Euler[i]] # If node is not visited before if vis[Euler[i]] == 0: # Add to first occurrence H[Euler[i]] = i # Mark it visited vis[Euler[i]] = 1# Stores values and positionsseg = [0] * (4 * MAX)for i in range(4 * MAX): seg[i] = [0, 0]# Utility function to find minimum of# pair type valuesdef minPair(a: list, b: list) -> list: if a[0] <= b[0]: return a else: return b# Utility function to build segment treedef buildSegTree(low: int, high: int, pos: int) -> list: if low == high: seg[pos][0] = L[low] seg[pos][1] = low return seg[pos] mid = low + (high - low) // 2 buildSegTree(low, mid, 2 * pos) buildSegTree(mid + 1, high, 2 * pos + 1) seg[pos] = min(seg[2 * pos], seg[2 * pos + 1])# Utility function to find LCAdef LCA(qlow: int, qhigh: int, low: int, high: int, pos: int) -> list: if qlow <= low and qhigh >= high: return seg[pos] if qlow > high or qhigh < low: return [INT_MAX, 0] mid = low + (high - low) // 2 return minPair(LCA(qlow, qhigh, low, mid, 2 * pos), LCA(qlow, qhigh, mid + 1, high, 2 * pos + 1))# Function to return distance between# two nodes n1 and n2def findDistance(n1: int, n2: int, size: int) -> int: # Maintain original Values prevn1 = n1 prevn2 = n2 # Get First Occurrence of n1 n1 = H[n1] # Get First Occurrence of n2 n2 = H[n2] # Swap if low>high if n2 < n1: n1, n2 = n2, n1 # Get position of minimum value lca = LCA(n1, n2, 1, size, 1)[1] # Extract value out of Euler tour lca = Euler[lca] # return calculated distance return level[prevn1] + level[prevn2] - 2 * level[lca]def preProcessing(root: Node, N: int): # Build Tree eulerTree(root) # Store Levels findLevels(root) # Find L and H array preprocessEuler(2 * N - 1) # Build sparse table buildSegTree(1, 2 * N - 1, 1)# Driver Codeif __name__ == "__main__": # Number of nodes N = 8 # Constructing tree given in the above figure root = Node(1) root.left = Node(2) root.right = Node(3) root.left.left = Node(4) root.left.right = Node(5) root.right.left = Node(6) root.right.right = Node(7) root.right.left.right = Node(8) # Function to do all preprocessing preProcessing(root, N) print("Dist(4, 5) =", findDistance(4, 5, 2 * N - 1)) print("Dist(4, 6) =", findDistance(4, 6, 2 * N - 1)) print("Dist(3, 4) =", findDistance(3, 4, 2 * N - 1)) print("Dist(2, 4) =", findDistance(2, 4, 2 * N - 1)) print("Dist(8, 5) =", findDistance(8, 5, 2 * N - 1))# This code is contributed by# sanjeev2552 |
C#
// C# program to find distance between // two nodes for multiple queriesusing System;using System.Collections.Generic;class GFG{ static int MAX = 100001; /* A tree node structure */ public class Node { public int data; public Node left, right; public Node(int data) { this.data = data; this.left = this.right = null; } } class Pair<T, V> { public T first; public V second; public Pair() { } public Pair(T first, V second) { this.first = first; this.second = second; } } // Array to store level of each node static int[] level = new int[MAX]; // Utility Function to store level of all nodes static void findLevels(Node root) { if (root == null) return; // queue to hold tree node with level List<Pair<Node, int>> q = new List<Pair<Node, int>>(); // let root node be at level 0 q.Add(new Pair<Node, int>(root, 0)); Pair<Node, int> p = new Pair<Node, int>(); // Do level Order Traversal of tree while (q.Count != 0) { p = q[0]; q.RemoveAt(0); // Node p.first is on level p.second level[p.first.data] = p.second; // If left child exits, put it in queue // with current_level +1 if (p.first.left != null) q.Add(new Pair<Node, int> (p.first.left, p.second + 1)); // If right child exists, put it in queue // with current_level +1 if (p.first.right != null) q.Add(new Pair<Node, int> (p.first.right, p.second + 1)); } } // Stores Euler Tour static int[] Euler = new int[MAX]; // index in Euler array static int idx = 0; // Find Euler Tour static void eulerTree(Node root) { // store current node's data Euler[++idx] = root.data; // If left node exists if (root.left != null) { // traverse left subtree eulerTree(root.left); // store parent node's data Euler[++idx] = root.data; } // If right node exists if (root.right != null) { // traverse right subtree eulerTree(root.right); // store parent node's data Euler[++idx] = root.data; } } // checks for visited nodes static int[] vis = new int[MAX]; // Stores level of Euler Tour static int[] L = new int[MAX]; // Stores indices of first occurrence // of nodes in Euler tour static int[] H = new int[MAX]; // Preprocessing Euler Tour for finding LCA static void preprocessEuler(int size) { for (int i = 1; i <= size; i++) { L[i] = level[Euler[i]]; // If node is not visited before if (vis[Euler[i]] == 0) { // Add to first occurrence H[Euler[i]] = i; // Mark it visited vis[Euler[i]] = 1; } } } // Stores values and positions static Pair<int, int>[] seg = new Pair<int, int>[4 * MAX]; // Utility function to find minimum of // pair type values static Pair<int, int> min(Pair<int, int> a, Pair<int, int> b) { if (a.first <= b.first) return a; return b; } // Utility function to build segment tree static Pair<int, int> buildSegTree(int low, int high, int pos) { if (low == high) { seg[pos].first = L[low]; seg[pos].second = low; return seg[pos]; } int mid = low + (high - low) / 2; buildSegTree(low, mid, 2 * pos); buildSegTree(mid + 1, high, 2 * pos + 1); seg[pos] = min(seg[2 * pos], seg[2 * pos + 1]); return seg[pos]; } // Utility function to find LCA static Pair<int, int> LCA(int qlow, int qhigh, int low, int high, int pos) { if (qlow <= low && qhigh >= high) return seg[pos]; if (qlow > high || qhigh < low) return new Pair<int, int>(int.MaxValue, 0); int mid = low + (high - low) / 2; return min(LCA(qlow, qhigh, low, mid, 2 * pos), LCA(qlow, qhigh, mid + 1, high, 2 * pos + 1)); } // Function to return distance between // two nodes n1 and n2 static int findDistance(int n1, int n2, int size) { // Maintain original Values int prevn1 = n1, prevn2 = n2; // Get First Occurrence of n1 n1 = H[n1]; // Get First Occurrence of n2 n2 = H[n2]; // Swap if low > high if (n2 < n1) { int temp = n1; n1 = n2; n2 = temp; } // Get position of minimum value int lca = LCA(n1, n2, 1, size, 1).second; // Extract value out of Euler tour lca = Euler[lca]; // return calculated distance return level[prevn1] + level[prevn2] - 2 * level[lca]; } static void preProcessing(Node root, int N) { for (int i = 0; i < 4 * MAX; i++) { seg[i] = new Pair<int,int>(); } // Build Tree eulerTree(root); // Store Levels findLevels(root); // Find L and H array preprocessEuler(2 * N - 1); // Build segment Tree buildSegTree(1, 2 * N - 1, 1); } // Driver Code public static void Main(String[] args) { // Number of nodes int N = 8; /* Constructing tree given in the above figure */ Node root = new Node(1); root.left = new Node(2); root.right = new Node(3); root.left.left = new Node(4); root.left.right = new Node(5); root.right.left = new Node(6); root.right.right = new Node(7); root.right.left.right = new Node(8); // Function to do all preprocessing preProcessing(root, N); Console.WriteLine("Dist(4, 5) = " + findDistance(4, 5, 2 * N - 1)); Console.WriteLine("Dist(4, 6) = " + findDistance(4, 6, 2 * N - 1)); Console.WriteLine("Dist(3, 4) = " + findDistance(3, 4, 2 * N - 1)); Console.WriteLine("Dist(2, 4) = " + findDistance(2, 4, 2 * N - 1)); Console.WriteLine("Dist(8, 5) = " + findDistance(8, 5, 2 * N - 1)); }}// This code is contributed by Rajput-Ji |
Dist(4, 5) = 2 Dist(4, 6) = 4 Dist(3, 4) = 3 Dist(2, 4) = 1 Dist(8, 5) = 5
Complexity Analysis:
- Time Complexity: O(Log N)
- Space Complexity: O(N)
Queries to find distance between two nodes of a Binary tree – O(1) method
Ready to dive in? Explore our Free Demo Content and join our DSA course, trusted by over 100,000 zambiatek!



