Find distance between two nodes of a Binary Tree

AucFind the distance between two keys in a binary tree, no parent pointers are given. The distance between two nodes is the minimum number of edges to be traversed to reach one node from another.
Â
Â
The distance between two nodes can be obtained in terms of lowest common ancestor. Following is the formula.Â
Dist(n1, n2) = Dist(root, n1) + Dist(root, n2) - 2*Dist(root, lca) 'n1' and 'n2' are the two given keys 'root' is root of given Binary Tree. 'lca' is lowest common ancestor of n1 and n2 Dist(n1, n2) is the distance between n1 and n2.
Following is the implementation of the above approach. The implementation is adopted from the last code provided in Lowest Common Ancestor Post.Â
C++
/* C++ program to find distance between n1 and n2 using    one traversal */#include <iostream>using namespace std;Â
// A Binary Tree Nodestruct Node{Â Â Â Â struct Node *left, *right;Â Â Â Â int key;};Â
// Utility function to create a new tree NodeNode* newNode(int key){Â Â Â Â Node *temp = new Node;Â Â Â Â temp->key = key;Â Â Â Â temp->left = temp->right = NULL;Â Â Â Â return temp;}Â
// Returns level of key k if it is present in tree, // otherwise returns -1int findLevel(Node *root, int k, int level){    // Base Case    if (root == NULL)        return -1;Â
    // If key is present at root, or in left subtree    // or right subtree, return true;    if (root->key == k)        return level;Â
    int l = findLevel(root->left, k, level+1);    return (l != -1)? l : findLevel(root->right, k, level+1);}Â
// This function returns pointer to LCA of two given// values n1 and n2. It also sets d1, d2 and dist if // one key is not ancestor of other// d1 --> To store distance of n1 from root// d2 --> To store distance of n2 from root// lvl --> Level (or distance from root) of current node// dist --> To store distance between n1 and n2Node *findDistUtil(Node* root, int n1, int n2, int &d1,                             int &d2, int &dist, int lvl){    // Base case    if (root == NULL) return NULL;Â
    // If either n1 or n2 matches with root's key, report    // the presence by returning root (Note that if a key is    // ancestor of other, then the ancestor key becomes LCA    if (root->key == n1)    {         d1 = lvl;         return root;    }    if (root->key == n2)    {         d2 = lvl;         return root;    }Â
    // Look for n1 and n2 in left and right subtrees    Node *left_lca = findDistUtil(root->left, n1, n2,                                    d1, d2, dist, lvl+1);    Node *right_lca = findDistUtil(root->right, n1, n2,                                   d1, d2, dist, lvl+1);Â
    // If both of the above calls return Non-NULL, then    // one key is present in once subtree and other is     // present in other. So this node is the LCA    if (left_lca && right_lca)    {        dist = d1 + d2 - 2*lvl;        return root;    }Â
    // Otherwise check if left subtree or right subtree     // is LCA    return (left_lca != NULL)? left_lca: right_lca;}Â
// The main function that returns distance between n1// and n2. This function returns -1 if either n1 or n2// is not present in Binary Tree.int findDistance(Node *root, int n1, int n2){    // Initialize d1 (distance of n1 from root), d2     // (distance of n2 from root) and dist(distance     // between n1 and n2)    int d1 = -1, d2 = -1, dist;    Node *lca = findDistUtil(root, n1, n2, d1, d2,                                          dist, 1);Â
    // If both n1 and n2 were present in Binary     // Tree, return dist    if (d1 != -1 && d2 != -1)        return dist;Â
    // If n1 is ancestor of n2, consider n1 as root     // and find level of n2 in subtree rooted with n1    if (d1 != -1)    {        dist = findLevel(lca, n2, 0);        return dist;    }Â
    // If n2 is ancestor of n1, consider n2 as root     // and find level of n1 in subtree rooted with n2    if (d2 != -1)    {        dist = findLevel(lca, n1, 0);        return dist;    }Â
    return -1;}Â
// Driver program to test above functionsint main(){    // Let us create binary tree given in the    // above example    Node * root = newNode(1);    root->left = newNode(2);    root->right = newNode(3);    root->left->left = newNode(4);    root->left->right = newNode(5);    root->right->left = newNode(6);    root->right->right = newNode(7);    root->right->left->right = newNode(8);    cout << "Dist(4, 5) = " << findDistance(root, 4, 5);    cout << "\nDist(4, 6) = " << findDistance(root, 4, 6);    cout << "\nDist(3, 4) = " << findDistance(root, 3, 4);    cout << "\nDist(2, 4) = " << findDistance(root, 2, 4);    cout << "\nDist(8, 5) = " << findDistance(root, 8, 5);    return 0;} |
Java
// A Java Program to find distance between // n1 and n2 using one traversal class GFG{Â Â Â // (To the moderator) in Java solution these// variables are declared as pointers hence //changes made to them reflects in the whole program Â
// Global static variable  static int d1 = -1; static int d2 = -1; static int dist = 0;Â
// A Binary Tree Node static class Node{Â Â Â Â Â Node left, right;Â Â Â Â Â int key;Â
    // constructor      Node(int key)    {        this.key = key;        left = null;        right = null;    }}Â
// Returns level of key k if it is present // in tree, otherwise returns -1  static int findLevel(Node root, int k,                                   int level){        // Base Case     if (root == null)    {        return -1;    }Â
    // If key is present at root, or in left     // subtree or right subtree, return true;     if (root.key == k)    {        return level;    }Â
    int l = findLevel(root.left, k, level + 1);    return (l != -1)? l : findLevel(root.right, k,                                         level + 1);}Â
