from types import DynamicClassAttribute
import numpy as np
from matplotlib import pyplot as plt
from Vaja_2.skripta import loadImage, displayImage

#10.11 prikazovanje 3D slik
def loadImage3D(iPath, iSize, iType): #kopirana in dodelana koda loadImage
    fid = open(iPath, 'rb') #r = read, rb = read binary, funkcija open nam vrne samo kazalec na file FILE ID
    buffer = fid.read() #naredimo buffer kjer preberemo kaj je v fid, niz znakov slike - uint8

    #buffer_len = len(np.frombuffer(buffer, dtype=type)) #numpy funkcija ki razdeli "string" slike, recimo na 8 bit uint
    #preverjamo ali se dolžina seznama ujema z napovedano velikostjo matrike - size v def funkcije
    
    oImage_shape = (iSize[1], iSize[0], iSize[2])
    
    #sekvenco znakov razseka na "vrstice" slike, order 'F' pomeni da sliko bere v stolpičnem načinu (namesto vrstic)
    oImage = np.ndarray(oImage_shape, dtype=iType, buffer=buffer, order='F')

    fid.close()

    return oImage

def displayImage3D(iImage, iTitle='', iGridX=None, iGridY=None): #prazne navednice da je lahko brez naslova
    fig = plt.figure()
    plt.title(iTitle)

    if iGridX is None or iGridY is None:
        extend = (-0.5, iImage.shape[1]-0.5, iImage.shape[0]-0.5, -0.5)
    else:
        stepX = iGridX[1] - iGridX[0]
        stepY = iGridY[1] - iGridY[0]
        extend = (
            iGridX[0] - stepX/2,
            iGridX[-1] + stepX/2,
            iGridY[-1] + stepY/2,
            iGridY[0] + stepY/2
        )

    plt.imshow(
        iImage,
        cmap="gray", # plt.cm.gray, kliče sivo iz možnih colormapov
        vmin=0,
        vmax=255, #definiramo razpon brav, da ne prilagodi sam (8bit uint)
        extent=extend, #od kje do kje prikažemo sliko, image.shape 1 so stolpci
    )

def getPlanarCrossSection(iImage, iDim, iNormVec, iLoc):
    Y, X, Z = iImage.shape
    dx, dy, dz = iDim

    #stranski prerez (glej predavanja)m np.array == np.array vrne array true/false, array == array pa samo true/false
    #ko je nx = (1,0,0), rezina iz X osi
    if iNormVec == [1, 0, 0]:
        oCS = iImage[:, iLoc, :].T #tu lahko dodamo .T za transponiranje (obračanje matrike)
        oH = np.arange(Y) * dy #vertikalna os * velikost voxlov 
        oV = np.arange(Z) * dz #velikost z osi * korak vzorčenja, isto kot gori
    elif iNormVec == [0,1,0]:
        oCS = iImage[iLoc, :, :].T
        oH = np.arange(X) * dx
        oV = np.arange(Z) * dz
    elif iNormVec == [0,0,1]:
        oCS = iImage[:, :, iLoc]
        oH = np.arange(X) * dx
        oV = np.arange(Y) * dy

    return oCS, oH, oV 

def getPlanarProjection(iImage, iDim, iNormVec, iFunc):
    Y, X, Z = iImage.shape
    dx, dy, dz = iDim

    # stranska projekcija
    if iNormVec == [1, 0, 0]:
        #dolgi način
        #oP = np.zeros((Z, Y))
        #for z in range(Z):
        #    for y in range(Y):
                #oP[z,y] = iFunc(iImage[y, :, z]) #np.max funkcija ima sposobnost iskanja po vrsticah, np.max(a, axis = 0) bo npr iskala max vrednosti po stolpcih, np.max(a, axis = 1) pa po vrsticah
        oP = iFunc(iImage, axis = 1).T
        oH = np.arange(Y) * dy
        oV = np.arange(Z) * dz
    #celna projekcija, ny = 0,1,0
    elif iNormVec == [0,1,0]:
        oP = iFunc(iImage, axis = 1).T
        oH = np.arange(X) * dx
        oV = np.arange(Z) * dz
    #prečna projekcija, nz = 0,0,1
    elif iNormVec == [0,0,1]:
        oP = iFunc(iImage, axis = 2)
        oH = np.arange(X) * dx
        oV = np.arange(Y) * dy
    return oP, oH, oV

if __name__ == "__main__":
    imSize = [512, 58, 907]
    pxDim = [0.597656, 3, 0.597656] #v mm, iz navodil
    I = loadImage3D('Vaja_4\spine-512x058x907-08bit.raw', imSize, np.uint8)
    print(I.shape)
    xc = 290
    sagCS, sagH, sagV = getPlanarCrossSection(I, pxDim, [1,0,0], xc)
    title = f'Stranski pravokotni ravninski prerez pri xc = {xc}'
    displayImage3D(sagCS, title, sagH, sagV) #sag so vektorji koordinat 
    xc = 500
    axCS, axH, axV = getPlanarCrossSection(I, pxDim, [0,0,1], xc)
    title = f'Prečni pravokotni ravninski prerez pri xc = {xc}'
    displayImage3D(axCS, title, axH, axV) #sag so vektorji koordinat
    
    func = np.max
    sagP, sagPH, sagPV = getPlanarProjection(I, pxDim, [1,0,0], func)
    title = f'Stranska pravokotna ravninska projekcija za func = {func.__name__}'
    displayImage3D(sagP, title, sagPH, sagPV)

    sagP2, sagPH2, sagPV2 = getPlanarProjection(I, pxDim, [0,0,1], func)
    title = f'Stranska pravokotna ravninska projekcija za func = {func.__name__}'
    displayImage3D(sagP2, title, sagPH2, sagPV2)