226

Two-stage Training of Graph Neural Networks for Graph Classification

Abstract

Graph Neural Networks (GNNs) have received massive attention in the field of machine learning on graphs. Inspired by the success of neural networks, a line of research has been conducted to train GNNs to deal with various tasks, such as node classification, graph classification, and link prediction. In this work, our task of interest is graph classification. Several GNN models have been proposed and shown great performance in this task. However, the question is whether the original setting has fully utilized the power of the architecture. In this work, we propose a two-stage training framework based on triplet loss. After the first stage, graphs of the same class are close while those of different classes are mapped far apart. Once graphs are well-separated based on labels, they can be easier to classify. This framework is generic in the sense that it is compatible to any GNN architecture. By adapting 5 GNNs to the triplet framework, together with some additional fine tuning, we demonstrate the consistent improvement in performance over the original setting of each model up to 5.4% in 12 datasets.

View on arXiv
Comments on this paper