// This function returns pointer to LCA of // two given values n1 and n2. It also sets// d1, d2 and dist if one key is not ancestor of other // d1 -. To store distance of n1 from root // d2 -. To store distance of n2 from root // lvl -. Level (or distance from root) of current node // dist -. To store distance between n1 and n2 public static Node findDistUtil(Node root, int n1, Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â int n2, int lvl){Â
    // Base case     if (root == null)    {        return null;    }Â
    // If either n1 or n2 matches with root's     // key, report the presence by returning     // root (Note that if a key is ancestor of     // other, then the ancestor key becomes LCA     if (root.key == n1)    {        d1 = lvl;        return root;    }    if (root.key == n2)    {        d2 = lvl;        return root;    }Â
    // Look for n1 and n2 in left and right subtrees     Node left_lca = findDistUtil(root.left, n1,                                    n2, lvl + 1);    Node right_lca = findDistUtil(root.right, n1,                                      n2, lvl + 1);Â
    // If both of the above calls return Non-null,     // then one key is present in once subtree and     // other is present in other, So this node is the LCA     if (left_lca != null && right_lca != null)    {        dist = (d1 + d2) - 2 * lvl;        return root;    }Â
    // Otherwise check if left subtree     // or right subtree is LCA     return (left_lca != null)? left_lca : right_lca;}Â
// The main function that returns distance // between n1 and n2. This function returns -1 // if either n1 or n2 is not present in // Binary Tree. public static int findDistance(Node root, int n1, int n2){Â Â Â Â d1 = -1;Â Â Â Â d2 = -1;Â Â Â Â dist = 0;Â Â Â Â Node lca = findDistUtil(root, n1, n2, 1);Â
    // If both n1 and n2 were present    // in Binary Tree, return dist     if (d1 != -1 && d2 != -1)    {        return dist;    }Â
    // If n1 is ancestor of n2, consider    // n1 as root and find level     // of n2 in subtree rooted with n1     if (d1 != -1)    {        dist = findLevel(lca, n2, 0);        return dist;    }Â
    // If n2 is ancestor of n1, consider     // n2 as root and find level     // of n1 in subtree rooted with n2     if (d2 != -1)    {        dist = findLevel(lca, n1, 0);        return dist;    }    return -1;}Â
// Driver Codepublic static void main(String[] args){Â
    // Let us create binary tree given    // in the above example     Node root = new Node(1);    root.left = new Node(2);    root.right = new Node(3);    root.left.left = new Node(4);    root.left.right = new Node(5);    root.right.left = new Node(6);    root.right.right = new Node(7);    root.right.left.right = new Node(8);    System.out.println("Dist(4, 5) = " +                        findDistance(root, 4, 5));    System.out.println("Dist(4, 6) = " +                        findDistance(root, 4, 6));    System.out.println("Dist(3, 4) = " +                        findDistance(root, 3, 4));    System.out.println("Dist(2, 4) = " +                        findDistance(root, 2, 4));    System.out.println("Dist(8, 5) = " +                        findDistance(root, 8, 5));}}Â
// This code is contributed by gauravrajput1 |
Python3
# Python Program to find distance between # n1 and n2 using one traversalÂ
class Node:    def __init__(self, data):        self.data = data        self.right = None        self.left = NoneÂ
def pathToNode(root, path, k):Â
    # base case handling    if root is None:        return FalseÂ
     # append the node value in path    path.append(root.data)      # See if the k is same as root's data    if root.data == k :        return True      # Check if k is found in left or right     # sub-tree    if ((root.left != None and pathToNode(root.left, path, k)) or            (root.right!= None and pathToNode(root.right, path, k))):        return True      # If not present in subtree rooted with root,     # remove root from path and return False     path.pop()    return FalseÂ
def distance(root, data1, data2):Â Â Â Â if root:Â Â Â Â Â Â Â Â # store path corresponding to node: data1Â Â Â Â Â Â Â Â path1 = []Â Â Â Â Â Â Â Â pathToNode(root, path1, data1)Â
        # store path corresponding to node: data2        path2 = []        pathToNode(root, path2, data2)Â
        # iterate through the paths to find the         # common path length        i=0        while i<len(path1) and i<len(path2):            # get out as soon as the path differs             # or any path's length get exhausted            if path1[i] != path2[i]:                break            i = i+1Â
        # get the path length by deducting the         # intersecting path length (or till LCA)        return (len(path1)+len(path2)-2*i)    else:        return 0Â
