COde RepresentAtion Learning with weakly supervised transformers (CORAL) is a model for learning neural representations of data science code snippets and classifying them as stages in the data analysis process. CORAL leverages both source code abstract syntax trees (ASTs) and associated natural language annotations in markdown text (see Fig. 2).
Model contributions
CORAL contributes the following:
-
CORAL jointly learns from code and surrounding natural language (Sect. 4.1), while preserving meaningful code structure through a graph-based masked attention mechanism (Sect. 4.2). We show that adding natural language improves performance by 13% on snippets that do not have associated markdown comments (Sect. 5.2).
-
We address the lack of high-quality training data through an easily extensible weakly supervised objective based on five simple heuristics (Sect. 4.3).
-
CORAL combines this weak supervision with an additional unsupervised training objective (again to avoid costly ground truth labels) based on topic modeling, which we combine with other objectives in a multi-task learning framework (Sect. 4.5).
4.1 Input representations
CORAL builds on graph neural networks [62] and masked-attention approaches [53] to encode the AST’s graph structure by first serializing the syntax tree in depth-first traversal and then using its adjacency matrix as an attention mask (Sect. 4.2).
We add additional nodes to the AST to capture surrounding natural language. For each code cell, we concatenate its most recent prior markdown as a token sequence to the AST graph sequence (yellow in Fig. 2), so long as the markdown is no more than three cells away. Concretely, we create a node for each markdown token and then connect each markdown node with each AST node. Finally, we add a virtual node [CLS] (for classification) at the head of every input sequence and connect all the other nodes to it. Similar to BERT, we take this node’s embedding as the representation of the cell [18].
Notation
Formally, let \(\mathcal{V}=\{u,v,\ldots\}\) be the set of nodes in the input, where each node v is either an AST node or markdown token. For any input sequence that has more than M nodes, we truncate it and keep only the first M nodes (a modeling choice which we evaluate in Sect. 5.3). We use A to represent the graph adjacency matrix that encodes the relationship between nodes as described above. All input nodes are converted to embedding vectors of dimension \(d_{\mathrm{model}}\). We assemble these embeddings into a matrix X.
4.2 Encoding code cells with attention
We extend the popular BERT model [18] by adding masked multi-head attention to capture the graphical structure of ASTs. We evaluate the impact of this addition in Sect. 5.1.
CORAL feeds the input code and natural language representations to an encoder, which is composed of a stack of \(N = 4\) identical layers (Fig. 2). Similar to Transformers [17], we equip each layer with a multi-head self-attention sublayer and a feed-forward sublayer. The graph structure is captured through masked attention (Eq. (2) below).
Masked multi-head attention
We use \(\mathit{Aggregate}_{k}^{i}\) to represent the self-attention function of \(\mathit{head}_{i}\) in \(\operatorname{layer}_{k}\). Let \((q,k,v)\) be the query, key, and value decomposition of the input to \(\mathit{Aggregate}_{k}^{i}\). Queries and keys are vectors of dimension \(d_{k}\), and values are vectors of dimension \(d_{v}\). For a given node u, let \((q_{u}, k_{u}, v_{u})\) be the triple of query, key and value, and let \(N(u)\) be the set of all its neighbours. Formally, the parameters \(q_{u}\), \(k_{u}\), \(v_{u}\) vary across each \(\mathit{head}_{i}\) and \(\operatorname{layer}_{k}\), but we drop additional notation for simplicity here. Then we compute aggregate results as:
$$ \mathit{Aggregate}_{k}^{i}(u)=\sum _{v\in N(u)} \mathit{Softmax}\biggl( \frac{q_{u}\cdot k_{v}}{\sqrt{d_{k}}}\biggr)\cdot v_{u}. $$
(1)
We adopt the scaling factor \(\frac{1}{\sqrt{d_{k}}}\) from Vaswani et al. [17] to mitigate the dot product’s growth in magnitude with \(d_{k}\). In practice, the queries, keys, and values are assembled into matrices Q, K, V. We compute the output in matrix form as:
$$ \mathit{Aggregate}_{k}^{i}(Q,K,V)= \mathit{Softmax}\biggl( \frac{\tilde{A}\odot QK^{T}}{\sqrt{d_{k}}}\biggr)V, $$
(2)
where \(\tilde{A}=A+I\) is the adjacency matrix with self-loops added to implement the masked attention approach, where each node only attends to its neighbours (described in Sect. 4.1) and itself.
Since we adopt multi-head attention, we concatenate h heads within the same layer:
$$\begin{aligned}& \mathit{MultiHead}(Q,K,V) \\& \quad =\mathit{Concat}(\mathit{head}_{1},\ldots, \mathit{head}_{h})W_{O}, \end{aligned}$$
(3)
$$\begin{aligned}& \mathit{head}_{i}=\mathit{Aggregate}_{k}^{i} \bigl(XW^{i}_{Q}, XW^{i}_{K}, XW^{i}_{V}\bigr), \end{aligned}$$
(4)
where \(\mathit{head}_{i}\in \mathbb{R}^{d_{v}}\) and \(W^{i}_{Q}\in \mathbb{R}^{d_{\mathrm{model}}\times d_{k}}\), \(W^{i}_{K}\in \mathbb{R}^{d_{\mathrm{model}}\times d_{k}}\), \(W^{i}_{V}\in \mathbb{R}^{d_{\mathrm{model}}\times d_{v}}\), and \(W_{O}\in \mathbb{R}^{h*d_{v}\times d_{\mathrm{model}}}\) are projection matrices that map the node embeddings X to queries, keys, values, and multi-head output, respectively.
Feed forward
In each layer, we additionally apply a fully connected feed-forward sublayer. This is composed of two linear transformations with ReLU activation in between:
$$ \operatorname{FFN}(x) = W_{\mathrm{FF}2}\cdot \max (0,W_{\mathrm{FF}1}\cdot x+b_{\mathrm{FF}1})+b_{\mathrm{FF}2}, $$
(5)
where
$$\begin{aligned} W_{\mathrm{FF}1}&\in \mathbb{R}^{h*d_{\mathrm{model}}\times d_{\mathrm{model}}} , \\ W_{\mathrm{FF}2}&\in \mathbb{R}^{d_{\mathrm{model}}\times h*d_{\mathrm{model}}} \end{aligned}$$
and \(b_{\mathrm{FF}1}\) and \(b_{\mathrm{FF}2}\) are parameters learned in model.
Add & norm
Each sublayer is followed by layer normalization [63]. The output of each sublayer is:
$$ \mathit{LayerNorm}\bigl(x + \mathit{Sublayer}(x)\bigr), $$
(6)
where \(\mathit{Sublayer}(x)\) is multi-head attention or feed forward.
Output
The multi-head attention sublayer and feed-forward sublayer are stacked and make up one “layer”. After stacking this layer \(N=4\) times to allow information to propagate between nodes, the encoder’s output contains contextual representations of all the nodes in the input sequence. We take the embedding of the [CLS] node as the representation of the each notebook cell’s graph (Sect. 4.1), denoted as \(z\in \mathbb{R}^{d_{\mathrm{model}}}\).
We compress this cell representation z into a lower-dimens-ional distribution over K “topics” to capture information about the data analysis stages. Concretely:
$$ p_{\mathrm{topic}} = \mathit{Softmax}(W_{\mathrm{topic}}\cdot z+b), $$
(7)
where \(W_{\mathrm{topic}}\in \mathbb{R}^{ K\times d_{\mathrm{model}}}\) is the weighted matrix parameter and b is the bias vector.
4.3 Weak supervision
It is prohibitively expensive to obtain manual annotations of data analysis stages at scale, as doing so would require thousands of person-hours of work by domain experts. Therefore, we use five simple heuristics to tailor CORAL to the prediction task described in Sect. 3.1:
-
1.
We collect a set of seed functions and assign each to a corresponding stage based on its usage. These functions are among the most commonly used in popular Python data science libraries like matplotlib and sklearn, and were selected by expert Python data scientists. Any cell that uses a seed is weakly labeled as the corresponding stage. For example, any cell that calls “sklearn.linear_model.LinearRegression” is weakly labeled MODEL. The full set of 39 seed functions is in Online Reproducability Appendix A.1 [55]. We demonstrate CORAL’s ability to correctly classify unseen code outside these functions in Sect. 5.4.
-
2.
A cell with one line of code that does not create a new variable is weakly labeled EXPLORE. This rule leverages a common pattern in Jupyter notebooks where users often use single line expressions to examine a variable, such as a DataFrame.
-
3.
A cell with more than 30% import statements is labeled IMPORT.
-
4.
A cell whose corresponding markdown is less than four words and contains {‘logistic regression’, ‘machine learning’, ‘random forest’} is weakly labeled MODEL.
-
5.
A cell whose corresponding markdown is less than four words and contains ‘cross validation’ is weakly labeled EVALUATE.
Note that there may be conflicts between these rules. We observe that less than one percent of cells in our corpus comply with more than one of these heuristics, further supporting our decision to formulate labels as mutually exclusive. We resolve any such conflicts by assigning priority in the following order: IMPORT, MODEL, EVALUATE, EXPLORE, WRANGLE.Footnote 1 In this layer, we aim to compute \(p_{\mathrm{stage}}\) – a probability distribution over these six stages – from the topic distribution computed in Eq. (7). We implement this by mapping the topic distribution \(p_{\mathrm{topic}}\) to a probability distribution \(p_{\mathrm{stage}}\) over the \(n_{\mathrm{stages}}=6\) stages. We compute the stage distribution \(p_{\mathrm{stage}}\) as follows, where \(W_{\mathrm{stage}}\in \mathbb{R}^{K\times n_{\mathrm{stages}}}\):
$$ p_{\mathrm{stage}}=\mathit{softmax}(W_{\mathrm{stage}}\cdot p_{\mathrm{topic}} + b_{\mathrm{stage}}). $$
(8)
We adopt cross entropy loss to minimize classification error on weak labels. For each \(p_{\mathrm{topic}}\), loss is computed as:
$$ L_{\mathrm{weakly}\_\mathrm{supervised}} = -\sum _{s}y_{o,s} \log (p_{s}), $$
(9)
where \(y_{o,s}\) is a binary indicator (0 or 1) if stage label s is the correct classification for observation o and \(p_{s}\) is the predicted probability \(p_{\mathrm{stage}}\) is of stage s.
The five weak supervision heuristics cover about 20% of notebook cells in the training data. To minimize the model’s ambiguity on the remaining 80% of unlabeled data, and encourage it to choose a stage for each topic, we add an additional loss function. Concretely, we add an entropy term to \(p_{\mathrm{stage}}\) to encourage uniqueness by forcing the topic distribution to map to as few stages as possible:
$$ L_{\mathrm{unique}\_\mathrm{stage}} = -\sum _{s} p_{s} \log (p_{s}), $$
(10)
where \(p_{s}\) is the predicted probability \(p_{\mathrm{stage}}[s]\) for stage s. This entropy objective is minimized when \(p_{s} = 1\) for some s and \(p_{s'} = 0\) all other \(s'\).
4.4 Unsupervised learning through reconstruction
As the weak supervision heuristics only cover about 20% of the cells, we enrich the model with additional training through an unsupervised topic model. Here, the goal is to optimize the topic representation \(p_{\mathrm{topic}}\) such that we can reconstruct the intermediate cell representation z. We reconstruct z from a linear combination of its topic embeddings \(p_{\mathrm{topic}}\):
$$ r=R\cdot p_{\mathrm{topic}}, $$
(11)
where \(R \in \mathbb{R}^{d_{\mathrm{model}}\times K}\) is the learned cell embedding reconstruction matrix. This unsupervised topic model is trained to minimize the reconstruction error. We adopt the contrastive max-margin objective function using a Hinge loss formulation [64–66]. Thus, in the training process, for each cell, we randomly sample \(m = 5\) cells from our dataset as negative samples:
$$ L_{\mathrm{unsupervised}}=\sum _{c\in D}\sum _{i=1}^{m=5} \max (0,1-r_{c} z_{c}+r_{c} n_{i}), $$
(12)
where D is the training data set, \(r_{c}\) is reconstructed vector of cell c, \(z_{c}\) is intermediate representation of cell c, and \(n_{i}\) is the reconstructed vector of each negative sample. This objective function seeks to minimize the inner product between \(r_{c}\) and \(n_{i}\), while simultaneously maximizing the inner product between \(r_{c}\) and \(z_{c}\).
We also employ a regularization term from He et al. [67] to promote the uniqueness of each topic embedding in R:
$$ L_{\mathrm{unique}\_\mathrm{topic}}= \bigl\Vert R_{\mathrm{norm}}\cdot R_{\mathrm{norm}}^{T}-I \bigr\Vert , $$
(13)
where I is the identity matrix and \(R_{\mathrm{norm}}\) is the result of L2-row-normalization of R. This objective function reaches its minimum when the inner product of two topic embeddings is 0. We apply this regularization term to encourage orthogonality among the rows of the cell embedding reconstruction matrix R and penalize redundancy between reconstruction vectors. We demonstrate in Sect. 5.2 that this additional unsupervised training improves overall classification performance.
4.5 Final optimization objective
We combine the loss functions of Equations (9), (10), (12), and (13) into the final optimization objective:
$$ \begin{aligned} L ={}& \lambda _{1}L_{\mathrm{weakly}\_\mathrm{supervised}} + \lambda _{2}L_{\mathrm{unique}\_\mathrm{stage}} \\ &{}+\lambda _{3}L_{\mathrm{unsupervised}} + \lambda _{4}L_{\mathrm{unique}\_\mathrm{topic}}, \end{aligned} $$
(14)
where \(\lambda _{1} \), \(\lambda _{2}\), \(\lambda _{3}\) and \(\lambda _{4}\) are hyperparameters that control the weights of optimization objectives.
We experiment with various training curricula and find that CORAL with the hyperparameters in Online Reproducability Appendix A.2 [55] achieve the best loss (Eq. (14)) on the validation set. Importantly, this optimization and model training is based on solely on the labels from weak supervision heuristics. We do not use expert annotations (Sect. 3.3), which we exclusively reserve for the final evaluation.