pyosv.post.patch_extractor

  1import numpy as np
  2
  3
  4def get_patches(img : np.ndarray, kernel : tuple = (16, 16), stride : tuple = (16, 16)) -> np.ndarray:
  5    '''
  6        Split an image into patches
  7
  8        Parameters:
  9        -----------
 10            - img : np.ndarray 
 11                a WxHxB image, with width W, height H and B bands
 12            - kernel : tuple
 13                Tuple of two values used to define the size of the patch
 14            - stride :  tuple
 15                Tuple of two values representing the stride to extract patches
 16
 17        Returns:
 18        --------
 19            - patches : np.ndarray
 20                the Nx(kernel[0])x(kernel[1])xB vector containing the N patches extracted from img
 21        Usage:
 22        ------
 23        ```python
 24        import numpy as np  
 25
 26        img         = np.array(  
 27            [[
 28                [0.1, 0.2, 0.3],  
 29                [0.4, 0.5, 0.6],  
 30                [0.7, 0.8, 0.9]
 31                ],
 32            [
 33                [0.1, 0.2, 0.3],  
 34                [0.4, 0.5, 0.6],  
 35                [0.7, 0.8, 0.9]
 36                ],
 37            [
 38                [0.1, 0.2, 0.3],  
 39                [0.4, 0.5, 0.6],  
 40                [0.7, 0.8, 0.9]
 41                ]
 42            ]  
 43        ) 
 44
 45        # Making channels last
 46        img = np.moveaxis(img, 0, -1)
 47
 48        patches = get_patches(img, kernel=(2, 2), stride=(1,1))
 49        ```
 50
 51        Output:
 52        -------
 53        ```python
 54        [[[[0.1 0.1 0.1]
 55        [0.2 0.2 0.2]]
 56
 57        [[0.4 0.4 0.4]
 58        [0.5 0.5 0.5]]]
 59
 60
 61        [[[0.2 0.2 0.2]
 62        [0.3 0.3 0.3]]
 63
 64        [[0.5 0.5 0.5]
 65        [0.6 0.6 0.6]]]
 66
 67
 68        [[[0.3 0.3 0.3]
 69        [0.3 0.3 0.3]]
 70
 71        [[0.6 0.6 0.6]
 72        [0.6 0.6 0.6]]]
 73
 74
 75        [[[0.4 0.4 0.4]
 76        [0.5 0.5 0.5]]
 77
 78        [[0.7 0.7 0.7]
 79        [0.8 0.8 0.8]]]
 80
 81
 82        [[[0.5 0.5 0.5]
 83        [0.6 0.6 0.6]]
 84
 85        [[0.8 0.8 0.8]
 86        [0.9 0.9 0.9]]]
 87
 88
 89        [[[0.6 0.6 0.6]
 90        [0.6 0.6 0.6]]
 91
 92        [[0.9 0.9 0.9]
 93        [0.9 0.9 0.9]]]
 94
 95
 96        [[[0.7 0.7 0.7]
 97        [0.8 0.8 0.8]]
 98
 99        [[0.7 0.7 0.7]
100        [0.8 0.8 0.8]]]
101
102
103        [[[0.8 0.8 0.8]
104        [0.9 0.9 0.9]]
105
106        [[0.8 0.8 0.8]
107        [0.9 0.9 0.9]]]
108
109
110        [[[0.9 0.9 0.9]
111        [0.9 0.9 0.9]]
112
113        [[0.9 0.9 0.9]
114        [0.9 0.9 0.9]]]]
115        ```
116    '''
117
118    if len(img.shape) != 3:
119        raise Exception("Error: len of img.shape must be 3")
120    if kernel[0] < 1 or kernel[1] < 1:
121        raise Exception("Error: kernel must be grather than 1")
122    if stride[0] < 1 or stride[1] < 1:
123        raise Exception("Error: kernel must be grather than 1")
124    
125    w, h, c = img.shape
126
127    if kernel[0] > w or kernel[1] > h:
128        raise Exception("Error: kernel size cannot be grather than image size")
129    if stride[0] > w or stride[1] > h:
130        raise Exception("Error: stride cannot be grather than image size")
131    
132    
133    w, h, c = img.shape
134    patches = []
135    for ww in range(0, w - kernel[0] + 1, stride[0]):
136        for hh in range(0,  h - kernel[1] + 1, stride[1]):         
137            patches.append(img[ww:ww+kernel[0],hh:hh+kernel[1],...])
138
139    return np.array(patches)
140
141
142
143def reverse_get_patches(patches : np.ndarray, img_shape:tuple, kernel : tuple = (16, 16), stride : tuple = (16, 16)) -> np.ndarray:
144    '''
145        Reconstruct an image from its patches
146
147        Parameters:
148        -----------
149            - patches : np.ndarray 
150                a NxWxHxB image, with width W, height H and B bands
151            - img_shape: tuple
152                a tuple defining the shape of the reconstructed image (W,H,B) with width W, height H and B bands
153            - kernel : tuple
154                Tuple of two values used to define the size of the patch
155            - stride :  tuple
156                Tuple of two values representing the stride to extract patches
157
158        Returns:
159        --------
160            - img : np.ndarray
161                the WxHxB reconstructed image
162        Usage:
163        ------
164        ```python
165        import numpy as np  
166
167        patches = np.array(
168            [[[[0.1 0.1 0.1]
169            [0.2 0.2 0.2]]
170
171            [[0.4 0.4 0.4]
172            [0.5 0.5 0.5]]]
173
174
175            [[[0.2 0.2 0.2]
176            [0.3 0.3 0.3]]
177
178            [[0.5 0.5 0.5]
179            [0.6 0.6 0.6]]]
180
181
182            [[[0.3 0.3 0.3]
183            [0.3 0.3 0.3]]
184
185            [[0.6 0.6 0.6]
186            [0.6 0.6 0.6]]]
187
188
189            [[[0.4 0.4 0.4]
190            [0.5 0.5 0.5]]
191
192            [[0.7 0.7 0.7]
193            [0.8 0.8 0.8]]]
194
195
196            [[[0.5 0.5 0.5]
197            [0.6 0.6 0.6]]
198
199            [[0.8 0.8 0.8]
200            [0.9 0.9 0.9]]]
201
202
203            [[[0.6 0.6 0.6]
204            [0.6 0.6 0.6]]
205
206            [[0.9 0.9 0.9]
207            [0.9 0.9 0.9]]]
208
209
210            [[[0.7 0.7 0.7]
211            [0.8 0.8 0.8]]
212
213            [[0.7 0.7 0.7]
214            [0.8 0.8 0.8]]]
215
216
217            [[[0.8 0.8 0.8]
218            [0.9 0.9 0.9]]
219
220            [[0.8 0.8 0.8]
221            [0.9 0.9 0.9]]]
222
223
224            [[[0.9 0.9 0.9]
225            [0.9 0.9 0.9]]
226
227            [[0.9 0.9 0.9]
228            [0.9 0.9 0.9]]]]
229        )
230
231
232
233        patches = reverse_get_patches(patches, img_shape=(3,3,3), kernel=(2, 2), stride=(1,1))
234        ```
235
236        Output:
237        -------
238        ```python
239        [[
240                [0.1, 0.2, 0.3],  
241                [0.4, 0.5, 0.6],  
242                [0.7, 0.8, 0.9]
243                ],
244            [
245                [0.1, 0.2, 0.3],  
246                [0.4, 0.5, 0.6],  
247                [0.7, 0.8, 0.9]
248                ],
249            [
250                [0.1, 0.2, 0.3],  
251                [0.4, 0.5, 0.6],  
252                [0.7, 0.8, 0.9]
253                ]
254        ]]  
255        ```
256    '''
257
258    if len(patches.shape) != 4:
259        raise Exception("Error: len of img.shape must be 4")
260    if kernel[0] < 1 or kernel[1] < 1:
261        raise Exception("Error: kernel must be grather than 1")
262    if stride[0] < 1 or stride[1] < 1:
263        raise Exception("Error: kernel must be grather than 1")
264
265    img_height, img_width, channels = img_shape
266    patch_height, patch_width = kernel
267    stride_y, stride_x = stride
268    
269    reconstructed_image = np.zeros((img_height, img_width, channels))
270    count_map = np.zeros((img_height, img_width, channels))
271    
272    patch_index = 0
273    for y in range(0, img_height - patch_height + 1, stride_y):
274        for x in range(0, img_width - patch_width + 1, stride_x):
275            reconstructed_image[y:y + patch_height, x:x + patch_width,...] += patches[patch_index, ...]
276            count_map[y:y + patch_height, x:x + patch_width] += 1
277            patch_index += 1
278    
279    reconstructed_image /= np.maximum(count_map, 1)
280
281    return reconstructed_image
def get_patches( img: numpy.ndarray, kernel: tuple = (16, 16), stride: tuple = (16, 16)) -> numpy.ndarray:
  5def get_patches(img : np.ndarray, kernel : tuple = (16, 16), stride : tuple = (16, 16)) -> np.ndarray:
  6    '''
  7        Split an image into patches
  8
  9        Parameters:
 10        -----------
 11            - img : np.ndarray 
 12                a WxHxB image, with width W, height H and B bands
 13            - kernel : tuple
 14                Tuple of two values used to define the size of the patch
 15            - stride :  tuple
 16                Tuple of two values representing the stride to extract patches
 17
 18        Returns:
 19        --------
 20            - patches : np.ndarray
 21                the Nx(kernel[0])x(kernel[1])xB vector containing the N patches extracted from img
 22        Usage:
 23        ------
 24        ```python
 25        import numpy as np  
 26
 27        img         = np.array(  
 28            [[
 29                [0.1, 0.2, 0.3],  
 30                [0.4, 0.5, 0.6],  
 31                [0.7, 0.8, 0.9]
 32                ],
 33            [
 34                [0.1, 0.2, 0.3],  
 35                [0.4, 0.5, 0.6],  
 36                [0.7, 0.8, 0.9]
 37                ],
 38            [
 39                [0.1, 0.2, 0.3],  
 40                [0.4, 0.5, 0.6],  
 41                [0.7, 0.8, 0.9]
 42                ]
 43            ]  
 44        ) 
 45
 46        # Making channels last
 47        img = np.moveaxis(img, 0, -1)
 48
 49        patches = get_patches(img, kernel=(2, 2), stride=(1,1))
 50        ```
 51
 52        Output:
 53        -------
 54        ```python
 55        [[[[0.1 0.1 0.1]
 56        [0.2 0.2 0.2]]
 57
 58        [[0.4 0.4 0.4]
 59        [0.5 0.5 0.5]]]
 60
 61
 62        [[[0.2 0.2 0.2]
 63        [0.3 0.3 0.3]]
 64
 65        [[0.5 0.5 0.5]
 66        [0.6 0.6 0.6]]]
 67
 68
 69        [[[0.3 0.3 0.3]
 70        [0.3 0.3 0.3]]
 71
 72        [[0.6 0.6 0.6]
 73        [0.6 0.6 0.6]]]
 74
 75
 76        [[[0.4 0.4 0.4]
 77        [0.5 0.5 0.5]]
 78
 79        [[0.7 0.7 0.7]
 80        [0.8 0.8 0.8]]]
 81
 82
 83        [[[0.5 0.5 0.5]
 84        [0.6 0.6 0.6]]
 85
 86        [[0.8 0.8 0.8]
 87        [0.9 0.9 0.9]]]
 88
 89
 90        [[[0.6 0.6 0.6]
 91        [0.6 0.6 0.6]]
 92
 93        [[0.9 0.9 0.9]
 94        [0.9 0.9 0.9]]]
 95
 96
 97        [[[0.7 0.7 0.7]
 98        [0.8 0.8 0.8]]
 99
100        [[0.7 0.7 0.7]
101        [0.8 0.8 0.8]]]
102
103
104        [[[0.8 0.8 0.8]
105        [0.9 0.9 0.9]]
106
107        [[0.8 0.8 0.8]
108        [0.9 0.9 0.9]]]
109
110
111        [[[0.9 0.9 0.9]
112        [0.9 0.9 0.9]]
113
114        [[0.9 0.9 0.9]
115        [0.9 0.9 0.9]]]]
116        ```
117    '''
118
119    if len(img.shape) != 3:
120        raise Exception("Error: len of img.shape must be 3")
121    if kernel[0] < 1 or kernel[1] < 1:
122        raise Exception("Error: kernel must be grather than 1")
123    if stride[0] < 1 or stride[1] < 1:
124        raise Exception("Error: kernel must be grather than 1")
125    
126    w, h, c = img.shape
127
128    if kernel[0] > w or kernel[1] > h:
129        raise Exception("Error: kernel size cannot be grather than image size")
130    if stride[0] > w or stride[1] > h:
131        raise Exception("Error: stride cannot be grather than image size")
132    
133    
134    w, h, c = img.shape
135    patches = []
136    for ww in range(0, w - kernel[0] + 1, stride[0]):
137        for hh in range(0,  h - kernel[1] + 1, stride[1]):         
138            patches.append(img[ww:ww+kernel[0],hh:hh+kernel[1],...])
139
140    return np.array(patches)