# Driver Code to test above functionsroot = Node(1)root.left = Node(2)root.right = Node(3)root.left.left = Node(4)root.right.right= Node(7)root.right.left = Node(6)root.left.right = Node(5)root.right.left.right = Node(8)Â
dist = distance(root, 4, 5)print ("Distance between node {} & {}: {}".format(4, 5, dist))Â
dist = distance(root, 4, 6)print ("Distance between node {} & {}: {}".format(4, 6, dist))Â
dist = distance(root, 3, 4)print ("Distance between node {} & {}: {}".format(3, 4, dist))Â
dist = distance(root, 2, 4)print ("Distance between node {} & {}: {}".format(2, 4, dist))Â
dist = distance(root, 8, 5)print ("Distance between node {} & {}: {}".format(8, 5, dist))Â
# This program is contributed by Aartee |
C#
// A C# Program to find distance between // n1 and n2 using one traversal using System;Â
class GFG{// (To the moderator) in c++ solution these// variables are declared as pointers hence //changes made to them reflects in the whole program Â
// Global static variable public static int d1 = -1;public static int d2 = -1;public static int dist = 0;Â
// A Binary Tree Node public class Node{Â Â Â Â public Node left, right;Â Â Â Â public int key;Â
    // constructor     public Node(int key)    {        this.key = key;        left = null;        right = null;    }}Â
// Returns level of key k if it is present // in tree, otherwise returns -1 public static int findLevel(Node root, int k,                                   int level){    // Base Case     if (root == null)    {        return -1;    }Â
    // If key is present at root, or in left     // subtree or right subtree, return true;     if (root.key == k)    {        return level;    }Â
    int l = findLevel(root.left, k, level + 1);    return (l != -1)? l : findLevel(root.right, k,                                         level + 1);}Â
// This function returns pointer to LCA of // two given values n1 and n2. It also sets// d1, d2 and dist if one key is not ancestor of other // d1 --> To store distance of n1 from root // d2 --> To store distance of n2 from root // lvl --> Level (or distance from root) of current node // dist --> To store distance between n1 and n2 public static Node findDistUtil(Node root, int n1, Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â int n2, int lvl){Â
    // Base case     if (root == null)    {        return null;    }Â
    // If either n1 or n2 matches with root's     // key, report the presence by returning     // root (Note that if a key is ancestor of     // other, then the ancestor key becomes LCA     if (root.key == n1)    {        d1 = lvl;        return root;    }    if (root.key == n2)    {        d2 = lvl;        return root;    }Â
    // Look for n1 and n2 in left and right subtrees     Node left_lca = findDistUtil(root.left, n1,                                    n2, lvl + 1);    Node right_lca = findDistUtil(root.right, n1,                                      n2, lvl + 1);Â
    // If both of the above calls return Non-NULL,     // then one key is present in once subtree and     // other is present in other, So this node is the LCA     if (left_lca != null && right_lca != null)    {        dist = (d1 + d2) - 2 * lvl;        return root;    }Â
    // Otherwise check if left subtree     // or right subtree is LCA     return (left_lca != null)? left_lca : right_lca;}Â
// The main function that returns distance // between n1 and n2. This function returns -1 // if either n1 or n2 is not present in // Binary Tree. public static int findDistance(Node root, int n1, int n2){Â Â Â Â d1 = -1;Â Â Â Â d2 = -1;Â Â Â Â dist = 0;Â Â Â Â Node lca = findDistUtil(root, n1, n2, 1);Â
    // If both n1 and n2 were present    // in Binary Tree, return dist     if (d1 != -1 && d2 != -1)    {        return dist;    }Â
    // If n1 is ancestor of n2, consider    // n1 as root and find level     // of n2 in subtree rooted with n1     if (d1 != -1)    {        dist = findLevel(lca, n2, 0);        return dist;    }Â
    // If n2 is ancestor of n1, consider     // n2 as root and find level     // of n1 in subtree rooted with n2     if (d2 != -1)    {        dist = findLevel(lca, n1, 0);        return dist;    }Â
    return -1;}Â
// Driver Codepublic static void Main(string[] args){Â
    // Let us create binary tree given    // in the above example     Node root = new Node(1);    root.left = new Node(2);    root.right = new Node(3);    root.left.left = new Node(4);    root.left.right = new Node(5);    root.right.left = new Node(6);    root.right.right = new Node(7);    root.right.left.right = new Node(8);Â
    Console.WriteLine("Dist(4, 5) = " +                        findDistance(root, 4, 5));    Console.WriteLine("Dist(4, 6) = " +                        findDistance(root, 4, 6));    Console.WriteLine("Dist(3, 4) = " +                        findDistance(root, 3, 4));    Console.WriteLine("Dist(2, 4) = " +                        findDistance(root, 2, 4));    Console.WriteLine("Dist(8, 5) = " +                        findDistance(root, 8, 5));}}Â
// This code is contributed by Shrikant13 |
Javascript
// A Javascript Program to find distance between// n1 and n2 using one traversalÂ
// (To the moderator) in Java solution these// variables are declared as pointers hence//changes made to them reflects in the whole programÂ
// Global static variable let d1 = -1; let d2 = -1; let dist = 0;Â
class Node{Â Â Â Â constructor(key) {Â Â Â Â Â Â Â this.left = null;Â Â Â Â Â Â Â this.right = null;Â Â Â Â Â Â Â this.key = key;Â Â Â Â }}Â
// Returns level of key k if it is present// in tree, otherwise returns -1function findLevel(root, k, level){Â
    // Base Case    if (root == null)    {        return -1;    }Â
    // If key is present at root, or in left    // subtree or right subtree, return true;    if (root.key == k)    {        return level;    }Â
    let l = findLevel(root.left, k, level + 1);    return (l != -1)? l : findLevel(root.right, k, level + 1);}Â
