Source code for mmagic.datasets.transforms.gcp_unprocess
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.distributions as tdist
[docs]def random_ccm():
"""Generates random RGB -> Camera color correction matrices."""
# Takes a random convex combination of XYZ -> Camera CCMs.
xyz2cams = [[[1.0234, -0.2969, -0.2266], [-0.5625, 1.6328, -0.0469],
[-0.0703, 0.2188, 0.6406]],
[[0.4913, -0.0541, -0.0202], [-0.613, 1.3513, 0.2906],
[-0.1564, 0.2151, 0.7183]],
[[0.838, -0.263, -0.0639], [-0.2887, 1.0725, 0.2496],
[-0.0627, 0.1427, 0.5438]],
[[0.6596, -0.2079, -0.0562], [-0.4782, 1.3016, 0.1933],
[-0.097, 0.1581, 0.5181]]]
num_ccms = len(xyz2cams)
xyz2cams = torch.FloatTensor(xyz2cams)
weights = torch.FloatTensor(num_ccms, 1, 1).uniform_(1e-8, 1e8)
weights_sum = torch.sum(weights, dim=0)
xyz2cam = torch.sum(xyz2cams * weights, dim=0) / weights_sum
# Multiplies with RGB -> XYZ to get RGB -> Camera CCM.
rgb2xyz = torch.FloatTensor([[0.4124564, 0.3575761, 0.1804375],
[0.2126729, 0.7151522, 0.0721750],
[0.0193339, 0.1191920, 0.9503041]])
rgb2cam = torch.mm(xyz2cam, rgb2xyz)
# Normalizes each row.
rgb2cam = rgb2cam / torch.sum(rgb2cam, dim=-1, keepdim=True)
return rgb2cam
[docs]def random_gains(rgb_gain_ratio=1.0,
red_gain_range=[1.9, 2.4],
blue_gain_range=[1.5, 1.9]):
"""Generates random gains for brightening and white balance."""
# RGB gain represents brightening.
n = tdist.Normal(loc=torch.tensor([0.8]), scale=torch.tensor([0.1]))
rgb_gain = 1.0 / n.sample()
rgb_gain = rgb_gain_ratio * rgb_gain
# Red and blue gains represent white balance.
red_gain = torch.FloatTensor(1).uniform_(red_gain_range[0],
red_gain_range[1])
blue_gain = torch.FloatTensor(1).uniform_(blue_gain_range[0],
blue_gain_range[1])
return rgb_gain, red_gain, blue_gain
[docs]def inverse_smoothstep(image):
"""Approximately inverts a global tone mapping curve."""
# Permute the image tensor to HxWxC format from CxHxW format
image = image.permute(1, 2, 0)
image = torch.clamp(image, min=0.0, max=1.0)
out = 0.5 - torch.sin(torch.asin(1.0 - 2.0 * image) / 3.0)
out = out.permute(2, 0, 1) # Re-Permute the tensor back to CxHxW format
return out
[docs]def gamma_expansion(image):
"""Converts from gamma to linear space."""
# Clamps to prevent numerical instability of gradients near zero.
# Permute the image tensor to HxWxC format from CxHxW format
image = image.permute(1, 2, 0)
out = torch.clamp(image, min=1e-8)**2.2
# Re-Permute the tensor back to CxHxW format
out = out.permute(2, 0, 1)
return out
[docs]def apply_ccm(image, ccm):
"""Applies a color correction matrix."""
# Permute the image tensor to HxWxC format from CxHxW format
image = image.permute(1, 2, 0)
shape = image.size()
image = torch.reshape(image, [-1, 3])
image = torch.tensordot(image, ccm, dims=[[-1], [-1]])
out = torch.reshape(image, shape)
out = out.permute(2, 0, 1) # Re-Permute the tensor back to CxHxW format
return out
[docs]def safe_invert_gains(image, rgb_gain, red_gain, blue_gain):
"""Inverts gains while safely handling saturated pixels."""
# Permute the image tensor to HxWxC format from CxHxW format
image = image.permute(1, 2, 0)
gains = torch.stack(
(1.0 / red_gain, torch.tensor([1.0]), 1.0 / blue_gain)) / rgb_gain
gains = gains.squeeze()
gains = gains[None, None, :]
# Prevents dimming of saturated pixels by smoothly masking gains near white
gray = torch.mean(image, dim=-1, keepdim=True)
inflection = 0.9
mask = (torch.clamp(gray - inflection, min=0.0) / (1.0 - inflection))**2.0
safe_gains = torch.max(mask + (1.0 - mask) * gains, gains)
out = image * safe_gains
# Re-Permute the tensor back to CxHxW format
out = out.permute(2, 0, 1)
return out
[docs]def mosaic(image):
"""Extracts RGGB Bayer planes from an RGB image."""
# Permute the image tensor to HxWxC format from CxHxW format
image = image.permute(1, 2, 0)
shape = image.size()
red = image[0::2, 0::2, 0]
green_red = image[0::2, 1::2, 1]
green_blue = image[1::2, 0::2, 1]
blue = image[1::2, 1::2, 2]
out = torch.stack((red, green_red, green_blue, blue), dim=-1)
out = torch.reshape(out, (shape[0] // 2, shape[1] // 2, 4))
# Re-Permute the tensor back to CxHxW format
out = out.permute(2, 0, 1)
return out
[docs]def random_noise_levels_kpn():
sigma_read = torch.from_numpy(
np.power(10, np.random.uniform(-3.0, -1.5, (1, ))))
# sigma_read = sigma_read**2
sigma_shot = torch.from_numpy(
np.power(10, np.random.uniform(-4.0, -2.0, (1, ))))
sigma_read = sigma_read.type(torch.FloatTensor)
sigma_shot = sigma_shot.type(torch.FloatTensor)
return sigma_shot, sigma_read
[docs]def add_noise(image,
shot_noise=0.01,
read_noise=0.0005,
read_noise_exponent=2):
"""Adds random shot (proportional to image) and read (independent)
noise."""
# Permute the image tensor to HxWxC format from CxHxW format
image = image.permute(1, 2, 0)
variance = image * shot_noise + read_noise**read_noise_exponent
n = tdist.Normal(
loc=torch.zeros_like(variance), scale=torch.sqrt(variance))
noise = n.sample()
out = image + noise
out = out.permute(2, 0, 1) # Re-Permute the tensor back to CxHxW format
return out
[docs]def unprocess(image):
"""Unprocesses an image from sRGB to realistic raw data."""
# Randomly creates image metadata.
rgb2cam = random_ccm()
cam2rgb = torch.inverse(rgb2cam)
rgb_gain, red_gain, blue_gain = random_gains()
# Approximately inverts global tone mapping.
image = inverse_smoothstep(image)
# Inverts gamma compression.
image = gamma_expansion(image)
# Inverts color correction.
image = apply_ccm(image, rgb2cam)
# Approximately inverts white balance and brightening.
image = safe_invert_gains(image, rgb_gain, red_gain, blue_gain)
# Clips saturated pixels.
image = torch.clamp(image, min=0.0, max=1.0)
# Applies a Bayer mosaic.
image = mosaic(image)
metadata = {
'cam2rgb': cam2rgb,
'rgb2cam': rgb2cam,
'rgb_gain': rgb_gain,
'red_gain': red_gain,
'blue_gain': blue_gain,
}
return image, metadata
[docs]def unprocess_gt(image):
"""Unprocesses an image from sRGB to realistic raw data."""
# Randomly creates image metadata.
rgb2cam = random_ccm()
cam2rgb = torch.inverse(rgb2cam)
rgb_gain, red_gain, blue_gain = random_gains()
# Approximately inverts global tone mapping.
image = inverse_smoothstep(image)
# Inverts gamma compression.
image = gamma_expansion(image)
# Inverts color correction.
image = apply_ccm(image, rgb2cam)
# Approximately inverts white balance and brightening.
image = safe_invert_gains(image, rgb_gain, red_gain, blue_gain)
# Clips saturated pixels.
image = torch.clamp(image, min=0.0, max=1.0)
# Applies a Bayer mosaic.
# image = mosaic(image)
metadata = {
'cam2rgb': cam2rgb,
'rgb2cam': rgb2cam,
'rgb_gain': rgb_gain,
'red_gain': red_gain,
'blue_gain': blue_gain,
}
return image, metadata
[docs]def unprocess_meta_gt(image, rgb_gains, red_gains, blue_gains, rgb2cam,
cam2rgb):
"""Unprocesses an image from sRGB to realistic raw data."""
# Approximately inverts global tone mapping.
image = inverse_smoothstep(image)
# Inverts gamma compression.
image = gamma_expansion(image)
# Inverts color correction.
image = apply_ccm(image, rgb2cam)
# Approximately inverts white balance and brightening.
image = safe_invert_gains(image, rgb_gains, red_gains, blue_gains)
# Clips saturated pixels.
image = torch.clamp(image, min=0.0, max=1.0)
# Applies a Bayer mosaic.
# image = mosaic(image)
metadata = {
'cam2rgb': cam2rgb,
'rgb2cam': rgb2cam,
'rgb_gain': rgb_gains,
'red_gain': red_gains,
'blue_gain': blue_gains,
}
return image, metadata
[docs]def random_noise_levels():
"""Generates random noise levels from a log-log linear distribution."""
log_min_shot_noise = np.log(0.0001)
log_max_shot_noise = np.log(0.012)
log_shot_noise = torch.FloatTensor(1).uniform_(log_min_shot_noise,
log_max_shot_noise)
shot_noise = torch.exp(log_shot_noise)
def line(x):
return 2.18 * x + 1.20
n = tdist.Normal(loc=torch.tensor([0.0]), scale=torch.tensor([0.26]))
log_read_noise = line(log_shot_noise) + n.sample()
read_noise = torch.exp(log_read_noise)
return shot_noise, read_noise
[docs]def add_noise_test(image, shot_noise=0.01, read_noise=0.0005, count=0):
"""Adds random shot (proportional to image) and read (independent)
noise."""
# Permute the image tensor to HxWxC format from CxHxW format
image = image.permute(1, 2, 0)
variance = image * shot_noise + read_noise**2
# n = tdist.Normal(
# loc=torch.zeros_like(variance), scale=torch.sqrt(variance))
# noise = n.sample()
seed = torch.Generator()
seed.manual_seed(count)
noise = torch.normal(
mean=torch.zeros_like(variance),
std=torch.sqrt(variance),
generator=seed)
out = image + noise
# Re-Permute the tensor back to CxHxW format
out = out.permute(2, 0, 1)
return out