import os
import pydicom
import numpy as np
import nibabel as nib
from collections import defaultdict
import json
import tkinter as tk
from tkinter import filedialog, messagebox

def load_dicom_series(directory):
    """Load a series of DICOM files from a directory."""
    dicom_files = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith('.dcm')]
    dicoms = [pydicom.dcmread(f) for f in dicom_files]
    return dicoms

def group_dicoms_by_echo_time(dicoms):
    """Group DICOM files by their Echo Time."""
    groups = defaultdict(list)
    for dcm in dicoms:
        echo_time = float(dcm.EchoTime)
        groups[echo_time].append(dcm)
    return groups

def sort_slices_by_position(slices):
    """Sort slices by their position in space (e.g., SliceLocation)."""
    return sorted(slices, key=lambda x: float(x.SliceLocation))

def create_3d_volume(slices):
    """Create a 3D volume from a list of slices."""
    # Sort slices by position
    sorted_slices = sort_slices_by_position(slices)
    
    # Stack slices along the Z-axis
    return np.stack([s.pixel_array for s in sorted_slices], axis=-1)

def create_4d_volume(groups):
    """Create a 4D volume by stacking 3D volumes for each echo time."""
    # Create a 3D volume for each echo time
    volumes = [create_3d_volume(slices) for slices in groups.values()]
    
    # Stack 3D volumes along the time dimension (4th axis)
    return np.stack(volumes, axis=-1)

def get_affine_matrix(dicom):
    """Construct the affine matrix from DICOM headers."""
    if not hasattr(dicom, 'ImagePositionPatient') or not hasattr(dicom, 'ImageOrientationPatient'):
        return np.eye(4)  # Default affine if headers are missing
    
    # Extract DICOM header information
    orientation = np.array(dicom.ImageOrientationPatient).reshape(2, 3)
    position = np.array(dicom.ImagePositionPatient)
    spacing = np.array([float(dicom.PixelSpacing[0]), float(dicom.PixelSpacing[1]), float(dicom.SliceThickness)])

    rowdirection=orientation[0]*spacing[0]
    columedirection =orientation[1]*spacing[1]
    slicdirection=np.cross(orientation[0],orientation[1]*spacing[2])

    if columedirection[0]<0:
        columedirection=-columedirection
   
    # Create the affine matrix
    affine = np.eye(4)
    affine[:3, :3] = np.array([columedirection,rowdirection,slicdirection]).T
    affine[:3, 3] = position  # Set the translation (position)
    
    return affine

def extract_metadata(dicoms):
    """Extract metadata from DICOM files."""
    # Get unique echo times
    echo_times = sorted({float(d.EchoTime) for d in dicoms})
    
    # Calculate dTE (time difference between echoes)
    dTE = np.diff(echo_times).tolist()
    
    # Get other metadata
    metadata = {
        "B0": float(dicoms[0].MagneticFieldStrength),  # B0 in Tesla
        "TE": echo_times,  # List of echo times
        "TR": float(dicoms[0].RepetitionTime),  # TR in ms
        "voxel_size": [
            float(dicoms[0].PixelSpacing[0]),  # X
            float(dicoms[0].PixelSpacing[1]),  # Y
            float(dicoms[0].SliceThickness)    # Z
        ],
        "dTE": dTE  # Time difference between echoes
    }
    return metadata

def save_as_nifti(data, affine, metadata, output_path):
    """Save the 4D volume as a NIfTI file with metadata."""
    # Create NIfTI image
    nifti_img = nib.Nifti1Image(data, affine)
    
    # Add metadata to the NIfTI header
    nifti_img.header["descrip"] = json.dumps(metadata)  # Store metadata as JSON in the description field
    
    # Save NIfTI file
    nib.save(nifti_img, output_path)
    print(f"4D NIfTI file saved to {output_path}")

def convert_to_nifti(dicom_dir, output_path):
    # Load DICOM files
    dicoms = load_dicom_series(dicom_dir)
    
    # Group DICOMs by Echo Time
    groups = group_dicoms_by_echo_time(dicoms)
    
    # Create 4D volume
    data_4d = create_4d_volume(groups)
    
    # Get affine matrix from the first DICOM file
    affine = get_affine_matrix(dicoms[0])

    #metadata
    metadata=extract_metadata(dicoms)
    
    # Save as NIfTI
    save_as_nifti(data_4d, affine, metadata, output_path)
    print(f"4D NIfTI file saved to {output_path}")

