二叉树的最低公共祖先(路径回溯法)

LeetCode 236 题解笔记

Posted by BY on April 26, 2026

原文先后给了两版”路径栈 + 比较”的写法。这里把题意、概念和两版差异说明一下,代码保留原样以便对照。

题目背景

LeetCode 236「二叉树的最低公共祖先」。给定一棵普通(非 BST)二叉树和两个节点 pq,求深度最大的同时是 pq 祖先的节点。所有节点值唯一,且 pq 一定都存在于树中。

概念解释

  • 最低公共祖先 (LCA):在所有同时为 pq 祖先的节点中,深度最大的那一个。任何节点也是它自己的祖先。
  • 路径回溯:DFS 过程中维护一条”从根到当前节点”的路径栈,进入子树时入栈,离开时出栈,正好对应递归的入栈/退栈过程。
  • 回溯剪枝(flag 优化):找到目标后立刻打个标记,让上层 DFS 不再继续扫描兄弟子树,节省无用搜索。

实现原理

总体思路是先分别求出 根→p根→q 两条路径,再在两条路径上找最深的公共节点。

第一版:

  1. search_tree 用 DFS 在路径栈 s 中维护当前路径;命中 target 时把整个 s 拷给 ret,作为”根到 target 的路径快照”。
  2. 拿到两条路径栈 ret1ret2 后,先把较深的一条 pop 到与另一条等高,再同步 pop 比对栈顶,遇到值相等的节点即为 LCA。

第二版(优化):

  • flag 表示”是否已经找到 target”。一旦命中,沿递归回溯时不再 s.pop(),让路径自然保留;同时其它分支的递归继续返回但不影响栈。
  • 这样省掉了拷贝 ret = s 的开销,并避免在已找到目标后还遍历无关子树。

时间复杂度 O(n),每个节点最多被访问常数次;空间复杂度 O(h),由路径栈与递归栈决定。

备注:标准答案多用「同时找 p、q 的递归一遍解」即可 O(n) 完成本题;这里保留原始的”先求两条路径再比较”写法,便于复盘当时的思考过程。

参考实现

/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode(int x) : val(x), left(NULL), right(NULL) {}
 * };
 */
class Solution {

private:
    // std::stack<TreeNode*> ret1;
    //回溯法
    //s 用来dfs遍历tree
    //ret 用来保存最终的结果,
    void search_tree(TreeNode* root, TreeNode* target, std::stack<TreeNode*>& s,
        std::stack<TreeNode*>& ret)
    {
        if(root == nullptr)
        {
            return;
        }

        s.emplace(root);

        if(root == target)
        {
            // ret1 = s;
            ret = s;
            return;
        }else
        {
            if(root->left != nullptr)
            {
                search_tree(root->left, target, s, ret);
            }

            if(root->right != nullptr)
            {
                search_tree(root->right, target, s, ret);
            }
            s.pop(); //回溯
        }

    }


public:
    TreeNode* lowestCommonAncestor(TreeNode* root, TreeNode* p, TreeNode* q) {

        std::stack<TreeNode*> s1;
        std::stack<TreeNode*> s2;

        std::stack<TreeNode*> ret1;
        std::stack<TreeNode*> ret2;

        search_tree(root, p, s1, ret1);
        search_tree(root, q, s2, ret2);

        while(!ret1.empty() && !ret2.empty())
        {
            // TreeNode* t = ret1.top();
            // std::cout << t->val << std::endl;
            // ret1.pop();

            if(ret1.size() > ret2.size())
            {
                ret1.pop();

            }else if(ret1.size() == ret2.size())
            {
                TreeNode* t1 = ret1.top();
                TreeNode* t2 = ret2.top();

                if(t1->val == t2->val)
                {
                    return t1;
                }else
                {
                    ret1.pop();
                    ret2.pop();
                }

            }else
            {
                ret2.pop();
            }

        }
        return nullptr;
    }
};

提交代码后,发现效果不是很好,需要进一步优化,想到使用flag,用来判断是否找到目标点,
如果,找到的话,则终止回溯过程,可以节省一些无用操作,优化如下:

/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode(int x) : val(x), left(NULL), right(NULL) {}
 * };
 */
class Solution {

private:
    // std::stack<TreeNode*> ret1;
    //回溯法
    //s 用来dfs遍历tree
    //ret 用来保存最终的结果,
    void search_tree(TreeNode* root, TreeNode* target, std::stack<TreeNode*>& s,
        bool& flag)
    {
        if(root == nullptr)
        {
            return;
        }

        s.emplace(root);

        if(root == target)
        {
            // ret1 = s;
            // ret = s;
            flag = true;
            return;
        }else
        {
            if(root->left != nullptr)
            {
                search_tree(root->left, target, s, flag);
            }

            if(root->right != nullptr)
            {
                search_tree(root->right, target, s, flag);
            }
            if(!flag)
            {
                s.pop(); //回溯
            }
        }

    }


public:
    TreeNode* lowestCommonAncestor(TreeNode* root, TreeNode* p, TreeNode* q) {

        std::stack<TreeNode*> s1;
        std::stack<TreeNode*> s2;

        // std::stack<TreeNode*> ret1;
        // std::stack<TreeNode*> ret2;
        bool flag1 = false;
        bool flag2 = false;

        search_tree(root, p, s1, flag1);
        search_tree(root, q, s2, flag2);

        while(!s1.empty() && !s2.empty())
        {
            // TreeNode* t = ret1.top();
            // std::cout << t->val << std::endl;
            // ret1.pop();

            if(s1.size() > s2.size())
            {
                s1.pop();

            }else if(s1.size() == s2.size())
            {
                TreeNode* t1 = s1.top();
                TreeNode* t2 = s2.top();

                if(t1->val == t2->val)
                {
                    return t1;
                }else
                {
                    s1.pop();
                    s2.pop();
                }

            }else
            {
                s2.pop();
            }

        }
        return nullptr;
    }
};