Shortcuts

Source code for mmagic.datasets.transforms.gcp_process

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn


[docs]def apply_gains_jdd(bayer_images, red_gains, blue_gains): """Applies white balance gains to a batch of Bayer images.""" if red_gains.clone().detach().shape[0] == bayer_images.shape[0]: red_gains = red_gains.clone().detach() blue_gains = blue_gains.clone().detach() else: red_gains = torch.tensor(([[red_gains]])) blue_gains = torch.tensor(([[blue_gains]])) # Permute the image tensor to BxHxWxC format from BxCxHxW format bayer_images = bayer_images.permute(0, 2, 3, 1) green_gains = torch.ones_like(red_gains) gains = torch.cat([red_gains, green_gains, blue_gains], dim=-1) gains = gains[:, None, None, :] # outs = bayer_images * gains outs = bayer_images outs[:, :, :, 0] = outs[:, :, :, 0] * gains[:, :, :, 0] outs[:, :, :, 1] = outs[:, :, :, 1] * gains[:, :, :, 1] outs[:, :, :, 2] = outs[:, :, :, 2] * gains[:, :, :, 2] # Re-Permute the tensor back to BxCxHxW format outs = outs.permute(0, 3, 1, 2) return outs
[docs]def apply_ccms(images, ccms): """Applies color correction matrices.""" # Permute the image tensor to BxHxWxC format from BxCxHxW format images = images.permute(0, 2, 3, 1) images = images[:, :, :, None, :] if ccms.shape != (3, 3): ccms = ccms[:, None, None, :, :] else: ccms = ccms[None, None, None, :, :] outs = torch.sum(images * ccms, dim=-1) # Re-Permute the tensor back to BxCxHxW format outs = outs.permute(0, 3, 1, 2) return outs
[docs]def gamma_compression(images, gamma=2.2): """Converts from linear to gamma space.""" # Clamps to prevent numerical instability of gradients near zero. # Permute the image tensor to BxHxWxC format from BxCxHxW format images = images.permute(0, 2, 3, 1) outs = torch.clamp(images, min=1e-8)**(1.0 / gamma) # Re-Permute the tensor back to BxCxHxW format outs = outs.permute(0, 3, 1, 2) return outs
[docs]def process_train(bayer_images, red_gains, blue_gains, cam2rgbs): """Processes a batch of Bayer RGGB images into sRGB images.""" # White balance. bayer_images = apply_gains_jdd(bayer_images, red_gains, blue_gains) # Demosaic. bayer_images = torch.clamp(bayer_images, min=0.0, max=1.0) images = bayer_images # images = demosaic(bayer_images) # Color correction. images = apply_ccms(images, cam2rgbs) # Gamma compression. images = torch.clamp(images, min=0.0, max=1.0) images = gamma_compression(images) # images = smoothstep(images) return images
[docs]def process(bayer_images, red_gains, blue_gains, cam2rgbs): """Processes a batch of Bayer RGGB images into sRGB images.""" # White balance. bayer_images = apply_gains(bayer_images, red_gains, blue_gains) # Demosaic. bayer_images = torch.clamp(bayer_images, min=0.0, max=1.0) images = demosaic(bayer_images) # Color correction. images = apply_ccms(images, cam2rgbs) # Gamma compression. images = torch.clamp(images, min=0.0, max=1.0) images = gamma_compression(images) return images
[docs]def process_test(bayer_images, red_gains, blue_gains, cam2rgbs): """Processes a batch of Bayer RGGB images into sRGB images.""" # White balance. bayer_images = apply_gains_jdd(bayer_images, red_gains, blue_gains) # Demosaic. bayer_images = torch.clamp(bayer_images, min=0.0, max=1.0) images = bayer_images # images = demosaic(bayer_images) # Color correction. images = apply_ccms(images, cam2rgbs) # Gamma compression. images = torch.clamp(images, min=0.0, max=1.0) images = gamma_compression(images) images_show = smoothstep(images) return images, images_show
[docs]def apply_gains(bayer_images, red_gains, blue_gains): """Applies white balance gains to a batch of Bayer images.""" red_gains = red_gains.squeeze(1) blue_gains = blue_gains.squeeze(1) # Permute the image tensor to BxHxWxC format from BxCxHxW format bayer_images = bayer_images.permute(0, 2, 3, 1) green_gains = torch.ones_like(red_gains) gains = torch.stack([red_gains, green_gains, green_gains, blue_gains], dim=-1) gains = gains[:, None, None, :] outs = bayer_images * gains # Re-Permute the tensor back to BxCxHxW format outs = outs.permute(0, 3, 1, 2) return outs
[docs]def demosaic(bayer_images): def SpaceToDepth_fact2(x): bs = 2 N, C, H, W = x.size() # (N, C, H//bs, bs, W//bs, bs) x = x.view(N, C, H // bs, bs, W // bs, bs) # (N, bs, bs, C, H//bs, W//bs) x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, C*bs^2, H//bs, W//bs) x = x.view(N, C * (bs**2), H // bs, W // bs) return x def DepthToSpace_fact2(x): bs = 2 N, C, H, W = x.size() # (N, bs, bs, C//bs^2, H, W) x = x.view(N, bs, bs, C // (bs**2), H, W) # (N, C//bs^2, H, bs, W, bs) x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H * bs, W * bs) x = x.view(N, C // (bs**2), H * bs, W * bs) return x """Bilinearly demosaics a batch of RGGB Bayer images.""" # Permute the image tensor to BxHxWxC format from BxCxHxW format bayer_images = bayer_images.permute(0, 2, 3, 1) shape = bayer_images.size() shape = [shape[1] * 2, shape[2] * 2] red = bayer_images[Ellipsis, 0:1] upsamplebyX = nn.Upsample(size=shape, mode='bilinear', align_corners=False) red = upsamplebyX(red.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) green_red = bayer_images[Ellipsis, 1:2] green_red = torch.flip(green_red, dims=[1]) # Flip left-right green_red = upsamplebyX(green_red.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) green_red = torch.flip(green_red, dims=[1]) # Flip left-right green_red = SpaceToDepth_fact2(green_red.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) green_blue = bayer_images[Ellipsis, 2:3] green_blue = torch.flip(green_blue, dims=[0]) # Flip up-down green_blue = upsamplebyX(green_blue.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) green_blue = torch.flip(green_blue, dims=[0]) # Flip up-down green_blue = SpaceToDepth_fact2(green_blue.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) green_at_red = (green_red[Ellipsis, 0] + green_blue[Ellipsis, 0]) / 2 green_at_green_red = green_red[Ellipsis, 1] green_at_green_blue = green_blue[Ellipsis, 2] green_at_blue = (green_red[Ellipsis, 3] + green_blue[Ellipsis, 3]) / 2 green_planes = [ green_at_red, green_at_green_red, green_at_green_blue, green_at_blue ] green = DepthToSpace_fact2( torch.stack(green_planes, dim=-1).permute(0, 3, 1, 2)).permute(0, 2, 3, 1) blue = bayer_images[Ellipsis, 3:4] blue = torch.flip(torch.flip(blue, dims=[1]), dims=[0]) blue = upsamplebyX(blue.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) blue = torch.flip(torch.flip(blue, dims=[1]), dims=[0]) rgb_images = torch.cat([red, green, blue], dim=-1) # Re-Permute the tensor back to BxCxHxW format rgb_images = rgb_images.permute(0, 3, 1, 2) return rgb_images
[docs]def smoothstep(image): """Approximately inverts a global tone mapping curve.""" temp = torch.mul(image, image) out = 3.0 * temp - 2.0 * torch.mul(temp, image) return out