import sys
if '-petsc' in sys.argv:
    i = sys.argv.index('-petsc')
    argv  = [sys.argv[0]]
    argv += sys.argv[(i+1):]
    del sys.argv[i:]
    import petsc
    petsc.Initialize(argv)

# --------------------------------------------------------------------

from petsc4py import PETSc
import unittest
from sys import getrefcount

# --------------------------------------------------------------------

class TestPCBase(object):

    KSP_TYPE = None
    PC_TYPE  = None
    
    def setUp(self):
        pc = PETSc.PC()
        pc.create(PETSc.COMM_SELF)
        pc.setType(self.PC_TYPE)
        self.pc = pc
        
    def tearDown(self):
        self.pc = None

    def testGetSetType(self):
        self.assertEqual(self.pc.getType(), self.PC_TYPE)
        self.pc.setType(self.PC_TYPE)
        self.assertEqual(self.pc.getType(), self.PC_TYPE)
        self.pc.setType('none')
        self.assertEqual(self.pc.getType(), 'none')
        self.pc.setType(self.PC_TYPE)
        self.assertEqual(self.pc.getType(), self.PC_TYPE)
        
    def testApply(self):
        A = PETSc.MatSeqAIJ(3)
        A.assemble()
        A.shift(10)
        x, y = A.getVecs()
        y.set(10)
        x.setRandom()
        self.pc.setOperators(A,A,'same_nz')
        self.pc.apply(y, x)
        return A, y, x
        
        

# --------------------------------------------------------------------

class TestPCNONE(TestPCBase, unittest.TestCase):
    PC_TYPE = PETSc.PC.Type.NONE
    def testApply(self):
        A, y, x = super(TestPCNONE, self).testApply()
        self.assertTrue(y.equal(x))


class PCShCtx_NONE(object):
    def setUp(self, *args):
        pass
    def apply(self, x, y):
        x.copy(y)

class PCShCtx_JACOBI(object):
    def setUp(self, A, B, struct):
        self.diag = B.getVecLeft()
        B.getDiagonal(self.diag)
    def apply(self, x, y):
        y.pointwiseDivide(x, self.diag)
    
class TestPCSHELL(TestPCBase, unittest.TestCase):
    PC_TYPE = PETSc.PC.Type.SHELL

    def setUp(self):
        super(TestPCSHELL, self).setUp()
        self.pc = PETSc.PCShell(self.pc)

    def _setContext(self, ctx):
        rcnt = getrefcount(ctx)
        self.pc.setContext(ctx)
        self.assertEqual(getrefcount(ctx), rcnt+1)
        self.assertTrue(self.pc.getContext() is ctx)
        self.assertEqual(getrefcount(ctx), rcnt+1)
        
    def _testApplyNone(self):
        pc = PETSc.PCShell(self.pc)
        ctx = PCShCtx_NONE()
        self._setContext(ctx)
        A, y, x = super(TestPCSHELL, self).testApply()
        self.assertTrue(y.equal(x))
        
    def _testApplyJacobi(self):
        pc = PETSc.PCShell(self.pc)
        ctx = PCShCtx_JACOBI()
        self._setContext(ctx)
        A, y, x = super(TestPCSHELL, self).testApply()
        diag = A.getVecLeft()
        A.getDiagonal(diag)
        
    def testApply(self):
        #
        self._testApplyNone()
        ctx = self.pc.getContext()
        self.assertEqual(getrefcount(ctx)-1, 2)
        self.pc.setContext(None)
        self.assertEqual(getrefcount(ctx)-1, 1)
        self.pc.setContext(None)
        #
        pc = self._testApplyJacobi()
        ctx = self.pc.getContext()
        self.assertEqual(getrefcount(ctx)-1, 2)
        self.pc.setContext(None)
        self.assertEqual(getrefcount(ctx)-1, 1)
        self.pc.setContext(None)
        #
        self._testApplyNone()
        ctx = self.pc.getContext()
        self.assertEqual(getrefcount(ctx)-1, 2)
        self.pc.destroy()
        self.assertEqual(getrefcount(ctx)-1, 1)
        
# --------------------------------------------------------------------

if __name__ == '__main__':
    unittest.main()
