Coding Phase 1.4

Hi!

Welcome back! Picking up on where I left last time, the implementation of U-Net using deeplearning4j is coming along fairly well. Last week I had been able to train 100 images over 50 images using https://github.com/montardon/unetdl4j/blob/master/TrainUnetModel.java as a basis and infer an output as well. 
Following are the tasks I completed this week:

1. Forking the deeplearning4j codebase (https://github.com/eclipse/deeplearning4j) and pushing the U-Net code. 
The link to my forked repository is https://github.com/Medha-B/deeplearning4j and the template U-Net code is available at https://github.com/Medha-B/deeplearning4j/tree/master/Unet_GSocThe initial commit, however, can be found at https://github.com/Medha-B/unetdl4j/tree/master/src in the repository https://github.com/Medha-B/unetdl4j (forked from https://github.com/funasoul/unetdl4j).
 
2. Writing code for saving the trained model and then loading it.  
For training the UNet model, we use

log.info("*****TRAIN MODEL******");
ComputationGraph model  = UNet.builder().updater(new Adam(1e-4)).build().init();
// ComputationGraph model  = UNet.builder().updater(new Adam(new MapSchedule(ScheduleType.ITERATION, learningScheduleMap))).build().init();
model.addListeners(new ScoreIterationListener());
model.fit(imageDataSetIterator,numEpochs);

where
DataSetIterator imageDataSetIterator = new RecordReaderDataSetIterator(imageRecordReader,batchSize,labelIndex,labelIndex,true);

Similarly, for saving the trained model, we can use

log.info("*****SAVE MODEL******");
//Location for saving the model
File locationTosave = new File("C:\\Users\\--\\save.zip");
boolean saveUpdater = false;
//ModelSerializer needs a Model name, Location of saving the model, and saveUpdater. 
ModelSerializer.writeModel(model,locationTosave,saveUpdater);
 
Likewise, for loading the trained model and using it, we can use

log.info("*****LOAD MODEL******");
//Location where the model is saved
File locationTosave = new File("C:\\Users\\--\\save.zip");
ComputationGraph model  = ModelSerializer.restoreComputationGraph(locationTosave); 

3. Splitting the code into different types according to functionality. 
I have endeavored to split the main function https://github.com/Medha-B/unetdl4j/blob/master/src/main/java/org/sbml/spatial/segmentation/TrainUnetModel.java into UnetTrainAndTest.java, UnetTrainAndSave.java, and UnetLoadAndTest.java.
Note: To download the trained model configurations (weights) for 100 images x 50 epochs and 300 images x 100 epochs, visit https://drive.google.com/drive/folders/19Uo1q_qj6PLlhvunRoH7K9yzsvXjKA_h?usp=sharing
 
4. Model inference with more images.
The dataset used for training 100 images can be obtained at https://drive.google.com/drive/folders/1u3SgJYb1LObpboEKkURQr3Mh7FrPrf_8?usp=sharing (https://drive.google.com/drive/folders/1u3SgJYb1LObpboEKkURQr3Mh7FrPrf_8)
The dataset used for training 300 images can be obtained at https://drive.google.com/drive/folders/1UUq6W-3P7Mg-eSE6_UJSCQaC8Xazc3zH?usp=sharing (https://drive.google.com/drive/folders/1UUq6W-3P7Mg-eSE6_UJSCQaC8Xazc3zH)
The raw cell images used for testing can be obtained at https://drive.google.com/drive/folders/1lNphWDWUDq6U4K25kP-zHL8U4ETawuDE?usp=sharing (https://drive.google.com/drive/folders/1lNphWDWUDq6U4K25kP-zHL8U4ETawuDE)
Inferred Image

Corresponding ground truth

Composite image

Difference of images


The inferred images, corresponding ground truth, composite image and, their difference can be obtained at 

5. Writing the pom.xml file for a separate maven project for U-Net segmentation.
The pom.xml file can be obtained at https://github.com/Medha-B/unetdl4j/blob/master/pom.xml
Note to self: Important tutorials regarding maven projects (many thanks to Dr. Akira Funahashi)-

6. Training the U-Net with 1 input channel
The method setInputShape(int[][] inputShape) in the class UNet can be used to alter the number of channels from 3 to 1. However, it cannot be directly used with type declaration ComputationGraph. The U-Net model needs to be instantiated as type Zoomodel and then setInputShape can be used. The code will look like this

ZooModel unet = UNet.builder().build();
unet.setInputShape(new int[][]{{1, 128, 128}});
ComputationGraph model = (ComputationGraph) unet.init();

The code compiled successfully and worked but the results were not satisfactory. After training the model with 44 images over 20 epochs (dataset available at https://drive.google.com/drive/folders/1Ox0fi1V9dwBXPHisgLc9kjaIfZFZ27dy?usp=sharing), the result was a dark image with all pixel values as 37.
To verify, I again trained the models with 100 images over 40 epochs and got a similar dark image; only this time all the pixel values were 0.
Image inferred after training with 44 images for a single input channel

Image inferred after training with 100 images for a single input channel

To download these images, visit https://drive.google.com/drive/folders/1CkQ-PtkmxtxgzkeX9CPtWpQTD43gswm2?usp=sharing

I introduced the changes in code for running this under 
TrainUnetModel.java at https://github.com/Medha-B/unetdl4j/blob/master/src/main/java/org/sbml/spatial/segmentation/TrainUnetModel.java (commit d901fb0).

Note: The loss function being used inherently in the U-Net is XENT, which stands for Cross-Entropy: Binary Classification.
It can be seen in the  final convolution layer as

.addLayer("output", new CnnLossLayer.Builder(LossFunctions.LossFunction.XENT)

I forgot to add the code to calculate the average error before training the model, this is what it looked like:


Comments

Popular Posts