#include "Common.hpp"
#include "ParseUtils.hpp"
#include "Cleanup.hpp"
#include "Protocol.hpp"
#include "Solution.hpp"
#include "Xoshiro.hpp"
#include "FastRandom.hpp"
#include "TimeUtils.hpp"
#include "ProcessId.hpp"
#include <array>
#include <minisat/core/Solver.h>

// sudo apt-get insntall zlib1g-dev

using std::string;
using std::vector;
using std::array;

typedef Xoshiro256StarStar Gen;

class PlanMove
{
public:
    PlanMove() : planIdx(0), moveIdx(0) {}
    PlanMove(uint64_t planIdx_, uint64_t moveIdx_) : planIdx(planIdx_), moveIdx(moveIdx_) {}

    bool operator==(const PlanMove& other) const
    {
        return planIdx == other.planIdx && moveIdx == other.moveIdx;
    }

    bool operator!=(const PlanMove& other) const
    {
        return planIdx != other.planIdx || moveIdx != other.moveIdx;
    }

    uint64_t planIdx = 0;
    uint64_t moveIdx = 0;
};

using Minisat::Var;
using Minisat::mkLit;
using Minisat::lbool;
typedef Minisat::vec<Minisat::Lit> LitVec;

void exactlyOneTrue(Minisat::Solver& solver,
                    const LitVec& literals)
{
    solver.addClause(literals);
    size_t size = literals.size();
    for (size_t i = 0; i < size; i++)
    {
        for (size_t j = 0; j < i; j++)
        {
            solver.addClause(~literals[i], ~literals[j]);
        }
    }
}

void usage()
{
    printf("Usage: sat <problem>\n");
    printf("  <problem>\n");
    printf("        Problem name\n");
    printf("Options:\n");
    printf("  -h    Print usage information and exit\n");
    printf("  -d    Dry run: don't actually submit\n");
}