Split an image into patches

Parameters:

- img : np.ndarray 
    a WxHxB image, with width W, height H and B bands
- kernel : tuple
    Tuple of two values used to define the size of the patch
- stride :  tuple
    Tuple of two values representing the stride to extract patches

Returns:

- patches : np.ndarray
    the Nx(kernel[0])x(kernel[1])xB vector containing the N patches extracted from img

Usage:

import numpy as np  

img         = np.array(  
    [[
        [0.1, 0.2, 0.3],  
        [0.4, 0.5, 0.6],  
        [0.7, 0.8, 0.9]
        ],
    [
        [0.1, 0.2, 0.3],  
        [0.4, 0.5, 0.6],  
        [0.7, 0.8, 0.9]
        ],
    [
        [0.1, 0.2, 0.3],  
        [0.4, 0.5, 0.6],  
        [0.7, 0.8, 0.9]
        ]
    ]  
) 

# Making channels last
img = np.moveaxis(img, 0, -1)

patches = get_patches(img, kernel=(2, 2), stride=(1,1))

Output:

[[[[0.1 0.1 0.1]
[0.2 0.2 0.2]]

[[0.4 0.4 0.4]
[0.5 0.5 0.5]]]


[[[0.2 0.2 0.2]
[0.3 0.3 0.3]]

[[0.5 0.5 0.5]
[0.6 0.6 0.6]]]


[[[0.3 0.3 0.3]
[0.3 0.3 0.3]]

[[0.6 0.6 0.6]
[0.6 0.6 0.6]]]


[[[0.4 0.4 0.4]
[0.5 0.5 0.5]]

[[0.7 0.7 0.7]
[0.8 0.8 0.8]]]


[[[0.5 0.5 0.5]
[0.6 0.6 0.6]]

[[0.8 0.8 0.8]
[0.9 0.9 0.9]]]


[[[0.6 0.6 0.6]
[0.6 0.6 0.6]]

[[0.9 0.9 0.9]
[0.9 0.9 0.9]]]


[[[0.7 0.7 0.7]
[0.8 0.8 0.8]]

[[0.7 0.7 0.7]
[0.8 0.8 0.8]]]


[[[0.8 0.8 0.8]
[0.9 0.9 0.9]]

[[0.8 0.8 0.8]
[0.9 0.9 0.9]]]


[[[0.9 0.9 0.9]
[0.9 0.9 0.9]]

[[0.9 0.9 0.9]
[0.9 0.9 0.9]]]]
def reverse_get_patches( patches: numpy.ndarray, img_shape: tuple, kernel: tuple = (16, 16), stride: tuple = (16, 16)) -> numpy.ndarray:
144def reverse_get_patches(patches : np.ndarray, img_shape:tuple, kernel : tuple = (16, 16), stride : tuple = (16, 16)) -> np.ndarray:
145    '''
146        Reconstruct an image from its patches
147
148        Parameters:
149        -----------
150            - patches : np.ndarray 
151                a NxWxHxB image, with width W, height H and B bands
152            - img_shape: tuple
153                a tuple defining the shape of the reconstructed image (W,H,B) with width W, height H and B bands
154            - kernel : tuple
155                Tuple of two values used to define the size of the patch
156            - stride :  tuple
157                Tuple of two values representing the stride to extract patches
158
159        Returns:
160        --------
161            - img : np.ndarray
162                the WxHxB reconstructed image
163        Usage:
164        ------
165        ```python
166        import numpy as np  
167
168        patches = np.array(
169            [[[[0.1 0.1 0.1]
170            [0.2 0.2 0.2]]
171
172            [[0.4 0.4 0.4]
173            [0.5 0.5 0.5]]]
174
175
176            [[[0.2 0.2 0.2]
177            [0.3 0.3 0.3]]
178
179            [[0.5 0.5 0.5]
180            [0.6 0.6 0.6]]]
181
182
183            [[[0.3 0.3 0.3]
184            [0.3 0.3 0.3]]
185
186            [[0.6 0.6 0.6]
187            [0.6 0.6 0.6]]]
188
189
190            [[[0.4 0.4 0.4]
191            [0.5 0.5 0.5]]
192
193            [[0.7 0.7 0.7]
194            [0.8 0.8 0.8]]]
195
196
197            [[[0.5 0.5 0.5]
198            [0.6 0.6 0.6]]
199
200            [[0.8 0.8 0.8]
201            [0.9 0.9 0.9]]]
202
203
204            [[[0.6 0.6 0.6]
205            [0.6 0.6 0.6]]
206
207            [[0.9 0.9 0.9]
208            [0.9 0.9 0.9]]]
209
210
211            [[[0.7 0.7 0.7]
212            [0.8 0.8 0.8]]
213
214            [[0.7 0.7 0.7]
215            [0.8 0.8 0.8]]]
216
217
218            [[[0.8 0.8 0.8]
219            [0.9 0.9 0.9]]
220
221            [[0.8 0.8 0.8]
222            [0.9 0.9 0.9]]]
223
224
225            [[[0.9 0.9 0.9]
226            [0.9 0.9 0.9]]
227
228            [[0.9 0.9 0.9]
229            [0.9 0.9 0.9]]]]
230        )
231
232
233
234        patches = reverse_get_patches(patches, img_shape=(3,3,3), kernel=(2, 2), stride=(1,1))
235        ```
236
237        Output:
238        -------
239        ```python
240        [[
241                [0.1, 0.2, 0.3],  
242                [0.4, 0.5, 0.6],  
243                [0.7, 0.8, 0.9]
244                ],
245            [
246                [0.1, 0.2, 0.3],  
247                [0.4, 0.5, 0.6],  
248                [0.7, 0.8, 0.9]
249                ],
250            [
251                [0.1, 0.2, 0.3],  
252                [0.4, 0.5, 0.6],  
253                [0.7, 0.8, 0.9]
254                ]
255        ]]  
256        ```
257    '''
258
259    if len(patches.shape) != 4:
260        raise Exception("Error: len of img.shape must be 4")
261    if kernel[0] < 1 or kernel[1] < 1:
262        raise Exception("Error: kernel must be grather than 1")
263    if stride[0] < 1 or stride[1] < 1:
264        raise Exception("Error: kernel must be grather than 1")
265
266    img_height, img_width, channels = img_shape
267    patch_height, patch_width = kernel
268    stride_y, stride_x = stride
269    
270    reconstructed_image = np.zeros((img_height, img_width, channels))
271    count_map = np.zeros((img_height, img_width, channels))
272    
273    patch_index = 0
274    for y in range(0, img_height - patch_height + 1, stride_y):
275        for x in range(0, img_width - patch_width + 1, stride_x):
276            reconstructed_image[y:y + patch_height, x:x + patch_width,...] += patches[patch_index, ...]
277            count_map[y:y + patch_height, x:x + patch_width] += 1
278            patch_index += 1
279    
280    reconstructed_image /= np.maximum(count_map, 1)
281
282    return reconstructed_image

