MathJax 3

Friday, 30 July 2021

RGB channel decomposition - Python implementation

In the previous post, we have seen histogram equalization for grayscale images. Before learning about colour image histogram, we must know how to separate the channels of an image to be able to process them separately. We will learn how to do that in this post.
 
For this, we will use an RGB image with a bit depth of 8 bits per channel. I wrote a simple function which returns a copy of the input image containing only the data for a single channel. It's possible to select the required channel using a key ('R', 'G' or 'B') passed as an argument to the function: 

 

def get_channel(input_image, channel_key):
   if input_image.mode != 'RGB':
       return None
   else:
      channels_indices = {'R':0 ,'G':1 ,'B':2}
      output_image = Image.new('RGB', (input_image.width,
                                          
                    input_image.height))
      ch_idx = channels_indices[channel_key]
      for x in range(input_image.width):
         for y in range(input_image.height):
            pix = input_image.getpixel((x, y))
            pix = pix[ch_idx]
            if ch_idx == 0:
               output_image.putpixel((x,y), (pix,0,0))
            elif ch_idx == 1:
               output_image.putpixel((x,y), (0,pix,0))
            elif ch_idx == 2:
               output_image.putpixel((x,y), (0,0,pix))
                  
       return output_image

 

the function above takes an input image and a key as arguments. The key is used to choose the relevant channel's index for the pixel's tuple. Let's go through the most relevant lines:

channels_indices dict to map keys to indices,

channels_indices = {'R':0 ,'G':1 ,'B':2} 

ch_idx gets 0, 1 or 2 based on channel_key value

ch_idx = channels_indices[channel_key] 

first we get the pixel's tuple, then we get the channel's value from it:  

pix = input_image.getpixel((x, y))
pix = pix[ch_idx]

we also use ch_idx to select the output channel:

 if ch_idx == 0:
   output_image.putpixel((x,y), (pix,0,0))
 elif ch_idx == 1:
   output_image.putpixel((x,y), (0,pix,0))
 elif ch_idx == 2:
   output_image.putpixel((x,y), (0,0,pix))

In the following picture we can see the original image with the three channels separated:

 


 

and this is the full code:

 

from PIL import Image, ImageTk
import tkinter as tk


def get_channel(input_image, channel_key):
   if input_image.mode != 'RGB':
       return None
   else:
      channels_indices = {'R':0 ,'G':1 ,'B':2}
      output_image = Image.new('RGB', (input_image.width,
                                      input_image.height))
      ch_idx = channels_indices[channel_key]
      for x in range(input_image.width):
         for y in range(input_image.height):
            pix = input_image.getpixel((x, y))
            pix = pix[ch_idx]
            if ch_idx == 0:
               output_image.putpixel((x,y), (pix,0,0))
            elif ch_idx == 1:
               output_image.putpixel((x,y), (0,pix,0))
            elif ch_idx == 2:
               output_image.putpixel((x,y), (0,0,pix))
                  
       return output_image

if __name__ == "__main__":
   root = tk.Tk()
   img = Image.open("retriver.png")

   width = img.width*2+20
   height = img.height*2+20
   root.title("RGB channel decomposition")
   root.geometry(f'{width}x{height}')

   output_R = get_channel(img, 'R')  
   output_G = get_channel(img, 'G')
   output_B = get_channel(img, 'B')
     
   hist_w = img.width
   hist_h = img.height
   
   if output_R != None and output_G!= None and output_B!=None:
      output_R = ImageTk.PhotoImage(output_R)
      output_G = ImageTk.PhotoImage(output_G)
      output_B = ImageTk.PhotoImage(output_B)
      input_im = ImageTk.PhotoImage(img)

      canvas = tk.Canvas(root, width=width, height=height, bg="#ffffff")
      canvas.create_image(width/4-1, height/4-1, image=input_im, state="normal")
      canvas.create_image(3*width/4, height/4-1, image=output_R, state="normal")
      canvas.create_image(width/4-1, 3*height/4-1, image=output_G, state="normal")
      canvas.create_image(3*width/4, 3*height/4-1, image=output_B, state="normal")
      
      canvas.place(x=0, y=0)
      canvas.pack()

      root.mainloop()
   else:
        print("Input image must be in 'RGB' colour space")

 
 

Now we can learn how to plot an histogram for a colour image. Follow me in the next post to see how it is done and how it can help us.

No comments:

Post a Comment