# laura maria engist, 2025
# script to create a colored grid based on the values of an array

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap, BoundaryNorm
import os

class CreateColorfulGridForArray:
    def __init__(self):
        pass
    
    ''''
    methods to create colorful grids for arrays
    '''
    def create_plots(self, bounds, colors, path_png, path_svg, data, range_upper, start, ticks_range): 
        cmap = ListedColormap(colors)
        norm = BoundaryNorm(bounds, len(colors))

        axis_labels = list(range(start, range_upper))
        plt.imshow(data, cmap=cmap, norm=norm, origin='lower')
        cba = plt.colorbar(boundaries=bounds)
        tick_bounds = bounds[:-1]
        cba.set_ticks(tick_bounds)
        cba.set_ticklabels([f"{b:.3f}" for b in tick_bounds]) 
        plt.grid(False)

        plt.xticks(ticks=np.arange(ticks_range), labels=axis_labels)
        plt.yticks(ticks=np.arange(ticks_range), labels=axis_labels)
        plt.xlabel("ge")
        plt.ylabel("go")

        for indexI, i in enumerate(data): 
            for indexJ, j in enumerate(i):
                value = round(j, 3)
                if value != 400.0 and value != 200.0 and value != 100.0 and value != 1.0:
                    text_color = 'white'
                    plt.text(indexJ, indexI, str(value), ha='center', va='center', color=text_color, fontsize=4)

        plt.savefig(path_png, dpi=300, bbox_inches='tight')
        plt.savefig(path_svg, dpi=300, bbox_inches='tight')
        plt.show()

    def create_plots_cost(self, path_png, path_svg, data, range_upper): 
        bounds_cost = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 600.0]
        colors_cost = ['#dc267f', '#fe6100', '#ffb000', '#785ef0', '#648fff', '#D3D3D3'] 
        data_cost = []
        for d in data:
            d_cost = []
            for num in d:
                if num == 0.0:
                    d_cost.append(400)
                elif num == 1.0:
                    d_cost.append(400)
                else:
                    d_cost.append(num)
            data_cost.append(d_cost)
        data = data_cost
        ticks_range = range_upper - 1
        start = 1
        self.create_plots(bounds_cost, colors_cost, path_png, path_svg, data, range_upper, start, ticks_range)

    def create_plots_quality(self, path_png, path_svg, data, range_upper):
        bounds_quality = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 600.0]
        colors_quality = ['#648fff', '#785ef0', '#ffb000', '#fe6100', '#dc267f', '#D3D3D3']
        data_quality = []
        for d in data:
            d_quality = []
            for num in d:
                if num == 0.0:
                    d_quality.append(400)
                else:
                    d_quality.append(num)
            data_quality.append(d_quality)
        data = data_quality
        ticks_range = range_upper - 1
        start = 1
        self.create_plots(bounds_quality, colors_quality, path_png, path_svg, data, range_upper, start, ticks_range)

# TODO for one experiment, set all parameters and run the method corresponding to the grid of interest:
diretory_experiment = "" # path to the experiment directory
path_png_qualAln_quality = "" # path to save the png file for qualityAln quality
path_svg_qualAln_quality = "" # path to save the svg file for qualityAln quality
path_png_qualIdent_quality = "" # path to save the png file for qualityIdent quality
path_svg_qualIdent_quality = "" # path to save the svg file for qualityIdent quality
path_png_qualAln_cost = "" # path to save the png file for qualityAln cost
path_svg_qualAln_cost = "" # path to save the svg file for qualityAln cost
path_png_qualIdent_cost = "" # path to save the png file for qualityIdent cost
path_svg_qualIdent_cost = "" # path to save the svg file for qualityIdent cost
max_value_matrix = 0 # maximum value of the matrix used
grid_data_qualAln = ""  # array with the values for qualityAln - computed with /hypopthd/evaluation/gap_penalties_alignment_quality_or_identification/gap_penalties_qualAln_qualIdent.py
grid_data_qualIdent = "" # array with the values for qualityIdent - computed with /hypopthd/evaluation/gap_penalties_alignment_quality_or_identification/gap_penalties_qualAln_qualIdent.py

create_colorful_grid = CreateColorfulGridForArray()
create_colorful_grid.create_plots_quality(path_png_qualAln_quality, path_svg_qualAln_quality, grid_data_qualAln, max_value_matrix)
create_colorful_grid.create_plots_cost(path_png_qualAln_cost, path_svg_qualAln_cost, grid_data_qualAln, max_value_matrix)
create_colorful_grid.create_plots_quality(path_png_qualIdent_quality, path_svg_qualIdent_quality, grid_data_qualIdent, max_value_matrix)
create_colorful_grid.create_plots_cost(path_png_qualIdent_cost, path_svg_qualIdent_cost, grid_data_qualIdent, max_value_matrix)
