import dolfinx
import numpy as np
from dolfinx import fem
from mpi4py import MPI
from petsc4py import PETSc


class Shell:
    """Identity operator"""

    def __init__(self):
        pass

    def mult(self, mat, X, Y):
        Y.zeroEntries()
        Y.axpy(1.0, X)


N = 32
mesh = dolfinx.mesh.create_unit_square(MPI.COMM_WORLD, N, N)
V = fem.FunctionSpace(mesh, ("CG", 1))

dim_global = V.dofmap.index_map.size_global
print("DOFs: ", dim_global)

u = fem.Function(V)

# parallel
b = u.vector
imap = V.dofmap.index_map
v_global = np.arange(dim_global)
b.setArray(v_global[imap.local_range[0] : imap.local_range[1]])
b.assemble()
b.ghostUpdate(addv=PETSc.InsertMode.INSERT, mode=PETSc.ScatterMode.FORWARD)
print(f"rank {MPI.COMM_WORLD.rank}: {imap.local_range}")
#

x = b.duplicate()

shell = Shell()

A = PETSc.Mat().createPython(
    ((imap.size_local, dim_global), (imap.size_local, dim_global)),
    context=shell,
    comm=mesh.comm,
)

# alternatively
# A = PETSc.Mat().create(mesh.comm)
# A.setType(PETSc.Mat.Type.PYTHON)
# A.setSizes(((imap.size_local, imap.size_global), (imap.size_local, imap.size_global)))
# A.setPythonContext(shell)
#

A.setUp()

print(f"rank {MPI.COMM_WORLD.rank}: A: {A.getLocalSize()}")

ksp = PETSc.KSP().create()
ksp.setOperators(A)
ksp.setType("cg")
pc = ksp.getPC()
pc.setType("none")
# ksp.setFromOptions()

# solve Ix = b
ksp.solve(b, x)

x_norm = x.norm()
b_norm = b.norm()

if mesh.comm.rank == 0:
    print("global:")
    print("x", x_norm)
    print("b", b_norm)
    print("\nlocal:")

print(f"rank {mesh.comm.rank}: |b| = {np.linalg.norm(b.array)}")
print(f"rank {mesh.comm.rank}: |x| = {np.linalg.norm(x.array)}")
assert np.allclose(x.array, b.array)