#pragma once

#include "binder/expression/node_expression.h"
#include "logical_operator_visitor.h"
#include "planner/logical_plan/logical_plan.h"

namespace kuzu {
namespace optimizer {

class FilterPushDownOptimizer {
public:
    FilterPushDownOptimizer() { predicateSet = std::make_unique<PredicateSet>(); }

    void rewrite(planner::LogicalPlan* plan);

private:
    std::shared_ptr<planner::LogicalOperator> visitOperator(
        std::shared_ptr<planner::LogicalOperator> op);
    // Collect predicates in FILTER
    std::shared_ptr<planner::LogicalOperator> visitFilterReplace(
        std::shared_ptr<planner::LogicalOperator> op);
    // Push primary key lookup into CROSS_PRODUCT
    // E.g.
    //      Filter(a.ID=b.ID)
    //      CrossProduct                   to           IndexNestedLoopJoin(b)
    //   S(a)           S(b)                            S(a)
    // This is a temporary solution in the absence of a generic hash join operator.
    std::shared_ptr<planner::LogicalOperator> visitCrossProductReplace(
        std::shared_ptr<planner::LogicalOperator> op);

    // Push FILTER before SCAN_NODE_PROPERTY.
    // Push index lookup into SCAN_NODE_ID.
    std::shared_ptr<planner::LogicalOperator> visitScanNodePropertyReplace(
        std::shared_ptr<planner::LogicalOperator> op);

    // Rewrite SCAN_NODE_ID->SCAN_NODE_PROPERTY->FILTER as
    // SCAN_NODE_ID->(SCAN_NODE_PROPERTY->FILTER)*->SCAN_NODE_PROPERTY
    // so that filter with higher selectivity is applied before scanning.
    std::shared_ptr<planner::LogicalOperator> pushDownToScanNode(
        std::shared_ptr<binder::NodeExpression> node, std::shared_ptr<binder::Expression> predicate,
        std::shared_ptr<planner::LogicalOperator> child);

    // Finish the current push down optimization by apply remaining predicates as a single filter.
    // And heuristically reorder equality predicates first in the filter.
    std::shared_ptr<planner::LogicalOperator> finishPushDown(
        std::shared_ptr<planner::LogicalOperator> op);

    std::shared_ptr<planner::LogicalOperator> appendScanNodeProperty(
        std::shared_ptr<binder::NodeExpression> node, binder::expression_vector properties,
        std::shared_ptr<planner::LogicalOperator> child);
    std::shared_ptr<planner::LogicalOperator> appendFilter(
        std::shared_ptr<binder::Expression> predicate,
        std::shared_ptr<planner::LogicalOperator> child);

    struct PredicateSet {
        binder::expression_vector equalityPredicates;
        binder::expression_vector nonEqualityPredicates;

        inline bool isEmpty() const {
            return equalityPredicates.empty() && nonEqualityPredicates.empty();
        }
        inline void clear() {
            equalityPredicates.clear();
            nonEqualityPredicates.clear();
        }

        void addPredicate(std::shared_ptr<binder::Expression> predicate);
        std::shared_ptr<binder::Expression> popNodePKEqualityComparison(
            const binder::NodeExpression& node);
    };

private:
    std::unique_ptr<PredicateSet> predicateSet;
};

} // namespace optimizer
} // namespace kuzu
