/*
 * Decompiled with CFR 0.152.
 */
package org.jkiss.dbeaver.model.stm;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import org.antlr.v4.runtime.atn.ATN;
import org.antlr.v4.runtime.atn.ATNState;
import org.antlr.v4.runtime.atn.AtomTransition;
import org.antlr.v4.runtime.atn.RangeTransition;
import org.antlr.v4.runtime.atn.SetTransition;
import org.antlr.v4.runtime.atn.Transition;
import org.antlr.v4.runtime.misc.IntSet;
import org.antlr.v4.runtime.misc.Interval;
import org.antlr.v4.runtime.misc.IntervalSet;
import org.antlr.v4.runtime.tree.RuleNode;
import org.antlr.v4.runtime.tree.TerminalNode;
import org.jkiss.code.NotNull;
import org.jkiss.code.Nullable;
import org.jkiss.dbeaver.model.impl.sql.BasicSQLDialect;
import org.jkiss.dbeaver.model.lsm.sql.impl.syntax.SQLStandardParser;
import org.jkiss.dbeaver.model.sql.SQLDialect;
import org.jkiss.dbeaver.model.stm.STMTreeNode;
import org.jkiss.dbeaver.model.stm.STMTreeTermErrorNode;
import org.jkiss.dbeaver.model.stm.STMTreeTermNode;
import org.jkiss.dbeaver.model.stm.STMUtils;
import org.jkiss.dbeaver.utils.ListNode;
import org.jkiss.utils.Pair;

public class LSMInspections {
    private static final Pattern anyWordPattern = Pattern.compile("^\\w+$");
    private static final Set<Integer> KNOWN_IDENTIFIER_PART_TOKENS = Set.of(Integer.valueOf(202), Integer.valueOf(1), Integer.valueOf(208));
    public static final Set<Integer> KNOWN_SEPARATOR_TOKENS = Set.of(161, 162, 163, 164, 167, 168, 169, 170, 171, 172, 173, 174, 175, 177, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 191, 192);
    @NotNull
    private static final Set<String> knownReservedWords = new HashSet<String>(BasicSQLDialect.INSTANCE.getReservedWords());
    @NotNull
    private static final Set<Integer> presenceTestRules = Set.of(Integer.valueOf(45), Integer.valueOf(87), Integer.valueOf(19), Integer.valueOf(120), Integer.valueOf(133));
    @NotNull
    private static final Set<Integer> reachabilityTestRules = Set.of(Integer.valueOf(45), Integer.valueOf(87), Integer.valueOf(19), Integer.valueOf(33), Integer.valueOf(120), Integer.valueOf(113), Integer.valueOf(73));
    private static final Map<Integer, List<List<Integer>>> subtreeTests = Map.ofEntries(Map.entry(87, List.of(List.of(Integer.valueOf(279), Integer.valueOf(53)), List.of(Integer.valueOf(279), Integer.valueOf(112)))), Map.entry(45, List.of(List.of(Integer.valueOf(279), Integer.valueOf(121)))));
    @NotNull
    private static final Set<Integer> knownReservedWordsExcludeRules = Set.of(Integer.valueOf(45), Integer.valueOf(87), Integer.valueOf(19), Integer.valueOf(33), Integer.valueOf(281), Integer.valueOf(279), Integer.valueOf(278), Integer.valueOf(271), Integer.valueOf(95), Integer.valueOf(280));
    @NotNull
    private final SQLDialect dialect;
    @NotNull
    private final STMTreeNode root;
    @NotNull
    private final List<STMTreeNode> allTerms;
    private final List<STMTreeTermNode> allNonErrorTerms;
    private static final SyntaxInspectionResult offqueryInspectionResult = LSMInspections.prepareOffquerySyntaxInspectionInternal();

    public static boolean matchesAnyWord(String str) {
        return anyWordPattern.matcher(str).matches();
    }

    public LSMInspections(@NotNull SQLDialect dialect, @NotNull STMTreeNode root) {
        this.dialect = dialect;
        this.root = root;
        Pair<List<STMTreeNode>, List<STMTreeTermNode>> termLists = LSMInspections.prepareTerms(root);
        this.allTerms = (List)termLists.getFirst();
        this.allNonErrorTerms = (List)termLists.getSecond();
    }

