# Licensed under a 3-clause BSD style license - see LICENSE.rst
# -*- coding: utf-8 -*-
"""Test the functions in pydl.pydlutils.rgbcolor.
"""
import pytest
import numpy as np
from .. import PydlutilsUserWarning
from ..rgbcolor import (nw_arcsinh, nw_cut_to_box, nw_float_to_byte,
                        nw_scale_rgb)


def test_nw_arcsinh():
    colors = np.random.random((10, 10))
    with pytest.raises(ValueError):
        fitted_colors = nw_arcsinh(colors)
    colors = np.random.random((10, 10, 5))
    with pytest.raises(ValueError):
        fitted_colors = nw_arcsinh(colors)
    colors = np.random.random((10, 10, 3))
    fitted_colors = nw_arcsinh(colors, nonlinearity=0)
    assert (fitted_colors == colors).all()
    colors = np.ones((2, 2, 3))
    fac = np.arcsinh(9.0)/9.0
    fitted_colors = nw_arcsinh(colors)
    assert np.allclose(fitted_colors, fac)


def test_nw_cut_to_box():
    colors = np.random.random((10, 10))
    with pytest.raises(ValueError):
        boxed_colors = nw_cut_to_box(colors)
    colors = np.random.random((10, 10, 5))
    with pytest.raises(ValueError):
        boxed_colors = nw_cut_to_box(colors)
    colors = np.random.random((10, 10, 3))
    with pytest.raises(ValueError):
        boxed_colors = nw_cut_to_box(colors, origin=(1.0, 1.0))
    boxed_colors = nw_cut_to_box(colors)
    assert np.allclose(boxed_colors, colors)


def test_nw_float_to_byte():
    colors = np.zeros((10, 10, 3), dtype=np.float32)
    byte_colors = nw_float_to_byte(colors)
    assert (byte_colors == 0).all()
    colors = np.ones((10, 10, 3), dtype=np.float32)
    byte_colors = nw_float_to_byte(colors)
    assert (byte_colors == 255).all()
    with pytest.warns(PydlutilsUserWarning) as w:
        byte_colors = nw_float_to_byte(colors, bits=16)
    assert len(w) > 0


def test_nw_scale_rgb():
    colors = np.random.random((10, 10))
    with pytest.raises(ValueError):
        scaled_colors = nw_scale_rgb(colors)
    colors = np.random.random((10, 10, 5))
    with pytest.raises(ValueError):
        scaled_colors = nw_scale_rgb(colors)
    colors = np.random.random((10, 10, 3))
    with pytest.raises(ValueError):
        scaled_colors = nw_scale_rgb(colors, scales=(1.0, 1.0))
    colors = np.ones((2, 2, 3))
    scaled_colors = nw_scale_rgb(colors, scales=(2.0, 2.0, 2.0))
    assert np.allclose(scaled_colors, 2.0)