// This function returns pointer to LCA of// two given values n1 and n2. It also sets// d1, d2 and dist if one key is not ancestor of other// d1 -. To store distance of n1 from root// d2 -. To store distance of n2 from root// lvl -. Level (or distance from root) of current node// dist -. To store distance between n1 and n2function findDistUtil(root, n1, n2, lvl){Â
    // Base case    if (root == null)    {        return null;    }Â
    // If either n1 or n2 matches with root's    // key, report the presence by returning    // root (Note that if a key is ancestor of    // other, then the ancestor key becomes LCA    if (root.key == n1)    {        d1 = lvl;        return root;    }    if (root.key == n2)    {        d2 = lvl;        return root;    }Â
    // Look for n1 and n2 in left and right subtrees    let left_lca = findDistUtil(root.left, n1,                                   n2, lvl + 1);    let right_lca = findDistUtil(root.right, n1,                                     n2, lvl + 1);Â
    // If both of the above calls return Non-null,    // then one key is present in once subtree and    // other is present in other, So this node is the LCA    if (left_lca != null && right_lca != null)    {        dist = (d1 + d2) - 2 * lvl;        return root;    }Â
    // Otherwise check if left subtree    // or right subtree is LCA    return (left_lca != null)? left_lca : right_lca;}Â
// The main function that returns distance// between n1 and n2. This function returns -1// if either n1 or n2 is not present in// Binary Tree.function findDistance(root, n1, n2){Â Â Â Â d1 = -1;Â Â Â Â d2 = -1;Â Â Â Â dist = 0;Â Â Â Â let lca = findDistUtil(root, n1, n2, 1);Â
    // If both n1 and n2 were present    // in Binary Tree, return dist    if (d1 != -1 && d2 != -1)    {        return dist;    }Â
    // If n1 is ancestor of n2, consider    // n1 as root and find level    // of n2 in subtree rooted with n1    if (d1 != -1)    {        dist = findLevel(lca, n2, 0);        return dist;    }Â
    // If n2 is ancestor of n1, consider    // n2 as root and find level    // of n1 in subtree rooted with n2    if (d2 != -1)    {        dist = findLevel(lca, n1, 0);        return dist;    }    return -1;}Â
// Let us create binary tree given// in the above examplelet root = new Node(1);root.left = new Node(2);root.right = new Node(3);root.left.left = new Node(4);root.left.right = new Node(5);root.right.left = new Node(6);root.right.right = new Node(7);root.right.left.right = new Node(8);console.log("Dist(4, 5) = " +Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â findDistance(root, 4, 5) + "</br>");console.log("Dist(4, 6) = " +Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â findDistance(root, 4, 6) + "</br>");console.log("Dist(3, 4) = " +Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â findDistance(root, 3, 4) + "</br>");console.log("Dist(2, 4) = " +Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â findDistance(root, 2, 4) + "</br>");console.log("Dist(8, 5) = " +Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â findDistance(root, 8, 5) + "</br>"); |
Dist(4, 5) = 2 Dist(4, 6) = 4 Dist(3, 4) = 3 Dist(2, 4) = 1 Dist(8, 5) = 5
Time Complexity: O(n), As the method does a single tree traversal. Here n is the number of elements in the tree.
Auxiliary Space: O(h), Here h is the height of the tree and the extra space is used in recursion call stack.
Thanks to Atul Singh for providing the initial solution for this post.
Better Solution :Â
We first find the LCA of two nodes. Then we find the distance from LCA to two nodes.Â
C++
/* C++ Program to find distance between n1 and n2Â Â Â using one traversal */#include <iostream>using namespace std;Â
// A Binary Tree Nodestruct Node {Â Â Â Â struct Node *left, *right;Â Â Â Â int key;};Â
// Utility function to create a new tree NodeNode* newNode(int key){    Node* temp = new Node;    temp->key = key;    temp->left = temp->right = NULL;    return temp;}Node* LCA(Node* root, int n1, int n2){    // Your code here    if (root == NULL)        return root;    if (root->key == n1 || root->key == n2)        return root;Â
    Node* left = LCA(root->left, n1, n2);    Node* right = LCA(root->right, n1, n2);Â
    if (left != NULL && right != NULL)        return root;    if (left == NULL && right == NULL)        return NULL;    if (left != NULL)        return LCA(root->left, n1, n2);Â
    return LCA(root->right, n1, n2);}Â
// Returns level of key k if it is present in// tree, otherwise returns -1int findLevel(Node* root, int k, int level){Â Â Â Â if (root == NULL)Â Â Â Â Â Â Â Â return -1;Â Â Â Â if (root->key == k)Â Â Â Â Â Â Â Â return level;Â
    int left = findLevel(root->left, k, level + 1);    if (left == -1)        return findLevel(root->right, k, level + 1);    return left;}Â
int findDistance(Node* root, int a, int b){    // Your code here    Node* lca = LCA(root, a, b);Â
    int d1 = findLevel(lca, a, 0);    int d2 = findLevel(lca, b, 0);Â
    return d1 + d2;}Â
