1
00:00:00,229 --> 00:00:04,000
In this talk, I'll introduce Abstract
Neural Networks and describe some fundamental
2
00:00:04,000 --> 00:00:07,710
theoretical results about them that are presented
in our paper.
3
00:00:07,710 --> 00:00:12,209
Nowadays, Deep Neural Networks, or DNNs, are
commonly used in consumer software. DNNs are
4
00:00:12,209 --> 00:00:17,009
a special type of straight-line program that
operate on vectors of real numbers by applying
5
00:00:17,009 --> 00:00:19,949
alternating ‘layers’ of linear and non-linear
transformations.
6
00:00:19,949 --> 00:00:23,994
Two of the most well-known applications of
DNNs are image recognition and speech recognition,
7
00:00:23,994 --> 00:00:30,230
where DNNs have been responsible for significant
accuracy improvements over the last ten years.
8
00:00:30,230 --> 00:00:36,090
DNNs are also quickly being applied to safety-critical
applications. The ACAS Xu DNN, for example,
9
00:00:36,090 --> 00:00:41,460
takes as input a description of an airplane
and its surroundings, then provides a direction
10
00:00:41,460 --> 00:00:42,780
to the plane: such as whether to turn left
or right.
11
00:00:42,780 --> 00:00:48,629
Of course, we want to ensure that these DNNs
are safe to use. However, the large
12
00:00:48,629 --> 00:00:53,559
size and extreme non-linearity of such DNNs
causes standard verification approaches to
13
00:00:53,559 --> 00:00:56,949
fail. This has motivated research in scalable
DNN verification.
14
00:00:56,949 --> 00:01:00,379
To get a better understanding, let’s consider
this DNN which has one input and produces
15
00:01:00,379 --> 00:01:01,379
three outputs.
16
00:01:01,379 --> 00:01:04,850
Suppose we want to compute the output of this
DNN given an input value of ‘1.’ We will
17
00:01:04,850 --> 00:01:11,220
assign this input to the entry node on the
left, then multiply its value by the edge
18
00:01:11,220 --> 00:01:15,590
weights to get the inputs to the nodes in
the second layer. We then apply some activation
19
00:01:15,590 --> 00:01:20,270
function to these values, which is a non-linear
function that can differ between different
20
00:01:20,270 --> 00:01:24,260
networks. In this case, we’ll output zero
if the node’s input is negative, and just
21
00:01:24,260 --> 00:01:29,200
pass the input through otherwise. That produces
an output of 1 for the first node in this layer
22
00:01:29,200 --> 00:01:34,310
and zero for the second. We repeat this process
to compute inputs for the final layer of nodes, and
23
00:01:34,310 --> 00:01:37,710
for simplicity, we will use the identity activation
function for the output layer to produce the
24
00:01:37,710 --> 00:01:38,710
values shown.
25
00:01:38,710 --> 00:01:44,770
So if we look just at the input-output behavior
of this process, we can see that this DNN
26
00:01:44,770 --> 00:01:48,940
defines a function which outputs the vector 1/1/0 on
an input of 1.
27
00:01:48,940 --> 00:01:54,010
When working with DNNs, it's often nice to
think of them not as graphs but rather in
28
00:01:54,010 --> 00:01:55,010
terms of matrices.
29
00:01:55,010 --> 00:02:00,200
Specifically, we can organize each layer of
edge weights into a layer weight matrix. For example,
30
00:02:00,200 --> 00:02:04,230
the first layer has two edge weights which
we organize into the matrix W^1. We can do
31
00:02:04,230 --> 00:02:11,090
a similar thing with the second layer, producing
weight matrix W^2. The DNN is then uniquely
32
00:02:11,090 --> 00:02:15,379
defined by these weight matrices as well as the associated
activation functions sigma^1 and sigma^2.
33
00:02:15,379 --> 00:02:18,949
Currently, to verify a DNN you would input it directly
into a verification tool, which either states
34
00:02:18,949 --> 00:02:25,910
that the DNN is safe, provides a counterexample,
or states that it is not able to prove either.
35
00:02:25,910 --> 00:02:29,160
However, such tools can be very slow for large
DNNs.
36
00:02:29,160 --> 00:02:35,620
We propose to augment this process for large
DNNs via the use of abstraction.
37
00:02:35,620 --> 00:02:39,640
In our proposal, we start with three inputs:
a DNN, an abstract domain, and a
38
00:02:39,640 --> 00:02:45,298
partitioning of the DNN's nodes. The partitioning and abstract
domain will determine the precision of our abstraction.
39
00:02:45,298 --> 00:02:50,360
From these we produce an Abstract Neural Network
(or ANN). This ANN is smaller than the original
40
00:02:50,360 --> 00:02:52,769
DNN, but soundly over-approximates its behavior.
41
00:02:52,769 --> 00:02:57,599
We then verify this ANN using a verification
tool, and our soundness result guarantees
42
00:02:57,599 --> 00:02:59,860
that if the ANN is verified then the DNN will
be as well.
43
00:02:59,860 --> 00:03:03,209
For this talk, I will focus on the first part
of this workflow, namely, how we soundly abstract
44
00:03:03,209 --> 00:03:04,450
a given DNN into a smaller ANN.
45
00:03:04,450 --> 00:03:10,030
In order to do this, we'll have to answer
two questions.
46
00:03:10,030 --> 00:03:13,719
First, we need to define what exactly is an
ANN.
47
00:03:13,719 --> 00:03:19,000
Then, we'll need to explain exactly how to over-approximate
a large DNN with a smaller ANN while guaranteeing
48
00:03:19,000 --> 00:03:20,000
soundness.
49
00:03:20,000 --> 00:03:26,220
In the rest of this talk, I will define ANNs,
describe our algorithm, and discuss its soundness
50
00:03:26,220 --> 00:03:27,220
proof.
51
00:03:27,220 --> 00:03:32,200
Whereas a DNN is a function that associates
with each vector of numbers another output vector,
52
00:03:32,200 --> 00:03:36,109
an ANN will be a function that associates
with each input vector a set of output vectors.
53
00:03:36,109 --> 00:03:42,040
We’ll start by describing a specific instantiation
of ANNs, called Interval Neural Networks, or INNs, which were
54
00:03:42,040 --> 00:03:45,560
actually introduced by Prabhakar et al. last
year.
55
00:03:45,560 --> 00:03:50,799
Here is an example of an INN. It's like a
DNN, except the edge weights are replaced
56
00:03:50,799 --> 00:03:51,810
with intervals.
57
00:03:51,810 --> 00:03:55,159
Given an INN, we can form an instantiation
of it by replacing the intervals with a single
58
00:03:55,159 --> 00:03:59,189
scalar in each corresponding interval. Here
is one example instantiation.
59
00:03:59,189 --> 00:04:05,510
There can be infinitely many such instantiations,
so we'll let Gamma_{INN} be a function mapping
60
00:04:05,510 --> 00:04:06,760
any INN to the set of all of its instantiations.
61
00:04:06,760 --> 00:04:12,230
We now need to define the semantics of an
INN. Recall that we want an INN to output
62
00:04:12,230 --> 00:04:15,870
a set of vectors, while each instantiated
DNN returns a single vector. So we'll define
63
00:04:15,870 --> 00:04:22,170
the output of the INN on an input ‘X’
to be the set of outputs of each of its instantiations
64
00:04:22,170 --> 00:04:24,149
on the same input 'X', as shown here.
65
00:04:24,149 --> 00:04:28,780
The first major contribution of our paper
is generalizing this definition of an Interval Neural Network
66
00:04:28,780 --> 00:04:33,030
to arbitrary numerical domains, such as octagons
and polyhedra.
67
00:04:33,030 --> 00:04:38,580
To define such Abstract Neural Networks, or ANNs, we associate with each
layer some element A^i of an abstract domain.
68
00:04:38,580 --> 00:04:43,500
In particular, we want the concretization
of each A^i to be a set of matrices.
69
00:04:43,500 --> 00:04:48,330
We then define a DNN instantiation of this
ANN to be any DNN where the layer weight matrices
70
00:04:48,330 --> 00:04:53,140
belong to the concretization set of the corresponding
ANN layer’s abstract weight matrix.
71
00:04:53,140 --> 00:04:57,620
We'll similarly define Gamma_ANN to be the
set of all such instantiations,
72
00:04:57,620 --> 00:05:02,230
and define the output of the ANN to be the
set of outputs of all the instantiations.
73
00:05:02,230 --> 00:05:07,120
With this definition, you can directly generalize
the notion of an INN to use any numerical
74
00:05:07,120 --> 00:05:11,290
domain, including intervals, difference-bounded
matrices, polyhedra, and octagons.
75
00:05:11,290 --> 00:05:15,630
So far, we’ve defined ANNs.
Next, I’ll describe our algorithm for over-approximating
76
00:05:15,630 --> 00:05:20,530
large DNNs with smaller ANNs, and after that
discuss the soundness of our algorithm.
77
00:05:20,530 --> 00:05:27,160
The goal of our algorithm is to abstract a
DNN N into a smaller ANN N# which soundly over-approximates
78
00:05:27,160 --> 00:05:28,160
N.
79
00:05:28,160 --> 00:05:34,640
Specifically, we will say that the DNN N is
soundly over-approximated by the ANN N# when
80
00:05:34,640 --> 00:05:39,670
for every input, the output of N is in the
output set of N#.
81
00:05:39,670 --> 00:05:44,530
Recalling the ANN semantics, what this means
is that for every input x_1 to N, there is
82
00:05:44,530 --> 00:05:48,690
an instantiation of N# with the same output
on x_1 as N.
83
00:05:48,690 --> 00:05:52,780
We’ll start by describing the abstraction
algorithm for interval neural networks introduced
84
00:05:52,780 --> 00:05:54,560
last year by Prabhakar et al.
85
00:05:54,560 --> 00:05:58,580
The key idea is that we want to merge nodes
in the DNN together to create abstract nodes
86
00:05:58,580 --> 00:06:02,000
in the corresponding INN.
In this example, we have merged both nodes
87
00:06:02,000 --> 00:06:06,130
in the second layer together to create an
INN with only one node in the second layer.
88
00:06:06,130 --> 00:06:10,070
We then take the weights in the INN to be
the weighted convex hull of the corresponding
89
00:06:10,070 --> 00:06:11,210
weights in the DNN.
90
00:06:11,210 --> 00:06:15,260
For example, the first green weight in the
INN corresponds to the two red weights in
91
00:06:15,260 --> 00:06:20,440
the DNN, so we assign its edge interval to
be the convex hull, [-1, 1], of the two corresponding
92
00:06:20,440 --> 00:06:21,440
DNN weights.
93
00:06:21,440 --> 00:06:25,950
We apply a similar process for all other weights.
The one important caveat is that we'll scale
94
00:06:25,950 --> 00:06:30,841
each weight up by the number of upcoming nodes
which were merged together --- for this weight,
95
00:06:30,841 --> 00:06:34,950
we use the interval [2, 2] instead of [1, 1]
because two nodes on the left of the weight
96
00:06:34,950 --> 00:06:39,930
were merged, so we should treat a single input
from that node as counting for two inputs,
97
00:06:39,930 --> 00:06:43,250
one from each of the original DNN nodes merged
together.
98
00:06:43,250 --> 00:06:47,880
We repeat this weighted-convex-hull process
to get all of the other weights in the INN.
99
00:06:47,880 --> 00:06:51,750
Our paper generalizes this approach to work
with any abstract domain A.
100
00:06:51,750 --> 00:06:56,130
Suppose we want to abstract this three-layer
DNN by merging the middle two nodes to produce
101
00:06:56,130 --> 00:06:57,720
the ANN on the right.
102
00:06:57,720 --> 00:07:01,520
Our algorithm will construct abstract elements
A1 and A2.
103
00:07:01,520 --> 00:07:07,150
Specifically, it'll construct A1 such that
all matrices of this type belong to the concretization
104
00:07:07,150 --> 00:07:12,000
set of A1. Each of these matrices is formed
by taking the original layer matrix from the
105
00:07:12,000 --> 00:07:16,340
first DNN layer, then applying some convex
combination to the outputs (because we merged
106
00:07:16,340 --> 00:07:21,470
the two output nodes of that first layer). The algorithm
will produce an element A^1 that includes
107
00:07:21,470 --> 00:07:24,620
all such combined matrices in its concretization
set.
108
00:07:24,620 --> 00:07:29,640
Similarly, the algorithm constructs A2 such
that its concretization set include all of
109
00:07:29,640 --> 00:07:34,960
these matrices. Each one starts with the initial
second-layer weight matrix from the DNN, then
110
00:07:34,960 --> 00:07:39,450
takes a convex combination of the two input
dimensions. Finally, we weight the convex
111
00:07:39,450 --> 00:07:42,880
combination by the number of inputs merged,
exactly like with the INN algorithm.
112
00:07:42,880 --> 00:07:47,390
So here we have defined the abstract elements
A1 and A2 in terms of their concretization sets,
113
00:07:47,390 --> 00:07:50,500
but we can turn this into a constructive definition
by applying alpha to both sides.
114
00:07:50,500 --> 00:07:56,551
The one caveat is that these sets are infinite,
so we cannot just compute each member of the
115
00:07:56,551 --> 00:07:59,880
set and then apply a finite alpha function to compute
A1 and A2.
116
00:07:59,880 --> 00:08:04,020
Thankfully, one of the key theorems in our
paper says that, as long as your abstract
117
00:08:04,020 --> 00:08:09,110
domain A is convex, you get the exact same
A1, A2 if you only use the extreme points.
118
00:08:09,110 --> 00:08:13,711
Most common numerical domains, such as polyhedra,
octagons, and intervals, are convex.
119
00:08:13,711 --> 00:08:16,948
In other words, we can change the intervals
in this definition to sets.
120
00:08:16,948 --> 00:08:22,848
So to compute A1 we only need to compute the two matrices
in this set and abstract with alpha, and similarly for
121
00:08:22,848 --> 00:08:24,080
computing A2.
122
00:08:24,080 --> 00:08:28,750
So this gives a constructive algorithm for
computing this corresponding ANN.
123
00:08:28,750 --> 00:08:32,250
We've so far defined ANNs
and described an algorithm for merging nodes
124
00:08:32,250 --> 00:08:37,389
in a large DNN to construct a smaller ANN
For the remainder of the talk we'll describe
125
00:08:37,389 --> 00:08:41,580
when this smaller ANN soundly over-approximates
the corresponding DNN.
126
00:08:41,580 --> 00:08:45,100
We’ll first gain some intuition by focusing
on the special case of the interval neural
127
00:08:45,100 --> 00:08:48,589
network we saw before, then expand this into
a more general argument.
128
00:08:48,589 --> 00:08:52,879
Let’s walk through a small example to see
this in action. On the left we have the DNN
129
00:08:52,879 --> 00:08:56,709
we saw earlier, and on the right we have an
interval ANN formed by applying the algorithm
130
00:08:56,709 --> 00:09:00,870
in our paper to this DNN. Let’s assume we
used the activation function sigma^h shown
131
00:09:00,870 --> 00:09:05,519
here, which, again, outputs 0 on a negative input
and acts like the identity on any others.
132
00:09:05,519 --> 00:09:09,850
Note that this activation function does
indeed satisfy the weakened intermediate value
133
00:09:09,850 --> 00:09:15,279
property and always has non-negative outputs,
so our sufficiency theorem states the corresponding
134
00:09:15,279 --> 00:09:17,740
ANN should over-approximate the DNN.
135
00:09:17,740 --> 00:09:24,010
To understand why, consider any positive input
X to this DNN. I will claim that this valid
136
00:09:24,010 --> 00:09:28,110
instantiation of the ANN has the same output
as the DNN on ‘X’.
137
00:09:28,110 --> 00:09:32,160
We can evaluate the DNN, finding that the
input to the first node in the second layer
138
00:09:32,160 --> 00:09:36,740
is positive, so it gets passed through, while
the input to the second node is negative so
139
00:09:36,740 --> 00:09:38,130
it outputs zero.
140
00:09:38,130 --> 00:09:42,980
Then the final output of the DNN is the vector
x/x/0.
141
00:09:42,980 --> 00:09:47,660
Similarly, on this instantiation we have that
the output of the middle node is X over 2,
142
00:09:47,660 --> 00:09:51,040
and the output is exactly the same vector
x/x/0 as the DNN.
143
00:09:51,040 --> 00:09:57,269
The key here is that, in the instantiation,
the merged node here corresponding to these
144
00:09:57,269 --> 00:10:03,200
two second-layer nodes and takes the average value
of both of those nodes, which is then scaled
145
00:10:03,200 --> 00:10:06,170
back up by the final layer.
146
00:10:06,170 --> 00:10:11,860
On the other hand, if x is negative or zero,
we can use this instantiation.
147
00:10:11,860 --> 00:10:20,660
We can perform the computation to see that
the output of the DNN on a negative input is -x/0/-x and that
148
00:10:20,660 --> 00:10:24,110
we get exactly the same output from this instantiation.
149
00:10:24,110 --> 00:10:29,100
So, for any input, there is an instantiation
of the ANN which has the same output as the
150
00:10:29,100 --> 00:10:33,470
original DNN. Therefore, we say the ANN over-approximates
the DNN.
151
00:10:33,470 --> 00:10:37,939
Let’s now recall the more general ANN abstraction,
which over-approximated the DNN
152
00:10:37,939 --> 00:10:43,800
on the top-left here with the ANN on the top-right,
where A1 and A2 were defined according to the
153
00:10:43,800 --> 00:10:45,709
two inclusions shown here.
154
00:10:45,709 --> 00:10:48,420
Suppose we have an input X1, X2 to the DNN.
155
00:10:48,420 --> 00:10:52,350
Following the DNN semantics, this input vector
is multiplied by the first layer’s weight
156
00:10:52,350 --> 00:10:58,399
matrix to get the vector u1/u2 input for the
second layer. Then we apply the activation
157
00:10:58,399 --> 00:11:04,899
function sigma to get the output vector v1/v2
of the second layer. Finally, we apply the
158
00:11:04,899 --> 00:11:09,850
output layer’s weight matrix to get the
output y1/y2 of the DNN. For simplicity here
159
00:11:09,850 --> 00:11:14,940
we have assumed the output activation function
is just the identity.
160
00:11:14,940 --> 00:11:17,970
The ANN produced by our algorithm replaces
the two weight matrices with elements A1,
161
00:11:17,970 --> 00:11:20,190
A2 of an abstract domain.
162
00:11:20,190 --> 00:11:26,230
To show that the ANN over-approximates the
DNN on this input, we need to show that y1/y2
163
00:11:26,230 --> 00:11:31,149
is in the output set of the ANN on an input
of x1/x2.
164
00:11:31,149 --> 00:11:36,559
This in turn is equivalent to showing that
exists some instantiation of the ANN which also
165
00:11:36,559 --> 00:11:38,350
maps x1/x2 to y1/y2.
166
00:11:38,350 --> 00:11:46,480
To do this, we'll first show that, for any
scalar u-sharp between u1 and u2, we can instantiate
167
00:11:46,480 --> 00:11:50,579
the first layer, A1, such that x1/x2 maps
to u-sharp.
168
00:11:50,579 --> 00:11:56,369
We'll then show that there is at least one such
u-sharp which maps to the average of v1 and
169
00:11:56,369 --> 00:11:57,369
v2 under sigma.
170
00:11:57,369 --> 00:12:04,100
Finally, we will show that there is an instantiation
of A2 which maps this average to the final
171
00:12:04,100 --> 00:12:05,100
desired output y1/y2.
172
00:12:05,100 --> 00:12:09,369
The key idea here is that after each layer
in the instantiation we want to realize the
173
00:12:09,369 --> 00:12:17,139
average of the nodes that are merged together.
Then, because the input and output nodes aren’t
174
00:12:17,139 --> 00:12:21,930
merged, their averages are just the original
inputs and outputs to the network, giving
175
00:12:21,930 --> 00:12:22,930
us the desired soundness property.
176
00:12:22,930 --> 00:12:23,930
Let’s take a closer look at each of these
arguments.
177
00:12:23,930 --> 00:12:27,680
The first thing we need to show is that, for
any u-sharp between u1 and u2 we can instantiate
178
00:12:27,680 --> 00:12:31,880
A1 to map x1/x2 to u-sharp.
179
00:12:31,880 --> 00:12:36,700
Recall that we constructed A1 to include all
matrices of this form, so it suffices to show
180
00:12:36,700 --> 00:12:40,019
one of those matrices maps x1/x2 to u-sharp.
181
00:12:40,019 --> 00:12:42,379
We can write the output of the instantiation
like this.
182
00:12:42,379 --> 00:12:46,330
And, because multiplication is associative,
the rightmost multiplication just becomes
183
00:12:46,330 --> 00:12:49,149
the output of the DNN u1/u2 itself.
184
00:12:49,149 --> 00:12:54,790
Expanding out this matrix multiplication, we
see that varying alpha allows us to hit any
185
00:12:54,790 --> 00:12:56,379
point on the line between u1 and u2.
186
00:12:56,379 --> 00:13:02,160
But u-sharp is assumed to be between u1 and
u2, so there is always some alpha for which
187
00:13:02,160 --> 00:13:03,269
the output is u-sharp.
188
00:13:03,269 --> 00:13:09,220
Hence, for any u-sharp you pick between u1 and u2,
there will be some instantiation layer W^1
189
00:13:09,220 --> 00:13:12,350
of A1 that maps x1/x2 to u-sharp.
190
00:13:12,350 --> 00:13:16,449
The next thing we need to show is that at
least one of those u-sharps gets mapped to
191
00:13:16,449 --> 00:13:19,950
the average of v1 and v2 after applying the
activation function sigma.
192
00:13:19,950 --> 00:13:25,420
As we'll see later, this is not, in fact,
generally true for all possible activation
193
00:13:25,420 --> 00:13:26,420
functions.
194
00:13:26,420 --> 00:13:29,740
Instead, we'll assume for now that sigma
satisfies the Weakened Intermediate Value
195
00:13:29,740 --> 00:13:34,430
Property, which is defined in detail in our
paper, but is effectively equivalent to the
196
00:13:34,430 --> 00:13:36,170
desired fact.
197
00:13:36,170 --> 00:13:40,529
Thankfully, all continuous functions satisfy
the Weakened Intermediate Value Property and
198
00:13:40,529 --> 00:13:44,170
so pretty much all commonly-used activation functions
work with our argument.
199
00:13:44,170 --> 00:13:49,779
So now we’ve shown that we can get from
x1/x2 to the average of v1 and v2 in the instantiation.
200
00:13:49,779 --> 00:13:54,410
For the very final layer, we want to show
that we can instantiate A2 with a matrix that
201
00:13:54,410 --> 00:13:58,160
maps this average of v1 and v2 to the original
output y1/y2.
202
00:13:58,160 --> 00:14:02,620
Again, we know that A2 was constructed to
have all matrices of this type in its concretization
203
00:14:02,620 --> 00:14:09,140
set, so we just need to show that one of them
maps the mean of v1/v2 to the vector y1/y2.
204
00:14:09,140 --> 00:14:12,879
We can write the output of any one of those
concretizations like this, and cancelling
205
00:14:12,879 --> 00:14:14,829
the 2s gives us a slightly simpler expression.
206
00:14:14,829 --> 00:14:19,760
From here, we use one of the key lemmas in
the paper. Namely, if v1 and v2 are non-negative,
207
00:14:19,760 --> 00:14:24,869
then there is some beta between 0 and 1 for
which the rightmost multiplication becomes
208
00:14:24,869 --> 00:14:29,829
the vector v1/v2. Essentially, from the mean
we can reconstruct each individual element
209
00:14:29,829 --> 00:14:32,000
using beta between 0 and 1.
210
00:14:32,000 --> 00:14:35,829
From there, we see that this gives us back
exactly the same computation from the original
211
00:14:35,829 --> 00:14:37,790
DNN layer, hence outputting y1/y2 as desired.
212
00:14:37,790 --> 00:14:43,290
To give some intuition behind where this beta
comes from, in this scenario you can take
213
00:14:43,290 --> 00:14:47,980
beta to be v1 over the sum of v1/v2 to see
that the equality holds and that beta is between
214
00:14:47,980 --> 00:14:49,069
zero and one.
215
00:14:49,069 --> 00:14:53,820
So what we’ve shown is that, under reasonable
conditions, you can instantiate the last layer’s
216
00:14:53,820 --> 00:15:00,720
abstract weights A2 to some matrix W2 which maps
the mean of v1/v2 to the desired output y1/y2,
217
00:15:00,720 --> 00:15:07,639
and then we've also shown that there's some u-sharp
between u1 and u2 which maps to that mean under sigma, and finally
218
00:15:07,639 --> 00:15:13,079
there is some instantiation W1 of the first layer’s
abstract weights A1 that actually produces
219
00:15:13,079 --> 00:15:15,240
that u-sharp on an input of x1/x2.
220
00:15:15,240 --> 00:15:21,309
Hence, overall we have shown that for any
input x1/x2, there exists an instantiation of
221
00:15:21,309 --> 00:15:25,120
the ANN which produces exactly the same output
as the original DNN.
222
00:15:25,120 --> 00:15:29,309
With some subtleties handled in the paper,
what this argument shows is that:
223
00:15:29,309 --> 00:15:35,220
If the activation function satisfies the Weakened
Intermediate Value Property, and if all post-activation intermediate
224
00:15:35,220 --> 00:15:40,750
values while computing the DNN’s output are
non-negative, then the ANN constructed by
225
00:15:40,750 --> 00:15:43,649
our algorithm soundly over-approximates the
original DNN.
226
00:15:43,649 --> 00:15:47,899
We may wonder if both of these conditions
were in fact necessary, and it turns out they
227
00:15:47,899 --> 00:15:48,899
were.
228
00:15:48,899 --> 00:15:52,540
Specifically, if the activation functions
do not satisfy the Weakened Intermediate Value
229
00:15:52,540 --> 00:15:56,689
Property, or if some of the intermediate values are
positive while others are negative, then we
230
00:15:56,689 --> 00:16:00,860
can construct a DNN using those activation
functions for which the corresponding ANN
231
00:16:00,860 --> 00:16:04,660
does not over-approximate it. While this may
seem worrying, in the paper we go over some
232
00:16:04,660 --> 00:16:09,130
extensions to the algorithm which can support
even DNNs that violate these necessary conditions.
233
00:16:09,130 --> 00:16:16,095
Let’s first take a look at a first example where
the intermediate values do not necessarily all have the same sign.
234
00:16:16,095 --> 00:16:21,499
Here we have the same DNN/INN pair we saw before,
except now we're using the identity activation function.
235
00:16:21,499 --> 00:16:26,309
Consider what happens when we give this
DNN an input of 1.
236
00:16:26,309 --> 00:16:29,869
Now, the first node in the second layer outputs
1, while the second node in that layer outputs
237
00:16:29,869 --> 00:16:32,640
-1, giving us mixed signs which violates the
conditions.
238
00:16:32,640 --> 00:16:38,249
Let’s look at any instantiation of the ANN,
where here the weight ‘A’ is between -1
239
00:16:38,249 --> 00:16:41,829
and 1 and ‘b’ and ‘c’ are between
0 and 2.
240
00:16:41,829 --> 00:16:46,389
Then the output of the instantiation will
be 2a/ba/ca. But we need this output to be
241
00:16:46,389 --> 00:16:51,120
0/1/-1, which then implies we must take a
to be 0.
242
00:16:51,120 --> 00:16:56,149
But then all the other outputs must be zero
as well, so there’s no way we can instantiate
243
00:16:56,149 --> 00:16:58,420
this ANN to have the same output as the DNN.
244
00:16:58,420 --> 00:17:02,970
This shows directly how, whenever the DNN
can have mixed signs, soundness of the abstraction
245
00:17:02,970 --> 00:17:04,610
algorithm is not guaranteed.
246
00:17:04,610 --> 00:17:09,395
Thankfully, in our paper we describe an automated
process to rewrite essentially any DNN into
247
00:17:09,395 --> 00:17:14,225
an entirely equivalent one which uses only
positive values.
248
00:17:14,225 --> 00:17:18,661
Let's now look at an example using an activation function
that violates the Weakened Intermediate Value Property
249
00:17:18,661 --> 00:17:23,528
This is the same DNN/INN pair we saw before,
except now we're using the activation function
250
00:17:23,528 --> 00:17:29,559
that outputs 1 one for any input greater than or equal to 1
and zero otherwise.
251
00:17:29,559 --> 00:17:33,647
We can compute the output of this DNN on an input of 1
to be 1/1/0.
252
00:17:33,647 --> 00:17:38,549
What we find from any instantiation is that,
no matter what the input to the middle node
253
00:17:38,549 --> 00:17:40,930
its output will be either zero or one.
254
00:17:40,930 --> 00:17:46,419
But then the first output node will be either
zero or two --- which cannot be the desired
255
00:17:46,419 --> 00:17:47,419
value of 1.
256
00:17:47,419 --> 00:17:51,049
So the takeaway from this example is that
if your activation function has essentially
257
00:17:51,049 --> 00:17:53,571
big-enough gaps, then soundness is not guaranteed.
258
00:17:53,571 --> 00:17:58,730
The workaround we propose in our paper for
this is to effectively ‘lift’ the activation
259
00:17:58,730 --> 00:18:01,130
function to a set-valued one to fill in
those gaps.
260
00:18:01,130 --> 00:18:05,059
But it’s really important to note that this
is not a very common thing to run into,
261
00:18:05,059 --> 00:18:09,529
because any activation function with these
gaps would be discontinuous and so you would
262
00:18:09,529 --> 00:18:12,600
have trouble even applying gradient descent in the first place
to train such a DNN.
263
00:18:13,634 --> 00:18:17,861
In this talk, and our associated paper, we defined
Abstract Neural Networks, or ANNs,
264
00:18:17,861 --> 00:18:22,799
We presented an algorithm for soundly over-approximating
a large DNN with a smaller ANN,
265
00:18:22,799 --> 00:18:27,600
and we described necessary and sufficient conditions
for that algorithm to be sound.
266
00:18:27,600 --> 00:18:30,377
There are two main areas of future work building on
the foundations laid out in our paper
267
00:18:30,377 --> 00:18:35,946
First is the problem of efficiently verifying an ANN
once it has been constructed.
268
00:18:36,042 --> 00:18:41,353
One proposed approach for a special case of interval
neural networks was described an evaluated by Prabhakar et al last year.
269
00:18:41,353 --> 00:18:46,924
Extending such an approach to octagon and polyhedral
neural networks would be an interesting direction for future research.
270
00:18:46,924 --> 00:18:49,799
Another question is what to do if the ANN cannot be verified.
271
00:18:49,799 --> 00:18:54,105
In this scenario we may want to refine either our
abstract domain or the partitioning used.
272
00:18:54,105 --> 00:18:58,253
For example, instead of merging all of the nodes in the
second layer together, we might only merge half of them.
273
00:18:58,253 --> 00:19:06,511
Our work lays strong theoretical foundations for this kind of
future experimentation and applications in the area of Abstract Neural Networks.
274
00:19:06,511 --> 00:19:09,900
All of our source code is available on GitHub at this link.