Reconstruct an image from its patches

Parameters:

- patches : np.ndarray 
    a NxWxHxB image, with width W, height H and B bands
- img_shape: tuple
    a tuple defining the shape of the reconstructed image (W,H,B) with width W, height H and B bands
- kernel : tuple
    Tuple of two values used to define the size of the patch
- stride :  tuple
    Tuple of two values representing the stride to extract patches

Returns:

- img : np.ndarray
    the WxHxB reconstructed image

Usage:

import numpy as np  

patches = np.array(
    [[[[0.1 0.1 0.1]
    [0.2 0.2 0.2]]

    [[0.4 0.4 0.4]
    [0.5 0.5 0.5]]]


    [[[0.2 0.2 0.2]
    [0.3 0.3 0.3]]

    [[0.5 0.5 0.5]
    [0.6 0.6 0.6]]]


    [[[0.3 0.3 0.3]
    [0.3 0.3 0.3]]

    [[0.6 0.6 0.6]
    [0.6 0.6 0.6]]]


    [[[0.4 0.4 0.4]
    [0.5 0.5 0.5]]

    [[0.7 0.7 0.7]
    [0.8 0.8 0.8]]]


    [[[0.5 0.5 0.5]
    [0.6 0.6 0.6]]

    [[0.8 0.8 0.8]
    [0.9 0.9 0.9]]]


    [[[0.6 0.6 0.6]
    [0.6 0.6 0.6]]

    [[0.9 0.9 0.9]
    [0.9 0.9 0.9]]]


    [[[0.7 0.7 0.7]
    [0.8 0.8 0.8]]

    [[0.7 0.7 0.7]
    [0.8 0.8 0.8]]]


    [[[0.8 0.8 0.8]
    [0.9 0.9 0.9]]

    [[0.8 0.8 0.8]
    [0.9 0.9 0.9]]]


    [[[0.9 0.9 0.9]
    [0.9 0.9 0.9]]

    [[0.9 0.9 0.9]
    [0.9 0.9 0.9]]]]
)



patches = reverse_get_patches(patches, img_shape=(3,3,3), kernel=(2, 2), stride=(1,1))

Output:

[[
        [0.1, 0.2, 0.3],  
        [0.4, 0.5, 0.6],  
        [0.7, 0.8, 0.9]
        ],
    [
        [0.1, 0.2, 0.3],  
        [0.4, 0.5, 0.6],  
        [0.7, 0.8, 0.9]
        ],
    [
        [0.1, 0.2, 0.3],  
        [0.4, 0.5, 0.6],  
        [0.7, 0.8, 0.9]
        ]
]]