// Driver program to test above functionsint main(){    // Let us create binary tree given in    // the above example    Node* root = newNode(1);    root->left = newNode(2);    root->right = newNode(3);    root->left->left = newNode(4);    root->left->right = newNode(5);    root->right->left = newNode(6);    root->right->right = newNode(7);    root->right->left->right = newNode(8);    cout << "Dist(4, 5) = " << findDistance(root, 4, 5);    cout << "\nDist(4, 6) = " << findDistance(root, 4, 6);    cout << "\nDist(3, 4) = " << findDistance(root, 3, 4);    cout << "\nDist(2, 4) = " << findDistance(root, 2, 4);    cout << "\nDist(8, 5) = " << findDistance(root, 8, 5);    return 0;} |
Java
/* Java Program to find distance between n1 and n2Â Â Â using one traversal */public class GFG {Â
    public static class Node {        int value;        Node left;        Node right;Â
        public Node(int value) { this.value = value; }    }Â
    public static Node LCA(Node root, int n1, int n2)    {        if (root == null)            return root;        if (root.value == n1 || root.value == n2)            return root;Â
        Node left = LCA(root.left, n1, n2);        Node right = LCA(root.right, n1, n2);Â
        if (left != null && right != null)            return root;        if (left == null && right == null)            return null;        if (left != null)            return left;        else            return right;    }Â
    // Returns level of key k if it is present in    // tree, otherwise returns -1    public static int findLevel(Node root, int a, int level)    {        if (root == null)            return -1;        if (root.value == a)            return level;        int left = findLevel(root.left, a, level + 1);        if (left == -1)            return findLevel(root.right, a, level + 1);        return left;    }Â
    public static int findDistance(Node root, int a, int b)    {        Node lca = LCA(root, a, b);Â
        int d1 = findLevel(lca, a, 0);        int d2 = findLevel(lca, b, 0);Â
        return d1 + d2;    }Â
    // Driver program to test above functions    public static void main(String[] args)    {Â
        // Let us create binary tree given in        // the above example        Node root = new Node(1);        root.left = new Node(2);        root.right = new Node(3);        root.left.left = new Node(4);        root.left.right = new Node(5);        root.right.left = new Node(6);        root.right.right = new Node(7);        root.right.left.right = new Node(8);        System.out.println("Dist(4, 5) = "                           + findDistance(root, 4, 5));Â
        System.out.println("Dist(4, 6) = "                           + findDistance(root, 4, 6));Â
        System.out.println("Dist(3, 4) = "                           + findDistance(root, 3, 4));Â
        System.out.println("Dist(2, 4) = "                           + findDistance(root, 2, 4));Â
        System.out.println("Dist(8, 5) = "                           + findDistance(root, 8, 5));    }}Â
// This code is contributed by Srinivasan Jayaraman. |
Python3
"""A python program to find distance between n1and n2 in binary tree"""# binary tree nodeÂ
Â
class Node:    # Constructor to create new node    def __init__(self, data):        self.data = data        self.left = self.right = NoneÂ
Â
# This function returns pointer to LCA of# two given values n1 and n2.def find_least_common_ancestor(root: Node, n1: int, n2: int) -> Node:Â
    # Base case    if root is None:        return rootÂ
    # If either n1 or n2 matches with root's    # key, report the presence by returning root    if root.data == n1 or root.data == n2:        return rootÂ
    # Look for keys in left and right subtrees    left = find_least_common_ancestor(root.left, n1, n2)    right = find_least_common_ancestor(root.right, n1, n2)Â
    if left and right:        return rootÂ
    # Otherwise check if left subtree or    # right subtree is Least Common Ancestor    if left:        return left    else:        return rightÂ
# function to find distance of any node# from rootÂ
Â
def find_distance_from_ancestor_node(root: Node, data: int) -> int:Â
    # case when we reach a beyond leaf node    # or when tree is empty    if root is None:        return -1Â
    # Node is found then return 0    if root.data == data:        return 0Â
    left = find_distance_from_ancestor_node(root.left, data)    right = find_distance_from_ancestor_node(root.right, data)    distance = max(left, right)    return distance+1 if distance >= 0 else -1Â
# function to find distance between two# nodes in a binary treeÂ
Â
def find_distance_between_two_nodes(root: Node, n1: int, n2: int):Â
    lca = find_least_common_ancestor(root, n1, n2)Â
    return find_distance_from_ancestor_node(lca, n1) + find_distance_from_ancestor_node(lca, n2) if lca else -1Â
Â
# Driver program to test above functionroot = Node(1)root.left = Node(2)root.right = Node(3)root.left.left = Node(4)root.left.right = Node(5)root.right.left = Node(6)root.right.right = Node(7)root.right.left.right = Node(8)Â
print("Dist(4,5) = ", find_distance_between_two_nodes(root, 4, 5))print("Dist(4,6) = ", find_distance_between_two_nodes(root, 4, 6))print("Dist(3,4) = ", find_distance_between_two_nodes(root, 3, 4))print("Dist(2,4) = ", find_distance_between_two_nodes(root, 2, 4))print("Dist(8,5) = ", find_distance_between_two_nodes(root, 8, 5))Â
# This article is contributed by Shweta Singh.# This article is improved by Sreeramachandra |
C#
using System;Â
/* C# Program to find distance between n1 and n2Â Â Â using one traversal */public class GFG {Â
    public class Node {        public int value;        public Node left;        public Node right;Â
        public Node(int value) { this.value = value; }    }Â
    public static Node LCA(Node root, int n1, int n2)    {        if (root == null) {            return root;        }        if (root.value == n1 || root.value == n2) {            return root;        }Â
        Node left = LCA(root.left, n1, n2);        Node right = LCA(root.right, n1, n2);Â
        if (left != null && right != null) {            return root;        }        if (left == null && right == null) {            return null;        }Â
        if (left != null) {            return LCA(root.left, n1, n2);        }        else {            return LCA(root.right, n1, n2);        }    }Â
    // Returns level of key k if it is present in    // tree, otherwise returns -1    public static int findLevel(Node root, int a, int level)    {        if (root == null) {            return -1;        }        if (root.value == a) {            return level;        }        int left = findLevel(root.left, a, level + 1);        if (left == -1) {            return findLevel(root.right, a, level + 1);        }        return left;    }Â
    public static int findDistance(Node root, int a, int b)    {        Node lca = LCA(root, a, b);Â
        int d1 = findLevel(lca, a, 0);        int d2 = findLevel(lca, b, 0);Â
        return d1 + d2;    }Â
    // Driver program to test above functions    public static void Main(string[] args)    {Â
        // Let us create binary tree given in        // the above example        Node root = new Node(1);        root.left = new Node(2);        root.right = new Node(3);        root.left.left = new Node(4);        root.left.right = new Node(5);        root.right.left = new Node(6);        root.right.right = new Node(7);        root.right.left.right = new Node(8);        Console.WriteLine("Dist(4, 5) = "                          + findDistance(root, 4, 5));Â
        Console.WriteLine("Dist(4, 6) = "                          + findDistance(root, 4, 6));Â
        Console.WriteLine("Dist(3, 4) = "                          + findDistance(root, 3, 4));Â
        Console.WriteLine("Dist(2, 4) = "                          + findDistance(root, 2, 4));Â
        Console.WriteLine("Dist(8, 5) = "                          + findDistance(root, 8, 5));    }}Â
