#include "xor_solver.h"

#include "../utils/logging.h"

using namespace std;

namespace parity_potentials {
XorSolver::XorSolver(AugmentedMatrix &&matrix) 
    : AugmentedMatrix(move(matrix)),
      first_non_zero_row(-1) {
}

void XorSolver::print_row(const Bitset &row) const {
    cout << row[row.size()-1] << " | ";
    for (int i = row.size()-2; i >= 0; --i) {
        cout << row[i];
    }
    cout << endl;
}

void XorSolver::zero_columns_below(int current_row, int current_column) {
    for (int row_index = current_row + 1;
             row_index < num_rows;
             ++row_index) {
        if (matrix[row_index][current_column])
            matrix[row_index] ^= matrix[current_row];
    }
}

int XorSolver::find_pivot_row(int current_row, int current_column) const {
    for (int row_index = current_row;
             row_index < num_rows;
             ++row_index) {
        if (matrix[row_index][current_column])
            return row_index;
    }
    return -1;
}

void XorSolver::gaussian_elimination() {
    int row_index = 0;
    int column_index = 0;
    while (row_index < num_rows &&
        column_index < num_cols-1) {
        int pivot_row = find_pivot_row(row_index, column_index);

        // Make sure that pivot row is on current row, swap if necessary.
        if (pivot_row < 0) {
            ++column_index;
            continue;
        } else if (pivot_row != row_index) {
            swap(matrix[row_index], matrix[pivot_row]);
        }
        // Eliminate all 1s in the pivot column.
        zero_columns_below(row_index, column_index);

        ++row_index;
        ++column_index;
    }
}

bool XorSolver::xor_row(Bitset &row, int starting_column) const {
    bool result = 0;
    for (int i = starting_column; i < num_cols-1; ++i) {
        result ^= row[i];
    }
    return result;
}

Bitset XorSolver::back_substitution() const {
    Bitset solution(num_cols);
    Bitset result_with_known_values(num_cols);
    Bitset handled(num_cols);
    Bitset all_ones(num_cols);

    solution.set(num_cols-1);
    all_ones.set();

    for (int i = first_non_zero_row; i >= 0; --i) {
        int first_bit_pos = -1;
        // Find first 1 in row
        // TODO: we can avoid doing this by storing the pivot locations during
        // the gaussian elimination
        for (int j = 0; j < num_cols-1; ++j) {
            if (matrix[i][j]) {
                first_bit_pos = j;
                break;
            }
        }
        result_with_known_values = matrix[i] & solution;
        // XOR over intermediate result, excluding rightmost bit
        bool xor_of_result =
            xor_row(result_with_known_values, first_bit_pos + 1);
        if (result_with_known_values[num_cols-1] ^ xor_of_result)
            solution.set(first_bit_pos);
        handled.set(first_bit_pos);
    }
    utils::g_log << "done!" << endl;
    if (!(handled == all_ones)) {
        // int num_free_vars = num_cols - handled.count();
    }
    return solution;
}

bool XorSolver::is_unsolvable() {
    Bitset unsolvability_mask(num_cols);
    Bitset zero_mask(num_cols);

    unsolvability_mask.set(num_cols-1);

    //cout << "before:\n" << *this;
    gaussian_elimination();
    //cout << "after:\n" << *this;

    int i = 1;
    Bitset* current_row;
    do {
        // Check if the row is all zeros except the most significant bit.
        // If yes, there is no solution.
        current_row = &matrix[num_rows - i];
        if (*current_row == unsolvability_mask)
            return false;
        else
            ++i;
    } while (i <= num_rows && *current_row == zero_mask);
    // This will be the starting point for the back substitution
    first_non_zero_row = num_rows - (--i);
    return true;
}

Bitset XorSolver::compute_solution_weights() {
    assert(first_non_zero_row >= 0);
    if (first_non_zero_row == -1) {
        cout << "Task is not solvable or unsolvability has not been\n"
             << "checked yet. Cannot compute solution weights." << endl;
        return Bitset(0);
    } else {
        return back_substitution();
    }
}

bool XorSolver::verify_solution(const Bitset &solution) {
    for (const Bitset &row : matrix) {
        Bitset anded = row & solution;
        bool result = xor_row(anded, 0);
        if (result != row[num_cols - 1])
            return false;
    }
    return true;
}
}