    @NotNull
    public static SyntaxInspectionResult prepareOffquerySyntaxInspection() {
        return offqueryInspectionResult;
    }

    @NotNull
    public static SyntaxInspectionResult prepareOffquerySyntaxInspectionInternal() {
        ATN atn = SQLStandardParser._ATN;
        ListNode emptyStack = ListNode.of(null);
        ATNState initialState = (ATNState)atn.states.get(atn.ruleToStartState[0].stateNumber);
        return LSMInspections.inspectAbstractSyntaxAtState(null, (ListNode<Integer>)emptyStack, initialState);
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    @Nullable
    public SyntaxInspectionResult prepareAbstractSyntaxInspection(int position) {
        ATNState initialState;
        STMTreeNode subroot = this.root;
        ATN atn = SQLStandardParser._ATN;
        Interval range = subroot.getRealInterval();
        if (position < range.a) {
            return LSMInspections.prepareOffquerySyntaxInspection();
        }
        if (position > range.b) {
            if (this.allNonErrorTerms.size() <= 0) return SyntaxInspectionResult.EMPTY;
            int index = this.allNonErrorTerms.size() - 1;
            STMTreeTermNode node = this.allNonErrorTerms.get(index);
            subroot = node.getParentNode();
            initialState = ((ATNState)atn.states.get((int)node.getAtnState())).getTransitions()[0].target;
            return LSMInspections.inspectAbstractSyntaxAtTreeState(subroot, initialState);
        } else {
            int index = STMUtils.binarySearchByKey(this.allNonErrorTerms, t -> t.getRealInterval().a, position, Comparator.comparingInt(k -> k));
            if (index < 0) {
                index = ~index - 1;
            }
            STMTreeTermNode node = this.allNonErrorTerms.get(index);
            subroot = node.getParentNode();
            Interval nodeRange = node.getRealInterval();
            if (nodeRange.a <= position) {
                if (nodeRange.b >= position) {
                    if (KNOWN_SEPARATOR_TOKENS.contains(node.symbol.getType())) {
                        node = this.allNonErrorTerms.get(index - 1);
                        initialState = ((ATNState)atn.states.get((int)node.getAtnState())).getTransitions()[0].target;
                        subroot = node.getParentNode();
                        return LSMInspections.inspectAbstractSyntaxAtTreeState(subroot, initialState);
                    } else {
                        initialState = (ATNState)atn.states.get(node.getAtnState());
                    }
                    return LSMInspections.inspectAbstractSyntaxAtTreeState(subroot, initialState);
                } else {
                    initialState = ((ATNState)atn.states.get((int)node.getAtnState())).getTransitions()[0].target;
                }
                return LSMInspections.inspectAbstractSyntaxAtTreeState(subroot, initialState);
            } else if (index > 0) {
                node = this.allNonErrorTerms.get(index - 1);
                initialState = ((ATNState)atn.states.get((int)node.getAtnState())).getTransitions()[0].target;
                subroot = node.getParentNode();
                return LSMInspections.inspectAbstractSyntaxAtTreeState(subroot, initialState);
            } else {
                initialState = (ATNState)atn.states.get(subroot.getAtnState());
            }
        }
        return LSMInspections.inspectAbstractSyntaxAtTreeState(subroot, initialState);
    }

    @NotNull
    public static Pair<List<STMTreeNode>, List<STMTreeTermNode>> prepareTerms(@NotNull STMTreeNode root) {
        ArrayList<STMTreeNode> allTerms = new ArrayList<STMTreeNode>();
        ArrayList<STMTreeTermNode> allNonErrorTerms = new ArrayList<STMTreeTermNode>();
        ListNode stack = ListNode.of((Object)root);
        while (ListNode.hasAny((ListNode)stack)) {
            STMTreeNode node = (STMTreeNode)stack.data;
            stack = stack.next;
            if (node instanceof STMTreeTermNode) {
                STMTreeTermNode term = (STMTreeTermNode)node;
                allTerms.add(term);
                allNonErrorTerms.add(term);
                continue;
            }
            if (node instanceof STMTreeTermErrorNode) {
                STMTreeTermErrorNode err = (STMTreeTermErrorNode)node;
                allTerms.add(err);
                continue;
            }
            int i = node.getChildCount() - 1;
            while (i >= 0) {
                stack = ListNode.push((ListNode)stack, (Object)node.getChildNode(i));
                --i;
            }
        }
        return Pair.of(allTerms, allNonErrorTerms);
    }

    @Nullable
    private static SyntaxInspectionResult inspectAbstractSyntaxAtTreeState(@NotNull STMTreeNode node, @NotNull ATNState initialState) {
        STMTreeNode n;
        ListNode stack = ListNode.of(null);
        LinkedList<RuleNode> path = new LinkedList<RuleNode>();
        STMTreeNode sTMTreeNode = n = node instanceof TerminalNode ? node.getParentNode() : node;
        while (n instanceof RuleNode) {
            RuleNode rn = (RuleNode)n;
            path.addFirst(rn);
            n = n.getParentNode();
        }
        for (RuleNode rn : path) {
            stack = ListNode.push((ListNode)stack, (Object)rn.getRuleContext().getRuleIndex());
        }
        int atnStateIndex = node.getAtnState();
        if (atnStateIndex < 0) {
            return null;
        }
        return LSMInspections.inspectAbstractSyntaxAtState(node, (ListNode<Integer>)stack, initialState);
    }

    public NameInspectionResult collectNameNodes(int position) {
        STMTreeNode currentTerm;
        boolean hasPeriod;
        int positionToInspect;
        ArrayDeque<STMTreeNode> nameNodes;
        block13: {
            nameNodes = new ArrayDeque<STMTreeNode>();
            int index = STMUtils.binarySearchByKey(this.allTerms, t -> t.getRealInterval().a, position, Comparator.comparingInt(k -> k));
            if (index < 0) {
                index = ~index - 1;
            }
            positionToInspect = position;
            hasPeriod = false;
            currentTerm = null;
            if (index < 0) break block13;
            STMTreeNode immTerm = this.allTerms.get(index);
            if (immTerm.getRealInterval().a >= position) {
                if (index > 0) {
                    immTerm = this.allTerms.get(index - 1);
                    --index;
                } else {
                    immTerm = null;
                }
            }
            if (immTerm == null || !immTerm.getRealInterval().properlyContains(Interval.of((int)(position - 1), (int)(position - 1)))) break block13;
            if (anyWordPattern.matcher(immTerm.getTextContent()).matches()) {
                currentTerm = immTerm;
            }
            if (this.dialect.getReservedWords().contains(immTerm.getTextContent().toUpperCase())) {
                positionToInspect = immTerm.getRealInterval().a;
            }
            if (immTerm instanceof STMTreeTermNode) {
                STMTreeTermNode t2 = (STMTreeTermNode)immTerm;
                if (t2.symbol.getType() == 176) {
                    hasPeriod = true;
                    --index;
                }
            }
            int i = index;
            while (i >= 0) {
                STMTreeNode term;
                block15: {
                    block14: {
                        term = this.allTerms.get(i);
                        if (!(term instanceof STMTreeTermNode)) break block14;
                        STMTreeTermNode t3 = (STMTreeTermNode)term;
                        if (KNOWN_IDENTIFIER_PART_TOKENS.contains(t3.symbol.getType())) break block15;
                    }
                    if ((term.getParentNode() == null || term.getParentNode().getNodeKindId() != 281) && !(term instanceof STMTreeTermErrorNode)) break;
                }
                nameNodes.addFirst(term);
                if (--i < 0) break;
                STMTreeNode sTMTreeNode = this.allTerms.get(i);
                if (sTMTreeNode instanceof STMTreeTermNode) {
                    STMTreeTermNode t4 = (STMTreeTermNode)sTMTreeNode;
                    if (t4.symbol.getType() != 176) break;
                }
                --i;
            }
        }
        return new NameInspectionResult(nameNodes, hasPeriod, currentTerm, positionToInspect);
    }

    private static Map<Integer, Boolean> performPresenceTests(ListNode<Integer> stateStack) {
        HashMap<Integer, Boolean> presenceTests = new HashMap<Integer, Boolean>(presenceTestRules.size());
        presenceTestRules.forEach(n -> {
            Boolean bl = presenceTests.put((Integer)n, false);
        });
        for (Integer s : stateStack) {
            presenceTests.computeIfPresent(s, (k, v) -> true);
        }
        LSMInspections.performSubtreeTests(presenceTests, stateStack);
        return presenceTests;
    }

    @NotNull
    private static SyntaxInspectionResult inspectAbstractSyntaxAtState(@Nullable STMTreeNode node, @NotNull ListNode<Integer> stack, @NotNull ATNState initialState) {
        STMTreeTermNode term;
        HashSet<String> predictedWords = new HashSet<String>();
        HashSet<Integer> predictedTokenIds = new HashSet<Integer>();
        Map<Integer, Boolean> presenceTests = LSMInspections.performPresenceTests(stack);
        HashMap<Integer, Boolean> reachabilityTests = new HashMap<Integer, Boolean>(reachabilityTestRules.size());
        reachabilityTestRules.forEach(n -> {
            Boolean bl = reachabilityTests.put((Integer)n, false);
        });
        Collection<Transition> tt = LSMInspections.collectFollowingTerms(stack, initialState, knownReservedWordsExcludeRules, reachabilityTests);
        IntervalSet transitionTokens = LSMInspections.getTransitionTokens(tt);
        for (Interval interval : transitionTokens.getIntervals()) {
            int a = interval.a;
            int b = interval.b;
            int v = a;
            while (v <= b) {
                String word = SQLStandardParser.VOCABULARY.getDisplayName(v);
                if (word != null && knownReservedWords.contains(word)) {
                    predictedTokenIds.add(v);
                    predictedWords.add(word);
                }
                ++v;
            }
        }
        boolean expectingTableName = (Boolean)reachabilityTests.get(45) != false || presenceTests.get(45) != false;
        boolean expectingColumnName = (Boolean)reachabilityTests.get(33);
        boolean expectingColumnReference = (Boolean)reachabilityTests.get(87) != false || presenceTests.get(87) != false;
        return new SyntaxInspectionResult(predictedTokenIds, predictedWords, reachabilityTests, expectingTableName, expectingColumnName, expectingColumnReference, (Boolean)reachabilityTests.get(19) != false || presenceTests.get(19) != false, expectingTableName && ((Boolean)reachabilityTests.get(120) != false || presenceTests.get(120) != false), expectingColumnReference && (Boolean)reachabilityTests.get(113) != false, (Boolean)reachabilityTests.get(73), presenceTests.get(133) != false && node instanceof STMTreeTermNode && (term = (STMTreeTermNode)node).getSymbol().getType() == 93);
    }

    @NotNull
    private static IntervalSet getTransitionTokens(@NotNull Collection<Transition> transitions) {
        IntervalSet tokens = new IntervalSet(new int[0]);
        for (Transition transition : transitions) {
            switch (transition.getSerializationType()) {
                case 5: {
                    tokens.add(((AtomTransition)transition).label);
                    break;
                }
                case 2: {
                    RangeTransition t = (RangeTransition)transition;
                    tokens.add(t.from, t.to);
                    break;
                }
                case 7: {
                    tokens.addAll((IntSet)((SetTransition)transition).set);
                    break;
                }
                case 8: 
                case 9: {
                    break;
                }
                default: {
                    throw new UnsupportedOperationException("Unrecognized ATN transition type.");
                }
            }
        }
        return tokens;
    }

    private static String collectStack(ListNode<Integer> stack) {
        return StreamSupport.stream(stack.spliterator(), false).map(ss -> ss == null ? "<NULL>" : SQLStandardParser.ruleNames[ss]).collect(Collectors.joining(", "));
    }

    @NotNull
    private static Collection<Transition> collectFollowingTerms(@NotNull ListNode<Integer> stateStack, @NotNull ATNState initialState, Set<Integer> exceptRules, @NotNull Map<Integer, Boolean> reachabilityTests) {
        HashSet<Pair> visited = new HashSet<Pair>();
        HashSet<Transition> results = new HashSet<Transition>();
        LinkedList<Pair> q = new LinkedList<Pair>();
        q.addLast(Pair.of((Object)initialState, stateStack));
        while (q.size() > 0) {
            Pair pair = (Pair)q.removeLast();
            ATNState state = (ATNState)pair.getFirst();
            ListNode stack = (ListNode)pair.getSecond();
            Transition[] transitionArray = state.getTransitions();
            int n = transitionArray.length;
            int n2 = 0;
            while (n2 < n) {
                Transition transition = transitionArray[n2];
                block0 : switch (transition.getSerializationType()) {
                    case 2: 
                    case 5: 
                    case 7: 
                    case 8: 
                    case 9: {
                        results.add(transition);
                        break;
                    }
                    case 1: 
                    case 3: 
                    case 4: 
                    case 6: 
                    case 10: {
                        ListNode transitionStack;
                        switch (state.getStateType()) {
                            case 7: {
                                if (stack == null || stack.data == null || stack.next == null || stack.next.data == null || transition.target.ruleIndex != (Integer)stack.next.data) break block0;
                                transitionStack = stack.next;
                                break;
                            }
                            case 2: {
                                reachabilityTests.computeIfPresent(state.ruleIndex, (k, v) -> true);
                                transitionStack = ListNode.push((ListNode)stack, (Object)state.ruleIndex);
                                LSMInspections.performSubtreeTests(reachabilityTests, (ListNode<Integer>)transitionStack);
                                if (!exceptRules.contains(state.ruleIndex)) break;
                                break block0;
                            }
                            default: {
                                transitionStack = stack;
                            }
                        }
                        Pair nextState = Pair.of((Object)transition.target, (Object)transitionStack);
                        if (!visited.add(nextState)) break;
                        q.addLast(nextState);
                        break;
                    }
                    default: {
                        throw new UnsupportedOperationException("Unrecognized ATN transition type.");
                    }
                }
                ++n2;
            }
        }
        return results;
    }

    private static void performSubtreeTests(@NotNull Map<Integer, Boolean> reachabilityTest, ListNode<Integer> stack) {
        block0: for (Map.Entry<Integer, List<List<Integer>>> subtreeTest : subtreeTests.entrySet()) {
            block1: for (List<Integer> subpath : subtreeTest.getValue()) {
                ListNode stackItem = stack;
                for (Integer subpathNode : subpath) {
                    if (!subpathNode.equals(stackItem.data)) continue block1;
                    stackItem = stackItem.next;
                }
                reachabilityTest.computeIfPresent(subtreeTest.getKey(), (k, v) -> true);
                continue block0;
            }
        }
    }

    public record NameInspectionResult(ArrayDeque<STMTreeNode> nameNodes, boolean hasPeriod, STMTreeNode currentTerm, int positionToInspect) {
    }

    public record SyntaxInspectionResult(@NotNull Set<Integer> predictedTokenIds, @NotNull Set<String> predictedWords, @NotNull Map<Integer, Boolean> reachabilityTests, boolean expectingTableReference, boolean expectingColumnName, boolean expectingColumnReference, boolean expectingIdentifier, boolean expectingTableSourceIntroduction, boolean expectingColumnIntroduction, boolean expectingValue, boolean expectingJoinCondition) {
        public static final SyntaxInspectionResult EMPTY = new SyntaxInspectionResult(Collections.emptySet(), Collections.emptySet(), Collections.emptyMap(), false, false, false, false, false, false, false, false);

        @NotNull
        public Map<String, Boolean> getReachabilityByName() {
            return this.reachabilityTests.entrySet().stream().collect(Collectors.toMap(e -> SQLStandardParser.ruleNames[(Integer)e.getKey()], Map.Entry::getValue));
        }
    }
}