//Â This code is contributed by Shrikant13 |
Javascript
<script>Â
// JavaScript Program to find distance // between n1 and n2 using one traversalclass Node {Â Â Â Â constructor(value)Â Â Â Â {Â Â Â Â Â Â Â Â this.value = value;Â Â Â Â Â Â Â Â this.left = null;Â Â Â Â Â Â Â Â this.right = null;Â Â Â Â }}Â
function LCA(root, n1, n2) {    if (root == null)     {        return root;    }    if (root.value == n1 || root.value == n2)    {        return root;    }         var left = LCA(root.left, n1, n2);    var right = LCA(root.right, n1, n2);         if (left != null && right != null)     {        return root;    }    if (left == null && right == null)    {        return null;    }         if (left != null)    {        return LCA(root.left, n1, n2);    }     else    {        return LCA(root.right, n1, n2);    }}Â
// Returns level of key k if it is present in// tree, otherwise returns -1function findLevel(root, a, level) {Â Â Â Â if (root == null)Â Â Â Â {Â Â Â Â Â Â Â Â return -1;Â Â Â Â }Â Â Â Â if (root.value == a) Â Â Â Â {Â Â Â Â Â Â Â Â return level;Â Â Â Â }Â Â Â Â var left = findLevel(root.left, a, level + 1);Â Â Â Â Â Â Â Â Â if (left == -1)Â Â Â Â {Â Â Â Â Â Â Â Â return findLevel(root.right, a, level + 1);Â Â Â Â }Â Â Â Â return left;}Â
function findDistance(root, a, b){Â Â Â Â var lca = LCA(root, a, b);Â Â Â Â var d1 = findLevel(lca, a, 0);Â Â Â Â var d2 = findLevel(lca, b, 0);Â Â Â Â Â Â Â Â Â return d1 + d2;}Â
// Driver codeÂ
// Let us create binary tree given in// the above examplevar root = new Node(1);root.left = new Node(2);root.right = new Node(3);root.left.left = new Node(4);root.left.right = new Node(5);root.right.left = new Node(6);root.right.right = new Node(7);root.right.left.right = new Node(8);Â
document.write("Dist(4, 5) = " + Â Â Â Â Â Â Â Â Â Â Â Â Â Â findDistance(root, 4, 5) + "<br>");document.write("Dist(4, 6) = " + Â Â Â Â Â Â Â Â Â Â Â Â Â Â findDistance(root, 4, 6) + "<br>");document.write("Dist(3, 4) = " + Â Â Â Â Â Â Â Â Â Â Â Â Â Â findDistance(root, 3, 4) + "<br>");Â
document.write("Dist(2, 4) = " + Â Â Â Â Â Â Â Â Â Â Â Â Â Â findDistance(root, 2, 4) + "<br>");Â
document.write("Dist(8, 5) = " + Â Â Â Â Â Â Â Â Â Â Â Â Â Â findDistance(root, 8, 5) + "<br>");Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â // This code is contributed by rdtankÂ
</script> |
Dist(4, 5) = 2 Dist(4, 6) = 4 Dist(3, 4) = 3 Dist(2, 4) = 1 Dist(8, 5) = 5
Time Complexity: O(n), As the method does a single tree traversal. Here n is the number of elements in the tree.
Auxiliary Space: O(h), Here h is the height of the tree and the extra space is used in recursion call stack.
Thanks to NILMADHAB MONDAL for suggesting this solution.
Another Better Solution (one pass): Â
We know that distance between two node(let suppose n1 and n2) = distance between LCA and n1 + distance between LCA and n2.
A general solution using above formula that may come to your mind is  : Â
int findDistance(Node* root, int n1, int n2) {
if (!root) return 0;
if (root->data == n1 || root->data == n2)
 return 1;
int left = findDistance(root->left, n1, n2);
int right = findDistance(root->right, n1, n2);
if (left  && right)
 return left + right;
else if (left || right)
 return max(left, right) + 1;
return 0;
}
But this solution has a flaw (a missing edge case)  when  n2 is Descendant of n1 or n1 is Descendant of n2.
Below is dry run of above code with edge case example :
In the above binary tree expected output is 2 but the function will give output as 3. This situation is overcome in the solution code given below :
Note : both n1 and n2 should be present in Binary Tree.
C++
/* C++ Program to find distance between n1 and n2Â Â Â using one traversal */#include <iostream>using namespace std;Â
// A Binary Tree Nodestruct Node {Â Â Â Â struct Node *left, *right;Â Â Â Â int key;};Â
// Utility function to create a new tree NodeNode* newNode(int key){    Node* temp = new Node;    temp->key = key;    temp->left = temp->right = NULL;    return temp;}//Global variable to store distance//between n1 and n2.int ans;//Function that finds distance between two node.int _findDistance(Node* root, int n1, int n2){    if (!root) return 0;    int left = _findDistance(root->left, n1, n2);    int right = _findDistance(root->right, n1, n2);      //if any node(n1 or n2) is found    if (root->key == n1 || root->key == n2)    {             //check if there is any descendant(n1 or n2)           //if descendant exist then distance between descendant          //and current root will be our answer.        if (left || right)        {            ans = max(left, right);            return 0;        }        else            return 1;    }      //if current root is LCA of n1 and n2.    else if (left && right)    {           ans = left + right;        return 0;    }      //if there is a descendant(n1 or n2).    else if (left || right)          //increment its distance        return max(left, right) + 1;    //if neither n1 nor n2 exist as descendant.    return 0;}// The main function that returns distance between n1// and n2.int findDistance(Node* root, int n1, int n2){    ans = 0;    _findDistance(root, n1, n2);    return ans;}Â
// Driver program to test above functionsint main(){    // Let us create binary tree given in    // the above example    Node* root = newNode(1);    root->left = newNode(2);    root->right = newNode(3);    root->left->left = newNode(4);    root->left->right = newNode(5);    root->right->left = newNode(6);    root->right->right = newNode(7);    root->right->left->right = newNode(8);    cout << "Dist(4, 5) = " << findDistance(root, 4, 5);    cout << "\nDist(4, 6) = " << findDistance(root, 4, 6);    cout << "\nDist(3, 4) = " << findDistance(root, 3, 4);    cout << "\nDist(2, 4) = " << findDistance(root, 2, 4);    cout << "\nDist(8, 5) = " << findDistance(root, 8, 5);    return 0;} |
Java
/* Java Program to find distance between n1 and n2Â Â Â using one traversal */public class Main {Â
    public static class Node {        int value;        Node left;        Node right;Â
        public Node(int value) { this.value = value; }    }    //variable to store distance    //between n1 and n2.    public static int ans;    //Function that finds distance between two node.    public static int _findDistance(Node root, int n1, int n2)    {        if (root == null) return 0;        int left = _findDistance(root.left, n1, n2);        int right = _findDistance(root.right, n1, n2);        //if any node(n1 or n2) is found        if (root.value == n1 || root.value == n2)        {                 //check if there is any descendant(n1 or n2)             //if descendant exist then distance between descendant            //and current root will be our answer.            if (left != 0 || right != 0)            {                ans = Math.max(left, right);                return 0;            }            else                return 1;        }          //if current root is LCA of n1 and n2.        else if (left != 0 && right != 0)        {            ans = left + right;            return 0;        }          //if there is a descendant(n1 or n2).        else if (left != 0 || right != 0)              //increment its distance            return Math.max(left, right) + 1;        //if neither n1 nor n2 exist as descendant.        return 0;    }    // The main function that returns distance between n1    // and n2.    public static int findDistance(Node root, int n1, int n2)    {        ans = 0;        _findDistance(root, n1, n2);        return ans;    }Â
    // Driver program to test above functions    public static void main(String[] args)    {Â
        // Let us create binary tree given in        // the above example        Node root = new Node(1);        root.left = new Node(2);        root.right = new Node(3);        root.left.left = new Node(4);        root.left.right = new Node(5);        root.right.left = new Node(6);        root.right.right = new Node(7);        root.right.left.right = new Node(8);        System.out.println("Dist(4, 5) = "                           + findDistance(root, 4, 5));Â
        System.out.println("Dist(4, 6) = "                           + findDistance(root, 4, 6));Â
        System.out.println("Dist(3, 4) = "                           + findDistance(root, 3, 4));Â
        System.out.println("Dist(2, 4) = "                           + findDistance(root, 2, 4));Â
        System.out.println("Dist(8, 5) = "                           + findDistance(root, 8, 5));    }} |
