classdef matFilesDatastore < matlab.io.Datastore & ...
matlab.io.datastore.Shuffleable & ...
matlab.io.datastore.Partitionable
properties
Datastore
Labels
ReadSize
end
properties(SetAccess = protected)
NumObservations
end
properties(Access = private)
CurrentFileIndex
end
methods
function ds = matFilesDatastore(folder, labels)
fds = fileDatastore(folder, ...
'ReadFcn',@readData, ...
'IncludeSubfolders',true);
ds.Datastore = fds;
numObservations = numel(fds.Files);
ds.Labels = labels;
ds.ReadSize = 1;
ds.NumObservations = numObservations;
ds.CurrentFileIndex = 1;
end
function tf = hasdata(ds)
tf = ds.CurrentFileIndex + ds.ReadSize - 1 ...
<= ds.NumObservations;
end
function [data,info] = read(ds)
miniBatchSize = ds.ReadSize;
info = struct;
for i = 1:miniBatchSize
predictors{i,1} = read(ds.Datastore);
responses{i,1} = ds.Labels(ds.CurrentFileIndex);
ds.CurrentFileIndex = ds.CurrentFileIndex + 1;
end
data = table(predictors,responses);
end
function reset(ds)
reset(ds.Datastore);
ds.CurrentFileIndex = 1;
end
function dsNew = shuffle(ds)
dsNew = copy(ds);
dsNew.Datastore = copy(ds.Datastore);
fds = dsNew.Datastore;
numObservations = dsNew.NumObservations;
idx = randperm(numObservations);
fds.Files = fds.Files(idx);
dsNew.Labels = dsNew.Labels(idx);
end
function subds = partition(ds, numPartitions, idx)
subds = copy(ds);
subds.Datastore = partition(ds.Datastore, numPartitions, idx);
subds.NumObservations = numel(subds.Datastore.Files);
indices = pigeonHole(idx, numPartitions, ds.NumObservations);
subds.Labels = ds.Labels(indices);
reset(subds);
end
end
methods(Access = protected)
function n = maxpartitions(ds)
n = ds.NumObservations;
end
end
methods (Hidden = true)
function frac = progress(ds)
frac = (ds.CurrentFileIndex - 1) / ds.NumObservations;
end
end
end
function data = readData(filename)
S = load(filename);
data = S.image;
end
function observationIndices = pigeonHole(partitionIndex, numPartitions, numObservations)
observationIndices = floor((0:numObservations - 1) * numPartitions / numObservations) + 1;
observationIndices = find(observationIndices == partitionIndex);
if isempty(observationIndices)
observationIndices = double.empty(0, 1);
end
end
Best Answer