import { MarkType, Node, ResolvedPos } from "prosemirror-model";
import { EditorState, SelectionRange, TextSelection, Transaction } from "prosemirror-state";
import { EditorView } from "prosemirror-view";
import { schema } from "../../schema";

function markApplies(doc: Node, ranges: readonly SelectionRange[], type: MarkType): boolean {
  for (let i = 0; i < ranges.length; i++) {
    const { $from, $to } = ranges[i];
    let can = $from.depth === 0 ? doc.type.allowsMarkType(type) : false;
    doc.nodesBetween($from.pos, $to.pos, (node) => {
      if (can) return false;
      can = node.inlineContent && node.type.allowsMarkType(type);
    });
    if (can) return true;
  }
  return false;
}

function applyMarkToAllowedNodesInRange(
  tr: Transaction,
  atomNodeStartPositions: number[],
  fullRangeHasMark: boolean,
  $from: ResolvedPos,
  $to: ResolvedPos,
  markType: MarkType,
  attrs?:
    | {
        [key: string]: any;
      }
    | undefined,
): void {
  if (fullRangeHasMark) {
    tr.removeMark($from.pos, $to.pos, markType);
    return;
  }
  if (atomNodeStartPositions.length === 0) tr.addMark($from.pos, $to.pos, markType.create(attrs));
  tr.addMark($from.pos, atomNodeStartPositions[0], markType.create(attrs));
  tr.addMark(atomNodeStartPositions[atomNodeStartPositions.length - 1] + 1, $to.pos, markType.create(attrs));
  if (atomNodeStartPositions.length > 1)
    for (let i = 0; i < atomNodeStartPositions.length; i++) {
      if (i < atomNodeStartPositions.length - 1)
        tr.addMark(atomNodeStartPositions[i] + 1, atomNodeStartPositions[i + 1], markType.create(attrs));
    }
}

function fullTextRangeHasMark(doc: Node, from: number, to: number, markType: MarkType) {
  for (let i = 0; i < to - from; i++) {
    if (!doc.rangeHasMark(from + i, from + i + 1, markType)) {
      return false;
    }
  }
  return true;
}

function getRangeDetails(state: EditorState, $from: ResolvedPos, $to: ResolvedPos, markType: MarkType) {
  const atomNodeStartPositions: number[] = [];
  let fullRangeHasMark = true;
  state.doc.nodesBetween($from.pos, $to.pos, (node, pos) => {
    // edit this function if need to add marks to atoms down the line
    if ([schema.nodes.reference, schema.nodes.image].includes(node.type)) {
      atomNodeStartPositions.push(pos);
    } else if (
      node.type === schema.nodes.text &&
      node.marks.findIndex((mark) => [schema.marks.hashtag, schema.marks.link].includes(mark.type)) === -1
    ) {
      // check all text nodes that are not links or hashtags
      // if 1 or more characters does not have the mark we must addMark
      if (!fullTextRangeHasMark(state.doc, pos, pos + node.nodeSize, markType)) fullRangeHasMark = false;
    }
  });
  return { atomNodeStartPositions, fullRangeHasMark };
}

// https://github.com/ProseMirror/prosemirror-commands/blob/master/src/commands.js
// adapted from the toggleMark command in prosemirror-commands
export const toggleAllowedMarks =
  (
    markType: MarkType,
    attrs?:
      | {
          [key: string]: any;
        }
      | undefined,
  ) =>
  (state: EditorState, dispatch?: EditorView["dispatch"]): boolean => {
    const { empty, ranges } = state.selection;
    const $cursor = (state.selection as TextSelection).$cursor;
    if ((empty && !$cursor) || !markApplies(state.doc, ranges, markType)) return false;
    if (dispatch) {
      if ($cursor) {
        if (markType.isInSet(state.storedMarks || $cursor.marks())) dispatch(state.tr.removeStoredMark(markType));
        else dispatch(state.tr.addStoredMark(markType.create(attrs)));
      } else {
        const tr = state.tr;
        for (let i = 0; i < ranges.length; i++) {
          const { $from, $to } = ranges[i];
          const { atomNodeStartPositions, fullRangeHasMark } = getRangeDetails(state, $from, $to, markType);

          applyMarkToAllowedNodesInRange(tr, atomNodeStartPositions, fullRangeHasMark, $from, $to, markType, attrs);
        }
        dispatch(tr.scrollIntoView());
      }
    }
    return true;
  };

const clearableMarkTypes = [schema.marks.bold, schema.marks.italic, schema.marks.strikethrough, schema.marks.underline];

export const clearAllowedMarks = (state: EditorState, dispatch?: EditorView["dispatch"]): boolean => {
  const { empty, ranges } = state.selection;
  const tr = state.tr;
  const $cursor = (state.selection as TextSelection).$cursor;
  if ((empty && !$cursor) || !dispatch) return false;
  if (dispatch) {
    // Bat: I copied this from the code above without really understanding this usecase.
    if ($cursor) {
      clearableMarkTypes.forEach((m) => dispatch(tr.removeStoredMark(m)));
    } else {
      for (let i = 0; i < ranges.length; i++) {
        const { $from, $to } = ranges[i];
        clearableMarkTypes.forEach((m) => {
          tr.removeMark($from.pos, $to.pos, m);
        });
      }
    }
  }
  dispatch(tr.scrollIntoView());
  return true;
};