Python3
# Python Program to find distance between n1 and n2 using one traversalÂ
# A Binary Tree Nodeclass Node:    def __init__(self, key):        self.key = key        self.left = None        self.right = NoneÂ
# Global variable to store distance between n1 and n2.ans = 0Â
# Function that finds distance between two node.def _findDistance(root, n1, n2):    global ans    if not root:        return 0    left = _findDistance(root.left, n1, n2)    right = _findDistance(root.right, n1, n2)         # if any node(n1 or n2) is found    if root.key == n1 or root.key == n2:                 # check if there is any descendant(n1 or n2)          # if descendant exist then distance between descendant and current root will be our answer.        if left or right:            ans = max(left, right)            return 0        else:            return 1               # if current root is LCA of n1 and n2.    elif left and right:        ans = left + right        return 0           # if there is a descendant(n1 or n2).    elif left or right:               # increment its distance        return max(left, right) + 1Â
    # if neither n1 nor n2 exist as descendant.    return 0Â
# The main function that returns distance between n1 and n2.def findDistance(root, n1, n2):Â Â Â Â _findDistance(root, n1, n2)Â Â Â Â return ansÂ
# Driver program to test above functionsif __name__ == '__main__':Â
    # Let us create binary tree given in the above example    root = Node(1)    root.left = Node(2)    root.right = Node(3)    root.left.left = Node(4)    root.left.right = Node(5)    root.right.left = Node(6)    root.right.right = Node(7)    root.right.left.right = Node(8)    print("Dist(4, 5) =", findDistance(root, 4, 5))    print("Dist(4, 6) =", findDistance(root, 4, 6))    print("Dist(3, 4) =", findDistance(root, 3, 4))    print("Dist(2, 4) =", findDistance(root, 2, 4))    print("Dist(8, 5) =", findDistance(root, 8, 5))Â
