Set dtype of rasterio data upon reading

pythonrasterio

I have the following piece of rasterio Python code, to read in a tiffs with discrete integer classes, which I mask immediately upon loading. I need finegrained control over the nodata value, which needs to be the same regardless of the input data type.

import rasterio
from rasterio.mask import mask

#[...]

    with rasterio.open(Geotiff_path) as src:
        cropped_image, _ = mask(src, 
                                POL, 
                                nodata=128,
                                crop=True,
                                all_touched=True
                               )
    return cropped_image.squeeze()

The problem is that if print(src.dtypes) is int8, the nodata=128 overflows back to -128. Of course, one fix is to simply use nodata=127. However, I would prefer it if it were possible to return the cropped_image with a different dtype such as int16, wherein the nodata didn't overflow.

Note that doing cropped_image.squeeze().astype(np.uint16) does the transformation too late, i.e., the overflow of the nodata value to -128 already happened. So the question is: how can I read in the rasterio grid (or the masked version) with a different dtype than the original grid reports?

I'm using rasterio==1.2.10.

Best Answer

Don't specify nodata=128. If filled=False, mask returns a np.ma.MaskedArray which you can cast to np.uint16 then set all values to 128 where the mask is True:

import fiona
import rasterio
from rasterio.mask import mask

with fiona.open("/tmp/mask.shp", "r") as shapefile:
    shapes = [feature["geometry"] for feature in shapefile]

with rasterio.open("/tmp/test.tif") as src:
    out_image, out_transform = mask(src, shapes, crop=True, filled=False, all_touched=True)
    data = out_image.data.astype('uint16')
    data[out_image.mask] = 128

Alternatively, use a WarpedVRT to alter dtype on the fly:

import fiona
import rasterio
from rasterio.mask import mask
from rasterio.vrt import WarpedVRT

with fiona.open("/tmp/mask.shp", "r") as shapefile:
    shapes = [feature["geometry"] for feature in shapefile]

with WarpedVRT(rasterio.open("/tmp/test.tif"), dtype='uint16') as vrt:
    out_image, out_transform = mask(vrt, shapes, nodata=128, crop=True, all_touched=True)