int main(int argc, char *argv[])
{
    bool gotProblemName = false;

    bool help = false;
    bool dryRun = false;
    string problemName;

    int iArg = 1;
    while (iArg < argc)
    {
        string strArg = argv[iArg++];

        if (strArg == "-h" || strArg == "--help")
        {
            help = true;
        }
        else if (strArg == "-d")
        {
            dryRun = true;
        }
        else if (!gotProblemName)
        {
            problemName = strArg;
            gotProblemName = true;
        }
        else
        {
            usage();
            return 1;
        }
    }

    if (help)
    {
        usage();
        return 0;
    }

    if (!gotProblemName)
    {
        usage();
        return 1;
    }

    uint64_t size = 0;
    bool lightning = false;

    if (problemName == "probatio") { size = 3; lightning = true; }
    else if (problemName == "primus") { size = 6; lightning = true; }
    else if (problemName == "secundus") { size = 12; lightning = true; }
    else if (problemName == "tertius") { size = 18; lightning = true; }
    else if (problemName == "quartus") { size = 24; lightning = true; }
    else if (problemName == "quintus") { size = 30; lightning = true; }
    else if (problemName == "aleph") size = 12;
    else if (problemName == "beth") size = 24;
    else if (problemName == "gimel") size = 36;
    else if (problemName == "daleth") size = 48;
    else if (problemName == "he") size = 60;
    else if (problemName == "vau") size = 18;
    else if (problemName == "zain") size = 36;
    else if (problemName == "hhet") size = 54;
    else if (problemName == "teth") size = 72;
    else if (problemName == "iod") size = 90;
    else
    {
        printf("Unknown problem name\n");
        return 1;
    }

    printf("problem: %s\n", problemName.c_str());
    printf("size: %" PRIu64 "\n", size);

    Protocol::init();
    Cleanup cleanupProtocol([](){ Protocol::cleanup(); });

    string msg;

    if (!Protocol::select(problemName, &msg))
    {
        printf("%s\n", msg.c_str());
        return 1;
    }

    Gen gen;
#if 1
    {
        uint64_t timeMS = getTimeMS();
        uint32_t pid = getProcessId();
        uint32_t seed0 = 0xddf294ea;
        uint32_t seed1 = (uint32_t)timeMS;
        uint32_t seed2 = (uint32_t)pid ^ (uint32_t)(timeMS >> 32);
        std::seed_seq seq{seed0, seed1, seed2};
        gen.seed(seq);
    }
#endif
#if 0
    {
        uint32_t seed0 = 0xddf294ea;
        std::seed_seq seq{seed0};
        gen.seed(seq);
    }
#endif

    uint64_t planCount;
    uint64_t moveCount;

    if (lightning)
    {
        planCount = 1;
        moveCount = size * 18;
    }
    else
    {
        planCount = 3;
        moveCount = size * 6;
    }

    vector<vector<uint8_t>> moves(planCount);
    vector<vector<uint8_t>> marks(planCount);
    for (uint64_t planIdx = 0; planIdx < planCount; planIdx++)
    {
        moves[planIdx].resize(moveCount);
        marks[planIdx].resize(moveCount);
        for (uint64_t moveIdx = 0; moveIdx < moveCount; moveIdx++)
        {
            moves[planIdx][moveIdx] = fastRandom((uint32_t)6, gen);
            marks[planIdx][moveIdx] = fastRandom((uint32_t)4, gen);
        }
    }

    vector<string> plans(planCount);
    for (uint64_t planIdx = 0; planIdx < planCount; planIdx++)
    {
        auto& plan = plans[planIdx];
        for (uint64_t moveIdx = 0; moveIdx < moveCount; moveIdx++)
        {
#if 0
            plan += '[';
            plan += (char)('0' + marks[planIdx][moveIdx]);
            plan += ']';
#endif
            plan += (char)('0' + moves[planIdx][moveIdx]);
        }
    }

    for (auto& plan : plans)
    {
        printf("%s\n", plan.c_str());
    }

    vector<vector<uint8_t>> results;
    uint64_t queryCount;
    if (!Protocol::explore(plans, &results, &queryCount, &msg))
    {
        printf("%s\n", msg.c_str());
        return 1;
    }

    for (auto& result : results)
    {
        for (uint8_t value : result)
        {
            printf("%" PRIu8 "", value);
        }
        printf("\n");
    }
    printf("queryCount = %" PRIu64 "\n", queryCount);

    vector<vector<uint8_t>> values(planCount);
    for (uint64_t planIdx = 0; planIdx < planCount; planIdx++)
    {
        values[planIdx].resize(moveCount + 1);
        for (uint64_t moveIdx = 0; moveIdx <= moveCount; moveIdx++)
        {
            //values[planIdx][moveIdx] = results[planIdx][moveIdx * 2];
            values[planIdx][moveIdx] = results[planIdx][moveIdx];
        }
    }

    Minisat::Solver solver;

    vector<Var> connectionVars(size * 6 * size * 6);
    for (uint64_t roomIdxI = 0; roomIdxI < size; roomIdxI++)
    {
        for (uint64_t doorIdxI = 0; doorIdxI < 6; doorIdxI++)
        {
            {
                Var var = solver.newVar();
                connectionVars[(roomIdxI * 6 + doorIdxI) * (size * 6) + (roomIdxI * 6 + doorIdxI)] = var;
            }

            for (uint64_t roomIdxJ = roomIdxI + 1; roomIdxJ < size; roomIdxJ++)
            {
                for (uint64_t doorIdxJ = 0; doorIdxJ < 6; doorIdxJ++)
                {
                    Var var = solver.newVar();
                    connectionVars[(roomIdxI * 6 + doorIdxI) * (size * 6) + (roomIdxJ * 6 + doorIdxJ)] = var;
                    connectionVars[(roomIdxJ * 6 + doorIdxJ) * (size * 6) + (roomIdxI * 6 + doorIdxI)] = var;
                }
            }
        }
    }

    vector<Var> destVars(size * 6 * size);
    for (uint64_t roomIdxI = 0; roomIdxI < size; roomIdxI++)
    {
        for (uint64_t doorIdxI = 0; doorIdxI < 6; doorIdxI++)
        {
            for (uint64_t roomIdxJ = 0; roomIdxJ < size; roomIdxJ++)
            {
                Var var = solver.newVar();
                destVars[(roomIdxI * 6 + doorIdxI) * size + roomIdxJ] = var;
            }
        }
    }

    vector<Var> roomVars(planCount * (moveCount + 1) * size);
    for (uint64_t planIdx = 0; planIdx < planCount; planIdx++)
    {
        for (uint64_t moveIdx = 0; moveIdx <= moveCount; moveIdx++)
        {
            uint64_t value = values[planIdx][moveIdx];
            for (uint64_t roomIdx = value; roomIdx < size; roomIdx += 4)
            {
                Var var = solver.newVar();
                roomVars[(planIdx * (moveCount + 1) + moveIdx) * size + roomIdx] = var;
            }
        }
    }

#if 0
    vector<Var> valueVars(size * 4);
    for (uint64_t roomIdx = 0; roomIdx < size; roomIdx++)
    {
        for (uint64_t value = 0; value < 4; value++)
        {
            Var var = solver.newVar();
            valueVars[roomIdx * 4 + value] = var;
        }
    }
#endif

    for (uint64_t roomIdxI = 0; roomIdxI < size; roomIdxI++)
    {
        for (uint64_t doorIdxI = 0; doorIdxI < 6; doorIdxI++)
        {
            LitVec literals;
            for (uint64_t roomIdxJ = 0; roomIdxJ < size; roomIdxJ++)
            {
                for (uint64_t doorIdxJ = 0; doorIdxJ < 6; doorIdxJ++)
                {
                    if (roomIdxJ == roomIdxI && doorIdxJ != doorIdxI) continue;
                    literals.push(mkLit(connectionVars[(roomIdxI * 6 + doorIdxI) * (size * 6) + (roomIdxJ * 6 + doorIdxJ)]));
                }
            }
            exactlyOneTrue(solver, literals);
        }
    }

    for (uint64_t roomIdxI = 0; roomIdxI < size; roomIdxI++)
    {
        for (uint64_t doorIdxI = 0; doorIdxI < 6; doorIdxI++)
        {
            LitVec literals;
            for (uint64_t roomIdxJ = 0; roomIdxJ < size; roomIdxJ++)
            {
                literals.push(mkLit(destVars[(roomIdxI * 6 + doorIdxI) * size + roomIdxJ]));
            }
            exactlyOneTrue(solver, literals);
        }
    }

    for (uint64_t roomIdxI = 0; roomIdxI < size; roomIdxI++)
    {
        for (uint64_t doorIdxI = 0; doorIdxI < 6; doorIdxI++)
        {
            for (uint64_t roomIdxJ = 0; roomIdxJ < size; roomIdxJ++)
            {
                LitVec literals;
                Var destVar = destVars[(roomIdxI * 6 + doorIdxI) * size + roomIdxJ];
                literals.push(~mkLit(destVar));
                for (uint64_t doorIdxJ = 0; doorIdxJ < 6; doorIdxJ++)
                {
                    if (roomIdxJ == roomIdxI && doorIdxJ != doorIdxI) continue;
                    Var connectionVar = connectionVars[(roomIdxI * 6 + doorIdxI) * (size * 6) + (roomIdxJ * 6 + doorIdxJ)];
                    literals.push(mkLit(connectionVar));
                }
                solver.addClause(literals);
            }
        }
    }

    for (uint64_t planIdx = 0; planIdx < planCount; planIdx++)
    {
        for (uint64_t moveIdx = 0; moveIdx <= moveCount; moveIdx++)
        {
            uint64_t value = values[planIdx][moveIdx];
            LitVec literals;
            for (uint64_t roomIdx = value; roomIdx < size; roomIdx += 4)
            {
                literals.push(mkLit(roomVars[(planIdx * (moveCount + 1) + moveIdx) * size + roomIdx]));
            }
            exactlyOneTrue(solver, literals);
        }
    }

#if 0
    for (uint64_t roomIdx = 0; roomIdx < size; roomIdx++)
    {
        LitVec literals;
        for (uint64_t value = 0; value < 4; value++)
        {
            literals.push(mkLit(valueVars[roomIdx * 4 + value]));
        }
        exactlyOneTrue(solver, literals);
    }
#endif

    for (uint64_t planIdx = 0; planIdx < planCount; planIdx++)
    {
        for (uint64_t moveIdx = 0; moveIdx < moveCount; moveIdx++)
        {
            uint64_t doorIdxI = moves[planIdx][moveIdx];
            uint64_t valueI = values[planIdx][moveIdx];
            for (uint64_t roomIdxI = valueI; roomIdxI < size; roomIdxI += 4)
            {
                Var roomVarI = roomVars[(planIdx * (moveCount + 1) + moveIdx) * size + roomIdxI];
                uint64_t valueJ = values[planIdx][moveIdx + 1];
                for (uint64_t roomIdxJ = valueJ; roomIdxJ < size; roomIdxJ += 4)
                {
                    Var roomVarJ = roomVars[(planIdx * (moveCount + 1) + (moveIdx + 1)) * size + roomIdxJ];
                    Var destVar = destVars[(roomIdxI * 6 + doorIdxI) * size + roomIdxJ];
                    solver.addClause(~mkLit(roomVarI), ~mkLit(roomVarJ), mkLit(destVar));
#if 0
                    LitVec literals;
                    literals.push(~mkLit(roomVarI));
                    literals.push(~mkLit(roomVarJ));
                    for (uint64_t doorIdxJ = 0; doorIdxJ < 6; doorIdxJ++)
                    {
                        if (roomIdxJ == roomIdxI && doorIdxJ != doorIdxI) continue;
                        Var connectionVar = connectionVars[(roomIdxI * 6 + doorIdxI) * (size * 6) + (roomIdxJ * 6 + doorIdxJ)];
                        literals.push(mkLit(connectionVar));
                    }
                    solver.addClause(literals);
#endif
                }
            }
        }
    }

    bool sat = solver.solve();
    if (!sat)
    {
        printf("UNSAT\n");
        return 0;
    }

    printf("SAT\n");

    Solution solution;

    for (uint64_t roomIdx = 0; roomIdx < size; roomIdx++)
    {
        uint64_t value = roomIdx % 4;
        printf("%" PRIu64 ": %" PRIu64 "\n", roomIdx, value);
        solution.rooms.push_back(value);
    }

    uint64_t startingRoomIdx = 0;
    for (uint64_t roomIdx = 0; roomIdx < size; roomIdx++)
    {
        uint64_t planIdx = 0;
        uint64_t moveIdx = 0;
        Var var = roomVars[(planIdx * (moveCount + 1) + moveIdx) * size + roomIdx];
        if (solver.modelValue(var) == l_True)
        {
            startingRoomIdx = roomIdx;
            break;
        }
    }
    printf("start: %" PRIu64 "\n", startingRoomIdx);
    solution.startingRoom = startingRoomIdx;

    for (uint64_t roomIdxI = 0; roomIdxI < size; roomIdxI++)
    {
        for (uint64_t doorIdxI = 0; doorIdxI < 6; doorIdxI++)
        {
            for (uint64_t roomIdxJ = roomIdxI; roomIdxJ < size; roomIdxJ++)
            {
                for (uint64_t doorIdxJ = 0; doorIdxJ < 6; doorIdxJ++)
                {
                    if (roomIdxJ == roomIdxI && doorIdxJ != doorIdxI) continue;
                    Var var = connectionVars[(roomIdxI * 6 + doorIdxI) * (size * 6) + (roomIdxJ * 6 + doorIdxJ)];
                    if (solver.modelValue(var) != l_True) continue;

                    printf("(%" PRIu64 ", %" PRIu64 ") - (%" PRIu64 ", %" PRIu64 ")\n", roomIdxI, doorIdxI, roomIdxJ, doorIdxJ);

                    auto& connection = solution.connections.emplace_back();
                    connection.from.room = roomIdxI;
                    connection.from.door = doorIdxI;
                    connection.to.room = roomIdxJ;
                    connection.to.door = doorIdxJ;
                }
            }
        }
    }

    if (dryRun)
    {
        printf("Dry run: not submitting\n");
    }
    else
    {
        bool correct;
        if (!Protocol::guess(solution, &correct, &msg))
        {
            printf("%s\n", msg.c_str());
            return 1;
        }

        printf("correct = %s\n", correct ? "true" : "false");
    }

    return 0;
}