# This code is contributed by Tapesh(tapeshdua420) |
C#
/* C# Program to find distance between n1 and n2Â Â Â using one traversal */using System;Â
class Program {Â
    public class Node {        public int value;        public Node left;        public Node right;Â
        public Node(int value) { this.value = value; }    }Â
    // variable to store distance between n1 and n2.    public static int ans = 0;Â
    // Function that finds distance between two node.    public static int _findDistance(Node root, int n1,                                    int n2)    {Â
        if (root == null)            return 0;        int left = _findDistance(root.left, n1, n2);        int right = _findDistance(root.right, n1, n2);        // if any node(n1 or n2) is found        if (root.value == n1 || root.value == n2) {            // check if there is any descendant(n1 or n2)            // if descendant exist then distance between            // descendant and current root will be our            // answer.            if (left != 0 || right != 0) {                ans = Math.Max(left, right);                return 0;            }            else                return 1;        }        // if current root is LCA of n1 and n2.        else if (left != 0 && right != 0) {            ans = left + right;            return 0;        }        // if there is a descendant(n1 or n2).        else if (left != 0 || right != 0)            // increment its distance            return Math.Max(left, right) + 1;        // if neither n1 nor n2 exist as descendant.        return 0;    }    // The main function that returns distance between n1    // and n2.    public static int findDistance(Node root, int n1,                                   int n2)    {        ans = 0;        _findDistance(root, n1, n2);        return ans;    }    // Driver program to test above functions    public static void Main(string[] args)    {        // Let us create binary tree given in the above        // example        Node root = new Node(1);        root.left = new Node(2);        root.right = new Node(3);        root.left.left = new Node(4);        root.left.right = new Node(5);        root.right.left = new Node(6);        root.right.right = new Node(7);        root.right.left.right = new Node(8);Â
        Console.WriteLine("Dist({0}, {1}) = {2}", 4, 5,                          findDistance(root, 4, 5));        Console.WriteLine("Dist({0}, {1}) = {2}", 4, 6,                          findDistance(root, 4, 6));        Console.WriteLine("Dist({0}, {1}) = {2}", 3, 4,                          findDistance(root, 3, 4));        Console.WriteLine("Dist({0}, {1}) = {2}", 2, 4,                          findDistance(root, 2, 4));        Console.WriteLine("Dist({0}, {1}) = {2}", 8, 5,                          findDistance(root, 8, 5));    }}Â
// This code is contributed by Tapesh(tapeshdua420) |
Javascript
<script>class Node {constructor(key) {this.key = key;this.left = null;this.right = null;}}Â
// Global variable to store distance between n1 and n2.let ans = 0;Â
// Function that finds distance between two node.function _findDistance(root, n1, n2) {if (!root) {return 0;}let left = _findDistance(root.left, n1, n2);let right = _findDistance(root.right, n1, n2);Â
// if any node(n1 or n2) is foundif (root.key === n1 || root.key === n2) {    // check if there is any descendant(n1 or n2)    // if descendant exist then distance between descendant and current root will be our answer.    if (left || right) {        ans = Math.max(left, right);        return 0;    } else {        return 1;    }}// if current root is LCA of n1 and n2.else if (left && right) {    ans = left + right;    return 0;}// if there is a descendant(n1 or n2).else if (left || right) {    // increment its distance    return Math.max(left, right) + 1;}// if neither n1 nor n2 exist as descendant.return 0;}Â
// The main function that returns distance between n1 and n2.function findDistance(root, n1, n2) {_findDistance(root, n1, n2);return ans;}Â
// Driver program let root = new Node(1);root.left = new Node(2);root.right = new Node(3);root.left.left = new Node(4);root.left.right = new Node(5);root.right.left = new Node(6);root.right.right = new Node(7);root.right.left.right = new Node(8);document.write("Dist(4, 5) = ", findDistance(root, 4, 5));document.write("Dist(4, 6) = ", findDistance(root, 4, 6));document.write("Dist(3, 4) = ", findDistance(root, 3, 4));document.write("Dist(2, 4) = ", findDistance(root, 2, 4));document.write("Dist(8, 5) = ", findDistance(root, 8, 5));//This code is contributed by Potta Lokesh</script> |
Dist(4, 5) = 2 Dist(4, 6) = 4 Dist(3, 4) = 3 Dist(2, 4) = 1 Dist(8, 5) = 5
Time Complexity: O(n), where n is the number of nodes in the binary tree.
Auxiliary Space: O(h), where h is the height of the binary tree.
Thanks to Gurudev Singh for suggesting this solution.
Please write comments if you find anything incorrect, or you want to share more information about the topic discussed above
Â
Ready to dive in? Explore our Free Demo Content and join our DSA course, trusted by over 100,000 zambiatek!