def rescale(nifty_file,rescale_output):
    import nibabel as nib
    import numpy as np

    # Load the NIfTI file
    nifti_file_path = nifty_file  # Replace with your NIfTI file path
    img = nib.load(nifti_file_path)
    data = img.get_fdata()


    # Normalize the data to the range [-π, π]
    original_min = np.min(data)
    original_max = np.max(data)
    target_min = -np.pi
    target_max = np.pi

    # Normalize the data
    normalized_data = (data - original_min) / (original_max - original_min)  # Scale to [0, 1]
    scaled_data = normalized_data * (target_max - target_min) + target_min  # Scale to [-π, π]

    # Save the modified data back to a new NIfTI file
    new_img = nib.Nifti1Image(scaled_data, img.affine, img.header)
    output_file_path = rescale_output # Replace with your desired output file path
   

    print(f"Image saved to {output_file_path}") 
    return nib.save(new_img, output_file_path)


def select_dicom_folder():
    """Open a dialog to select the DICOM folder."""
    folder = filedialog.askdirectory(title="Select DICOM Folder")
    if folder:
        dicom_folder_entry.delete(0, tk.END)
        dicom_folder_entry.insert(0, folder)

def select_output_path():
    """Open a dialog to select the output NIfTI file path."""
    output_path = filedialog.asksaveasfilename(
        title="Save NIfTI File As",
        defaultextension=".nii.gz",
        filetypes=[("NIfTI files", "*.nii.gz"), ("All files", "*.*")]
    )
    if output_path:
        output_path_entry.delete(0, tk.END)
        output_path_entry.insert(0, output_path)

def select_nifty_file_to_rescale():
    input_path= filedialog.askopenfilename(
        title="Rescale_nifty",
        defaultextension=".nii.gz",
        filetypes=[("Nifty_files","*.nii.gz")]
    )
    if input_path:
        input_path_entry.delete(0,tk.END)
        input_path_entry.insert(0,input_path)

def select_rescaled_image_output():
    output_rescale_path= filedialog.asksaveasfilename(
        title="Save rescaled file",
        defaultextension=".nii.gz",
        filetypes=[("Nifty files","*.nii.gz")]
    )
    if output_rescale_path:
        output_rescale_path_entry.delete(0,tk.END)
        output_rescale_path_entry.insert(0,output_rescale_path)

def rescaling():
    nifty_file=input_path_entry.get()
    rescale_output=output_rescale_path_entry.get()

    if not nifty_file:
        messagebox.showerror("Please choose a file to rescale")
        return
    try:
        rescale(nifty_file,rescale_output)
    except Exception as f:
        messagebox.showerror("Error",f"An Error occured {str(f)}")

def start_conversion():
    """Start the conversion process."""
    dicom_dir = dicom_folder_entry.get()
    output_path = output_path_entry.get()
    
    if not dicom_dir or not output_path:
        messagebox.showerror("Error", "Please select both the DICOM folder and the output path.")
        return
    
    try:
        convert_to_nifti(dicom_dir, output_path)
    except Exception as e:
        messagebox.showerror("Error", f"An error occurred: {str(e)}")


# Create the main GUI window
root = tk.Tk()
root.title("DICOM to NIfTI Converter")

# DICOM Folder Selection
tk.Label(root, text="DICOM Folder:").grid(row=0, column=0, padx=5, pady=5)
dicom_folder_entry = tk.Entry(root, width=50)
dicom_folder_entry.grid(row=0, column=1, padx=5, pady=5)
tk.Button(root, text="Browse", command=select_dicom_folder).grid(row=0, column=2, padx=5, pady=5)

# Output Path Selection
tk.Label(root, text="Output NIfTI File:").grid(row=1, column=0, padx=5, pady=5)
output_path_entry = tk.Entry(root, width=50)
output_path_entry.grid(row=1, column=1, padx=5, pady=5)
tk.Button(root, text="Browse", command=select_output_path).grid(row=1, column=2, padx=5, pady=5)

# rescaled image selection
tk.Label(root, text="Input nifty to rescale:").grid(row=3, column=0, padx=5, pady=5)
input_path_entry = tk.Entry(root, width=50)
input_path_entry.grid(row=3, column=1, padx=5, pady=5)
tk.Button(root, text="Browse", command=select_nifty_file_to_rescale).grid(row=3, column=2, padx=5, pady=5)

# Output Path Selection for rescaled image
tk.Label(root, text="Rescaled nifty output:").grid(row=4, column=0, padx=5, pady=5)
output_rescale_path_entry = tk.Entry(root, width=50)
output_rescale_path_entry.grid(row=4, column=1, padx=5, pady=5)
tk.Button(root, text="Browse", command=select_rescaled_image_output).grid(row=4, column=2, padx=5, pady=5)


# Convert Button
tk.Button(root, text="Convert", command=start_conversion).grid(row=2, column=1, pady=10)

# Convert Button
tk.Button(root, text="Rescale", command=rescaling).grid(row=5, column=1, pady=10)


# Run the GUI
root.mainloop()