import java.util.Arrays;

public class Main {
	public static void main(String[] args) {
		int[][] test = { {2, 7, 6},     // Nothing wrong with this sample 3x3 array,
				{9, 5, 1},				// but isMagicSquare needs to work for ANY
				{4, 3, 8} };				// nxn array
		print2DArray(test);
		System.out.println(isMagicSquare(test));
	}
	
	public static boolean isMagicSquare(int[][] arr) {
		int sum = 0;
		for (int column = 0; column < arr.length; column++) {
			sum += arr[0][column];  // Add up elements of row 0
		}
//		return (diagsOK(arr, sum));
		return (rowsOK(arr, sum) && colsOK(arr, sum) && diagsOK(arr, sum)
				&& correctElements(arr));
	}

	
	public static boolean rowsOK(int[][] arr, int sum) {
		for (int row = 0; row < arr.length; row++) {
			int rowSum = 0;
			for (int col = 0; col < arr.length; col++) {
				rowSum += arr[row][col];
			}
			if (rowSum != sum) return false;
		}
		return true; 
	}

	public static boolean colsOK(int[][] arr, int sum) {
		for (int col = 0; col < arr.length; col++) {
			int colSum = 0;
			for (int row = 0; row < arr.length; row++) {
				colSum += arr[row][col];
			}
			if (colSum != sum) return false;
		}
		return true;
	}

	public static boolean diagsOK(int[][] arr, int sum) {
		int diagSum = 0, diagSum2 = 0;
		for (int row = 0; row < arr.length; row++) {
			diagSum += arr[row][row];
			diagSum2 += arr[row][arr.length-row-1];
		}
		return ((diagSum == sum) && (diagSum2 == sum));
	}
	
	public static boolean correctElements(int[][] arr) {
		int[] flatArray = new int[arr.length * arr[0].length]; // multiply rows by cols
		
		int index = 0;
		for (int row = 0; row < arr.length; row++) {
			for (int col = 0; col < arr.length; col++) {
				flatArray[index] = arr[row][col];
				index++;
			}
		}
		
		Arrays.sort(flatArray);   // now the elements are in ascending order
		
		for (int i = 0; i < flatArray.length; i++) {
			if (flatArray[i] != i+1) return false;
		}
		return true; 		
	}

	
	private static boolean correctElements2(int[][] arr) {
		int[] flatArray = new int[arr.length * arr[0].length];
		
		for (int row = 0; row < arr.length; row++) {
			for (int col = 0; col < arr[0].length; col++) {
				flatArray[row*arr[0].length + col] = arr[row][col];
			}
		}
		
		Arrays.sort(flatArray);
		
		for (int i = 0; i < flatArray.length; i++) {
			if (flatArray[i] != i+1) return false;
		}
		return true;
	}

	// Print utility for a 2-dimensional array of ints
	public static void print2DArray(int[][] arr) {
		for (int row = 0; row < arr.length; row++) {
			for (int col = 0; col < arr[0].length; col++) {
				System.out.print(arr[row][col] + " ");
			}
			System.out.println();
		}
	}
}