import {
  clone,
  difference,
  each,
  filter,
  findIndex,
  flatten,
  flattenDeep,
  groupBy,
  indexOf,
  map,
  reverse,
  some,
  sortBy,
  uniq,
  uniqWith,
  values,
  without,
} from 'lodash-es';
import StaticDisjointSet from 'mnemonist/static-disjoint-set';

const disjoint = (sets: string[][]) => {
  const members = uniq(flatten(sets));
  let categories = [];

  const disjointForest = new StaticDisjointSet(members.length);

  each(sets, (set) => {
    each(set, (member) => {
      each(set, (other) => {
        disjointForest.union(indexOf(members, member), indexOf(members, other));
      });
    });
  });
  categories = values(groupBy(members, (member) => disjointForest.find(indexOf(members, member))));

  return categories;
};

interface RankedMember {
  member: string;
  rank: number;
}

const rankMembersOf = (sets: string[][]) => {
  const ranked: RankedMember[] = map(uniq(flatten(sets)), (member) => ({ member, rank: 0 }));
  each(sets, (set) => {
    each(set, (member) => {
      const memberIndex = findIndex(ranked, ['member', member]);
      if (memberIndex !== -1) {
        // @ts-expect-error checking for undefined above
        ranked[memberIndex].rank += 1;
      }
    });
  });
  return reverse(sortBy(ranked, ['rank']));
};

const lowPass = (sets: string[][], rankedMembers: RankedMember[]) => {
  let filteredSets = sets;
  each(rankedMembers, (rankedMember) => {
    const singlePassFilteredSets = map(filteredSets, (set) => {
      return without(set, rankedMember.member);
    });
    if (!some(singlePassFilteredSets, (set) => set.length === 0)) {
      filteredSets = singlePassFilteredSets;
    }
  });
  return uniqWith(filteredSets, (a, b) => difference(a, b).length === 0);
};

const highPass = (sets: string[][], rankedMembers: RankedMember[]) => {
  let filteredSets = sets;
  each(reverse(clone(rankedMembers)), (rankedMember) => {
    const singlePassFilteredSets = map(filteredSets, (set) => {
      return without(set, rankedMember.member);
    });
    if (!some(singlePassFilteredSets, (set) => set.length === 0)) {
      filteredSets = singlePassFilteredSets;
    }
  });
  return uniqWith(filteredSets, (a, b) => difference(a, b).length === 0);
};

const removeSections = (labelSets: string[][], hiPass: string[][]) => {
  const flattenMembers = uniq(flattenDeep(hiPass));
  return filter(
    map(labelSets, (set) => {
      return difference(set, flattenMembers);
    }),
    (set) => set.length > 1,
  );
};

const CategorizedLabels = (labelSets: string[][]) => {
  const disjointed = disjoint(labelSets);
  const ranked = rankMembersOf(labelSets);
  const hiPass = highPass(labelSets, ranked);
  const disjointHiPass = disjoint(hiPass);
  const filteredOutHiPass = removeSections(labelSets, hiPass);
  const loPass = lowPass(filteredOutHiPass, ranked);
  const disjointLoPass = disjoint(loPass);

  return {
    ranked,
    disjointed,
    hiPass,
    disjointHiPass,
    loPass,
    disjointLoPass,
  };
};

export default CategorizedLabels;
