Coding Phase 1.4
Hi!
Welcome back! Picking up on where I left last time, the implementation of U-Net using deeplearning4j is coming along fairly well. Last week I had been able to train 100 images over 50 images using https://github.com/montardon/unetdl4j/blob/master/TrainUnetModel.java as a basis and infer an output as well.
Following are the tasks I completed this week:
1. Forking the deeplearning4j codebase (https://github.com/eclipse/deeplearning4j) and pushing the U-Net code.
The link to my forked repository is https://github.com/Medha-B/deeplearning4j and the template U-Net code is available at https://github.com/Medha-B/deeplearning4j/tree/master/Unet_GSoc. The initial commit, however, can be found at https://github.com/Medha-B/unetdl4j/tree/master/src in the repository https://github.com/Medha-B/unetdl4j (forked from https://github.com/funasoul/unetdl4j).
2. Writing code for saving the trained model and then loading it.
For training the UNet model, we use
log.info("*****TRAIN MODEL******");
ComputationGraph model = UNet.builder().updater(new Adam(1e-4)).build().init();
// ComputationGraph model = UNet.builder().updater(new Adam(new MapSchedule(ScheduleType.ITERATION, learningScheduleMap))).build().init();
model.addListeners(new ScoreIterationListener());
model.fit(imageDataSetIterator,numEpochs);
where
DataSetIterator imageDataSetIterator = new RecordReaderDataSetIterator(imageRecordReader,batchSize,labelIndex,labelIndex,true);
Similarly, for saving the trained model, we can use
log.info("*****SAVE MODEL******");
//Location for saving the model
File locationTosave = new File("C:\\Users\\--\\save.zip");
boolean saveUpdater = false;
//ModelSerializer needs a Model name, Location of saving the model, and saveUpdater.
ModelSerializer.writeModel(model,locationTosave,saveUpdater);
Likewise, for loading the trained model and using it, we can use
log.info("*****LOAD MODEL******");
//Location where the model is saved
File locationTosave = new File("C:\\Users\\--\\save.zip");
ComputationGraph model = ModelSerializer.restoreComputationGraph(locationTosave);
3. Splitting the code into different types according to functionality.
I have endeavored to split the main function https://github.com/Medha-B/unetdl4j/blob/master/src/main/java/org/sbml/spatial/segmentation/TrainUnetModel.java into UnetTrainAndTest.java, UnetTrainAndSave.java, and UnetLoadAndTest.java.
- UnetTrainAndTest.java: For training the model and then evaluating it using inference at https://github.com/Medha-B/unetdl4j/blob/master/src/main/java/org/sbml/spatial/segmentation/UnetTrainAndTest.java (commit f322c4).
- UnetTrainAndSave.java: For training the model and then saving it at https://github.com/Medha-B/unetdl4j/blob/master/src/main/java/org/sbml/spatial/segmentation/UnetTrainAndSave.java (commit 6738f7).
- UnetLoadAndTest.java: For loading the trained model and then evaluating it using inference at https://github.com/Medha-B/unetdl4j/blob/master/src/main/java/org/sbml/spatial/segmentation/UnetLoadAndTest.java (commit 95e54fc).
Note: To download the trained model configurations (weights) for 100 images x 50 epochs and 300 images x 100 epochs, visit https://drive.google.com/ drive/folders/19Uo1q_ qj6PLlhvunRoH7K9yzsvXjKA_h? usp=sharing
4. Model inference with more images.
The dataset used for training 100 images can be obtained at https://drive.google.com/drive/folders/1u3SgJYb1LObpboEKkURQr3Mh7FrPrf_8?usp=sharing (https://drive.google.com/drive/folders/1u3SgJYb1LObpboEKkURQr3Mh7FrPrf_8)
The dataset used for training 300 images can be obtained at https://drive.google.com/drive/folders/1UUq6W-3P7Mg-eSE6_UJSCQaC8Xazc3zH?usp=sharing (https://drive.google.com/drive/folders/1UUq6W-3P7Mg-eSE6_UJSCQaC8Xazc3zH)
The raw cell images used for testing can be obtained at https://drive.google.com/drive/folders/1lNphWDWUDq6U4K25kP-zHL8U4ETawuDE?usp=sharing (https://drive.google.com/drive/folders/1lNphWDWUDq6U4K25kP-zHL8U4ETawuDE)
The inferred images, corresponding ground truth, composite image and, their difference can be obtained at
- https://drive.google.com/drive/folders/1ZetKvrxNPULv_AejnTvnxEi0zlk2Zg4B?usp=sharing (https://drive.google.com/drive/folders/1ZetKvrxNPULv_AejnTvnxEi0zlk2Zg4B)
- https://drive.google.com/drive/folders/1mPoAuW98VMYrqDFtz7QdX9ZA43iuUylU?usp=sharing (https://drive.google.com/drive/folders/1mPoAuW98VMYrqDFtz7QdX9ZA43iuUylU)
5. Writing the pom.xml file for a separate maven project for U-Net segmentation.
The pom.xml file can be obtained at https://github.com/Medha-B/unetdl4j/blob/master/pom.xml
Note to self: Important tutorials regarding maven projects (many thanks to Dr. Akira Funahashi)-
- Create a maven project https://asciinema.org/a/Der4bVBu4KHOGutWqibZdGuYp
- Compile and running maven project https://asciinema.org/a/IrIj4P8Vn8MjIuKn6X2ymXMe4
- Setting up deeplearning4j project with maven https://asciinema.org/a/fcxAth5RLxOn25GQ7vhqXp1ws
6. Training the U-Net with 1 input channel
The method setInputShape(int[][] inputShape) in the class UNet can be used to alter the number of channels from 3 to 1. However, it cannot be directly used with type declaration ComputationGraph. The U-Net model needs to be instantiated as type Zoomodel and then setInputShape can be used. The code will look like this
ZooModel unet = UNet.builder().build();
unet.setInputShape(new int[][]{{1, 128, 128}});
ComputationGraph model = (ComputationGraph) unet.init();
The code compiled successfully and worked but the results were not satisfactory. After training the model with 44 images over 20 epochs (dataset available at https://drive.google.com/drive/folders/1Ox0fi1V9dwBXPHisgLc9kjaIfZFZ27dy?usp=sharing), the result was a dark image with all pixel values as 37.
To verify, I again trained the models with 100 images over 40 epochs and got a similar dark image; only this time all the pixel values were 0.
Image inferred after training with 100 images for a single input channel |
To download these images, visit https://drive.google.com/drive/folders/1CkQ-PtkmxtxgzkeX9CPtWpQTD43gswm2?usp=sharing
I introduced the changes in code for running this under
TrainUnetModel.java at https://github.com/Medha-B/unetdl4j/blob/master/src/main/java/org/sbml/spatial/segmentation/TrainUnetModel.java (commit d901fb0).Note: The loss function being used inherently in the U-Net is XENT, which stands for Cross-Entropy: Binary Classification.
It can be seen in the final convolution layer as
.addLayer("output", new CnnLossLayer.Builder(LossFunctions.LossFunction.XENT)
I forgot to add the code to calculate the average error before training the model, this is what it looked like:
Until the next time...
Yasou!
References:
1. https://deeplearning4j.konduit.ai/model-zoo/overview
2. https://www.bookstack.cn/read/deeplearning4j/ce79934a30930bd1.md
3. https://www.codota.com/code/java/classes/org.deeplearning4j.zoo.model.ResNet50
4. https://www.codota.com/code/java/classes/org.deeplearning4j.zoo.model.VGG16
5. http://maven.apache.org/guides/getting-started/maven-in-five-minutes.html
6. https://mvnrepository.com/search?q=deeplearning4j
7. https://github.com/eclipse/deeplearning4j-examples/blob/master/mvn-project-template/pom.xml
2. https://www.bookstack.cn/read/deeplearning4j/ce79934a30930bd1.md
3. https://www.codota.com/code/java/classes/org.deeplearning4j.zoo.model.ResNet50
4. https://www.codota.com/code/java/classes/org.deeplearning4j.zoo.model.VGG16
5. http://maven.apache.org/guides/getting-started/maven-in-five-minutes.html
6. https://mvnrepository.com/search?q=deeplearning4j
7. https://github.com/eclipse/deeplearning4j-examples/blob/master/mvn-project-template/pom.xml
Comments
Post a Comment