Multi class semantic segmentation in VR

This week I was mostly improving my model for multi class semantic segmentation.


Dataset

For this task I was using the Cholec8k Segmentation dataset. This dataset is interesting, because it contains not only the segmentation of organs, but also the data needed for segmentation of surgical tools. 

The data in this dataset is split on videos. Totally there are 17 videos, so I decided to make the train-test split based on video ids. For this purpose, I was using 13 videos for training and 4 videos for validation and testing.


Model

Last week I was experimenting with different models and at the beginning of this week I was sure that DeepLavV3+ should be used. DeepLavV3+ uses a decoder to better segment boundaries of objects. It is an important part of medical images segmentation, since we want to know where the border between organs and surgical tools located. Actually, it is one more reason for me to use DeepLabV3+.


For training I was using 2 loss functions: CE and Dice Loss. CE loss is widely used for mutli class semantic segmentation. And, to be honest, the original model (with no distortion) performed well on the data. 
Example of segmentation results Original model CE

However, in case of distorted data the result is not that optimistic.  My models were struggling with segmentation of the picture's border (exactly what I predicted at the very beginning of the project).
"Distorted model"

To fight this issue I have just changed the loss function for training. I was using the Dice Loss function, which is closely correlated with IoU score (main metric in this work). The original model trained with this loss function performed even better than the CE model. If CE model had 0.65 IoU, the resulting model, trained with Dice Loss had IoU about 0.7
Original model, trained with Dice Loss

That is why I expected from the distorted model high performance. And it actually worked. The distorted model, trained with this loss function had IoU close to the one of original model. Also, since the videos do nit overlap, I can conclude that these results are not caused overfitting.



Distorted model, Dice Loss

Conclusion

From my work it can be concluded that VR can be simulated for training model with multi class segmentation tasks. Results are quite accurate and precise.

Comments

Popular posts from this blog

Summing up my GSoC experience

How can we "simulate" VR?

Surgical tools detection