In [1]:
import jax.numpy as jnp
import plotly.express as px
from plotly.subplots import make_subplots
import jax
import numpy as np
from datasets import mnist
import plotly.graph_objects as go

import copy
In [2]:
train_images, train_labels, test_images, test_labels = mnist()

train_images = train_images.astype(jnp.float32)
test_images = test_images.astype(jnp.float32)

train_labels = jnp.asarray(train_labels, dtype=jnp.int32)
test_labels = jnp.asarray(test_labels, dtype=jnp.int32)
In [3]:
def visualize_images(images_tensor, w=28, h=28, col_wrap=5):
    
    img = images_tensor.reshape(-1, w, h)
    
    fig = px.imshow(img[:, :, :], binary_string=False, facet_col=0, facet_col_wrap=col_wrap)
    
    item_map={f'{i}':"" for i, key in enumerate(range(img.shape[0]))}
    fig.for_each_annotation(lambda a: a.update(text=item_map[a.text.split("=")[1]])) 
    
    fig.show()
In [5]:
net_parameters = {
    'w0' : np.random.randn(256, 784) * 0.1,
    'w1' : np.random.randn(256, 256) * 0.1,
    'w2' : np.random.randn(256, 256) * 0.1,
    'w3' : np.random.randn(10, 256) * 0.1,
}
In [6]:
def ReLU(x):
    return jnp.maximum(0,x)

def forward(parameters, x):
    x = x.T
    x = parameters['w0'] @ x
    x = ReLU(x)
    x = parameters['w1'] @ x
    x = ReLU(x)
    x = parameters['w2'] @ x
    x = ReLU(x)
    x = parameters['w3'] @ x
    x = x.T
    return x
In [7]:
def loss(parameters, x, y):
    out = forward(parameters, x)
    out = jax.nn.softmax(out)
    _loss = -(y * jnp.log(out)).sum(axis=-1).mean()
    return _loss

loss(net_parameters, test_images, test_labels)
Out[7]:
Array(2.731128, dtype=float32)
In [8]:
(forward(net_parameters, train_images).argmax(axis=-1) == train_labels.argmax(axis=-1)).mean()
Out[8]:
Array(0.06178333, dtype=float32)
In [9]:
grad_loss = jax.grad(loss)
lr = 0.1

# keep track of all the previous gradients
grad_history = []

# keep track of all the previous parameters
parameter_deltas = []

for epoch in range(100):
    
    p_grad = grad_loss(net_parameters, train_images, train_labels)
    grad_history.append(p_grad)
    
    net_parameters['w0'] -= lr * p_grad['w0']
    net_parameters['w1'] -= lr * p_grad['w1']
    net_parameters['w2'] -= lr * p_grad['w2']
    net_parameters['w3'] -= lr * p_grad['w3']

    # record the changes that were made to the parameters for analysis
    parameter_delta_vector = np.concatenate((
        -lr * p_grad['w0'].flatten(),
        -lr * p_grad['w1'].flatten(),
        -lr * p_grad['w2'].flatten(),
        -lr * p_grad['w3'].flatten()
    ))
    parameter_deltas.append(parameter_delta_vector)



    print(f"epoch {epoch}")
    print(f"validation loss: {loss(net_parameters, test_images, test_labels)}")
    print(f"train loss: {loss(net_parameters, train_images, train_labels)}")
    acc = (forward(net_parameters, train_images).argmax(axis=-1) == train_labels.argmax(axis=-1)).mean()
    print(f"accuracy: {acc}")
    print("\n")
epoch 0
validation loss: 2.277862548828125
train loss: 2.276801347732544
accuracy: 0.20038333535194397


epoch 1
validation loss: 1.9986860752105713
train loss: 2.0062105655670166
accuracy: 0.28600001335144043


epoch 2
validation loss: 1.8096214532852173
train loss: 1.8151549100875854
accuracy: 0.3808833360671997


epoch 3
validation loss: 1.7291929721832275
train loss: 1.7464510202407837
accuracy: 0.4345833361148834


epoch 4
validation loss: 1.49941086769104
train loss: 1.5121557712554932
accuracy: 0.5228833556175232


epoch 5
validation loss: 1.4020278453826904
train loss: 1.424275279045105
accuracy: 0.5604666471481323


epoch 6
validation loss: 1.3016060590744019
train loss: 1.3177744150161743
accuracy: 0.5866666436195374


epoch 7
validation loss: 1.2220079898834229
train loss: 1.247183084487915
accuracy: 0.609499990940094


epoch 8
validation loss: 1.157178282737732
train loss: 1.175179362297058
accuracy: 0.6294500231742859


epoch 9
validation loss: 1.0745911598205566
train loss: 1.1006921529769897
accuracy: 0.6504666805267334


epoch 10
validation loss: 1.0312285423278809
train loss: 1.0506861209869385
accuracy: 0.6718666553497314


epoch 11
validation loss: 0.9537557363510132
train loss: 0.9796228408813477
accuracy: 0.6926500201225281


epoch 12
validation loss: 0.9217555522918701
train loss: 0.9424632787704468
accuracy: 0.7120500206947327


epoch 13
validation loss: 0.8608860373497009
train loss: 0.8858886957168579
accuracy: 0.726983368396759


epoch 14
validation loss: 0.8320112824440002
train loss: 0.8535038232803345
accuracy: 0.7462166547775269


epoch 15
validation loss: 0.7871823906898499
train loss: 0.8111382126808167
accuracy: 0.7528499960899353


epoch 16
validation loss: 0.7618271708488464
train loss: 0.7834699153900146
accuracy: 0.7703499794006348


epoch 17
validation loss: 0.7334457635879517
train loss: 0.7562983632087708
accuracy: 0.76746666431427


epoch 18
validation loss: 0.7144235372543335
train loss: 0.735580325126648
accuracy: 0.7815499901771545


epoch 19
validation loss: 0.7055370211601257
train loss: 0.7271606922149658
accuracy: 0.7684000134468079


epoch 20
validation loss: 0.6890653967857361
train loss: 0.7091876268386841
accuracy: 0.7810333371162415


epoch 21
validation loss: 0.6934828758239746
train loss: 0.7136143445968628
accuracy: 0.7663333415985107


epoch 22
validation loss: 0.6664817929267883
train loss: 0.6853572726249695
accuracy: 0.7836333513259888


epoch 23
validation loss: 0.6628146171569824
train loss: 0.681693434715271
accuracy: 0.7766333222389221


epoch 24
validation loss: 0.6299082636833191
train loss: 0.6479393839836121
accuracy: 0.7970499992370605


epoch 25
validation loss: 0.6206050515174866
train loss: 0.6385985612869263
accuracy: 0.7936999797821045


epoch 26
validation loss: 0.5956608057022095
train loss: 0.612824559211731
accuracy: 0.810283362865448


epoch 27
validation loss: 0.5865873694419861
train loss: 0.6037764549255371
accuracy: 0.807033360004425


epoch 28
validation loss: 0.568708598613739
train loss: 0.5849623680114746
accuracy: 0.8195500373840332


epoch 29
validation loss: 0.5596689581871033
train loss: 0.5761591196060181
accuracy: 0.8181333541870117


epoch 30
validation loss: 0.5462689995765686
train loss: 0.5616114735603333
accuracy: 0.82833331823349


epoch 31
validation loss: 0.5368631482124329
train loss: 0.5527257919311523
accuracy: 0.8262166976928711


epoch 32
validation loss: 0.5265526175498962
train loss: 0.541017472743988
accuracy: 0.8357666730880737


epoch 33
validation loss: 0.5168095827102661
train loss: 0.5320915579795837
accuracy: 0.8339666724205017


epoch 34
validation loss: 0.5087394118309021
train loss: 0.5223720669746399
accuracy: 0.8425499796867371


epoch 35
validation loss: 0.49895179271698
train loss: 0.513683557510376
accuracy: 0.8413166999816895


epoch 36
validation loss: 0.49255242943763733
train loss: 0.505401074886322
accuracy: 0.8484833240509033


epoch 37
validation loss: 0.4828387498855591
train loss: 0.497036337852478
accuracy: 0.8472999930381775


epoch 38
validation loss: 0.47769641876220703
train loss: 0.4898192882537842
accuracy: 0.8543333411216736


epoch 39
validation loss: 0.4683235287666321
train loss: 0.48197564482688904
accuracy: 0.8528000116348267


epoch 40
validation loss: 0.4641515016555786
train loss: 0.47560182213783264
accuracy: 0.8591166734695435


epoch 41
validation loss: 0.45532241463661194
train loss: 0.4684279263019562
accuracy: 0.8571500182151794


epoch 42
validation loss: 0.4518907964229584
train loss: 0.4627283215522766
accuracy: 0.8634333610534668


epoch 43
validation loss: 0.44364094734191895
train loss: 0.4562109708786011
accuracy: 0.8617166876792908


epoch 44
validation loss: 0.4408518373966217
train loss: 0.4511384963989258
accuracy: 0.8675333261489868


epoch 45
validation loss: 0.43332454562187195
train loss: 0.4453541040420532
accuracy: 0.8657333254814148


epoch 46
validation loss: 0.43103447556495667
train loss: 0.4408362805843353
accuracy: 0.8708666563034058


epoch 47
validation loss: 0.4243217706680298
train loss: 0.4358236789703369
accuracy: 0.8690000176429749


epoch 48
validation loss: 0.42252859473228455
train loss: 0.43190935254096985
accuracy: 0.8734166622161865


epoch 49
validation loss: 0.4167732298374176
train loss: 0.4277592599391937
accuracy: 0.8722833395004272


epoch 50
validation loss: 0.41566863656044006
train loss: 0.42469915747642517
accuracy: 0.875249981880188


epoch 51
validation loss: 0.4111364185810089
train loss: 0.42160138487815857
accuracy: 0.8744833469390869


epoch 52
validation loss: 0.41114169359207153
train loss: 0.4198891520500183
accuracy: 0.8761000037193298


epoch 53
validation loss: 0.40817520022392273
train loss: 0.418091744184494
accuracy: 0.8753499984741211


epoch 54
validation loss: 0.41018104553222656
train loss: 0.4187377095222473
accuracy: 0.8759333491325378


epoch 55
validation loss: 0.4091891646385193
train loss: 0.41852423548698425
accuracy: 0.8740666508674622


epoch 56
validation loss: 0.41472214460372925
train loss: 0.42317917943000793
accuracy: 0.87336665391922


epoch 57
validation loss: 0.41550925374031067
train loss: 0.4242120385169983
accuracy: 0.8715167045593262


epoch 58
validation loss: 0.42603418231010437
train loss: 0.4344639480113983
accuracy: 0.8686167001724243


epoch 59
validation loss: 0.4259662926197052
train loss: 0.43401408195495605
accuracy: 0.8682000041007996


epoch 60
validation loss: 0.43949228525161743
train loss: 0.4479300081729889
accuracy: 0.8633833527565002


epoch 61
validation loss: 0.43219423294067383
train loss: 0.4395970106124878
accuracy: 0.8658833503723145


epoch 62
validation loss: 0.44001853466033936
train loss: 0.4486149549484253
accuracy: 0.8625333309173584


epoch 63
validation loss: 0.4229450225830078
train loss: 0.42973002791404724
accuracy: 0.8697500228881836


epoch 64
validation loss: 0.4187626242637634
train loss: 0.42750319838523865
accuracy: 0.8693333268165588


epoch 65
validation loss: 0.40092435479164124
train loss: 0.40716466307640076
accuracy: 0.8785333633422852


epoch 66
validation loss: 0.390180379152298
train loss: 0.39872950315475464
accuracy: 0.8792499899864197


epoch 67
validation loss: 0.3786790370941162
train loss: 0.38462117314338684
accuracy: 0.8875499963760376


epoch 68
validation loss: 0.3687390983104706
train loss: 0.37683331966400146
accuracy: 0.887583315372467


epoch 69
validation loss: 0.3628421425819397
train loss: 0.3686543107032776
accuracy: 0.8930166959762573


epoch 70
validation loss: 0.355516642332077
train loss: 0.36311420798301697
accuracy: 0.8926500082015991


epoch 71
validation loss: 0.35248467326164246
train loss: 0.3581932783126831
accuracy: 0.8963333368301392


epoch 72
validation loss: 0.3471384048461914
train loss: 0.3542581796646118
accuracy: 0.895633339881897


epoch 73
validation loss: 0.345338374376297
train loss: 0.35090339183807373
accuracy: 0.8987833261489868


epoch 74
validation loss: 0.3412402868270874
train loss: 0.3479164242744446
accuracy: 0.897683322429657


epoch 75
validation loss: 0.33992841839790344
train loss: 0.3453098237514496
accuracy: 0.9003166556358337


epoch 76
validation loss: 0.336586058139801
train loss: 0.34285449981689453
accuracy: 0.899566650390625


epoch 77
validation loss: 0.3354648947715759
train loss: 0.34064599871635437
accuracy: 0.9019666910171509


epoch 78
validation loss: 0.3326025903224945
train loss: 0.3385034203529358
accuracy: 0.9012333154678345


epoch 79
validation loss: 0.3315693438053131
train loss: 0.33653268218040466
accuracy: 0.9030333161354065


epoch 80
validation loss: 0.3290373980998993
train loss: 0.33460214734077454
accuracy: 0.9025499820709229


epoch 81
validation loss: 0.328059583902359
train loss: 0.33279159665107727
accuracy: 0.9039833545684814


epoch 82
validation loss: 0.32575860619544983
train loss: 0.3310023546218872
accuracy: 0.9035166501998901


epoch 83
validation loss: 0.3248150944709778
train loss: 0.3293081820011139
accuracy: 0.9050166606903076


epoch 84
validation loss: 0.3226830065250397
train loss: 0.32762911915779114
accuracy: 0.9047166705131531


epoch 85
validation loss: 0.3217836916446686
train loss: 0.3260313868522644
accuracy: 0.9059000015258789


epoch 86
validation loss: 0.3197770416736603
train loss: 0.3244391679763794
accuracy: 0.9055666923522949


epoch 87
validation loss: 0.31892064213752747
train loss: 0.3229207694530487
accuracy: 0.9067833423614502


epoch 88
validation loss: 0.31702181696891785
train loss: 0.3214121162891388
accuracy: 0.9064666628837585


epoch 89
validation loss: 0.31621429324150085
train loss: 0.31996384263038635
accuracy: 0.907633364200592


epoch 90
validation loss: 0.3143947422504425
train loss: 0.31851959228515625
accuracy: 0.9072999954223633


epoch 91
validation loss: 0.31363949179649353
train loss: 0.3171364665031433
accuracy: 0.90829998254776


epoch 92
validation loss: 0.31189385056495667
train loss: 0.3157590627670288
accuracy: 0.9079999923706055


epoch 93
validation loss: 0.3111976981163025
train loss: 0.31443601846694946
accuracy: 0.9089666604995728


epoch 94
validation loss: 0.3095097243785858
train loss: 0.31311991810798645
accuracy: 0.9086000323295593


epoch 95
validation loss: 0.3088778257369995
train loss: 0.3118544816970825
accuracy: 0.9096166491508484


epoch 96
validation loss: 0.30723172426223755
train loss: 0.3105928599834442
accuracy: 0.9092666506767273


epoch 97
validation loss: 0.30667465925216675
train loss: 0.3093916177749634
accuracy: 0.9103333353996277


epoch 98
validation loss: 0.30506303906440735
train loss: 0.3081871271133423
accuracy: 0.9098833203315735


epoch 99
validation loss: 0.3045886754989624
train loss: 0.3070499300956726
accuracy: 0.9108999967575073


In [10]:
im = 0
visualize_images(test_images[im])
forward(net_parameters, test_images[im])
Out[10]:
Array([ 0.34763533, -3.1584437 ,  1.6465373 ,  0.7349794 , -0.6372952 ,
       -0.7725915 , -3.224408  ,  9.081954  , -0.8517007 ,  2.2727547 ],      dtype=float32)
In [11]:
# the magnitude of the gradient at each training step

grad_norms = {
    'w0':[],
    'w1':[],
    'w2':[],
    'w3':[]
}
for grad_vector in grad_history:
    grad_norms['w0'].append(np.linalg.norm(grad_vector['w0'].flatten()))
    grad_norms['w1'].append(np.linalg.norm(grad_vector['w1'].flatten()))
    grad_norms['w2'].append(np.linalg.norm(grad_vector['w2'].flatten()))
    grad_norms['w3'].append(np.linalg.norm(grad_vector['w3'].flatten()))

fig = px.line(grad_norms)
fig.show()
In [12]:
# for each training step, calculate the the angle between the current and previous vector

grad_cosines = {
    'w0':[],
    'w1':[],
    'w2':[],
    'w3':[]
}

grad_angles = {
    'w0':[],
    'w1':[],
    'w2':[],
    'w3':[]
}

for i in range(1,len(grad_history)):
    for key in ['w0','w1','w2','w3']:
        g_i = grad_history[i][key].flatten()
        g_i_norm = g_i / np.linalg.norm(g_i)
        g_im1 = grad_history[i-1][key].flatten()
        g_im1_norm = g_im1 / np.linalg.norm(g_im1)
        cos = g_i_norm @ g_im1_norm 
        grad_cosines[key].append(cos)
        angle = np.degrees(np.arccos(cos))
        grad_angles[key].append(angle)

fig_0 = px.line(grad_cosines, title="Cosine Between Each Previous Gradient")
fig_0.show()

fig_1 = px.line(grad_angles, title="Angle Between Each Previous Gradient")
fig_1.show()
In [13]:
# Here we are going to find the similarity between each gradient, and each other gradient (per weight)
for key in ['w0','w1','w2','w3']:
    # The history of every gradient for this parameter during the training process
    history = [gradient_dict[key] for gradient_dict in grad_history]
    # convert from a list to an numpy array
    history = np.array(history)
    #print(history.shape) # should have a shape: (training_epochs, output_dim, input_dim)
    
    training_epochs, output_dim, input_dim = history.shape
    history = history.reshape(training_epochs, output_dim * input_dim)

    # normalize the gradient vector for each time step
    magnitudes = np.linalg.norm(history, axis=-1)
    history = (history.T / magnitudes).T

    # find the cosine of the angle between the gradient of each step, and each other step
    similarity_matrix = history @ history.T
    fig = px.imshow(similarity_matrix.reshape(100,100), title=f"Similarity Matrix for {key} Gradients")
    fig.update_layout(
        autosize=False,
        width=800,
        height=800,
        margin=dict(
            l=50,
            r=50,
            b=100,
            t=100,
            pad=4
        ),
    )
    fig.show()
    

Visualizing the weight changes¶

We will use a process similar to stochastic neighbor embedding in which we will have $T$ vectors representing the changes that the weights undergo in each time step. We randomly create $T$ vectors in low dimensional space, and optimize them so that they have the most similar relative angles, and relative magnitudes as the changes in the training weights.

In [14]:
parameter_deltas = np.array(parameter_deltas)
In [15]:
parameter_deltas.shape
Out[15]:
(100, 334336)
In [16]:
def similarity_map(repr):
    magnitudes = jnp.linalg.norm(repr, axis=-1)
    repr = (repr.T / magnitudes).T
    return repr @ repr.T
In [17]:
# find the similarities between parameters, and other paramteters
parameter_similarities = similarity_map(parameter_deltas)

# find the magnitudes of each parameter change
parameter_delta_magnitudes = np.linalg.norm(parameter_deltas, axis=-1)
In [18]:
px.line(parameter_delta_magnitudes, title="magnitude of each parameter change").show()
In [19]:
fig = px.imshow(parameter_similarities.reshape(100,100), title=f"Similarity Matrix for Weight Deltas During Training")
fig.update_layout(
        autosize=False,
        width=800,
        height=800,
        margin=dict(
            l=50,
            r=50,
            b=100,
            t=100,
            pad=4
        ),
    )
fig.show()
In [20]:
deltas = np.random.randn(100, 3) * 0.1 
In [21]:
def visualize_delta_similarity_map(similarities, title=""):
    fig = px.imshow(similarities.reshape(100,100), title=title)
    fig.update_layout(
            autosize=False,
            width=800,
            height=800,
            margin=dict(
                l=50,
                r=50,
                b=100,
                t=100,
                pad=4
            ),
        )
    fig.show()
ld_similarities = similarity_map(deltas)
visualize_delta_similarity_map(ld_similarities, title="Intial Similarity Matrix for random 3d deltas")
In [22]:
def construct_points(deltas):
    # given a list of deltas, create points in space starting at the origin
    dim = deltas.shape[-1]
    points = [np.zeros(dim)]
    for delta in deltas: 
        last_point = points[-1]
        new_point = last_point + delta
        points.append(new_point)
    return np.array(points)
In [23]:
def calculate_distances(deltas):
    # given a list of deltas, find the total distance traveled along each point
    # find the distance that each delta goes
    magnitudes = np.linalg.norm(deltas, axis=-1)
    # find the distance traveled at each point
    distances = [magnitudes[:i].sum() for i in range(0, len(deltas))]
    return distances
In [24]:
def visualize_deltas(deltas):
    points = construct_points(deltas)
    
    fig = go.Figure(data=go.Scatter3d(
        x=points[:,0], y=points[:,1], z=points[:,2],
        marker=dict(
            size=4,
            color=calculate_distances(deltas),
            colorscale='Viridis',
        ),
        line=dict(
            color='darkblue',
            width=2
        )
    ))
    
    fig.update_layout(
        width=800,
        height=700,
        autosize=False,
        scene=dict(
            camera=dict(
                up=dict(
                    x=0,
                    y=0,
                    z=1
                ),
                eye=dict(
                    x=0,
                    y=1.0707,
                    z=1,
                )
            ),
            aspectratio = dict( x=1, y=1, z=0.7 ),
            aspectmode = 'manual'
        ),
    )
    
    fig.show()

visualize_deltas(deltas)
In [25]:
def angle_similarity_loss(deltas):
    # compare the relative similarities between angles between deltas, and parameter deltas
    ld_similarities = similarity_map(deltas)
    difference = ld_similarities - parameter_similarities
    return (difference ** 2).mean()

angle_similarity_loss(deltas)
Out[25]:
Array(0.6063774, dtype=float32)
In [26]:
def magnitude_similarity_loss(deltas):
    # compare the similarity between the magnitudes of the deltas, and the magnitudes of the parameter changes
    magnitudes = jnp.linalg.norm(deltas, axis=-1)
    return ((magnitudes - parameter_delta_magnitudes) ** 2).mean()

magnitude_similarity_loss(deltas)
Out[26]:
Array(0.00914208, dtype=float32)
In [27]:
def total_loss(deltas):
    return magnitude_similarity_loss(deltas) + angle_similarity_loss(deltas)

total_loss(deltas)
Out[27]:
Array(0.6155195, dtype=float32)
In [28]:
deltas = np.random.randn(100, 3) * 0.1 
In [29]:
for i in range(100):
    angle_similarity_grad_fn = jax.grad(angle_similarity_loss)
    grad = angle_similarity_grad_fn(deltas)
    deltas -= 0.5 * grad
    print(angle_similarity_loss(deltas))
0.5964816
0.5829333
0.5622697
0.5381429
0.50967044
0.47500637
0.4340699
0.39085728
0.34838098
0.3070019
0.2680783
0.2345094
0.207783
0.18676044
0.16971576
0.15537523
0.14294219
0.1319135
0.12195632
0.112854674
0.10449049
0.09683234
0.089912236
0.08378158
0.078463316
0.07392803
0.07010139
0.06688679
0.06418512
0.061906323
0.059973497
0.05832324
0.056903996
0.055674277
0.054600693
0.053656347
0.052819524
0.05207257
0.05140107
0.05079317
0.05023908
0.049730603
0.049260907
0.048824254
0.048415806
0.048031535
0.04766804
0.0473226
0.04699296
0.04667753
0.046374846
0.046084583
0.045805406
0.045538396
0.045280438
0.04503451
0.044792477
0.04456093
0.04432836
0.04410638
0.043886006
0.04367704
0.0434731
0.0432785
0.043088365
0.04290485
0.04272402
0.042547736
0.04237282
0.042200852
0.04202944
0.041859604
0.04168965
0.041520078
0.041349847
0.041179154
0.04100745
0.040834904
0.040661335
0.04048702
0.040312037
0.04013674
0.03996139
0.03978637
0.03961198
0.039438535
0.03926626
0.039095327
0.038925815
0.03875775
0.03859107
0.03842565
0.038261384
0.038098194
0.037936
0.037774857
0.037614897
0.037456393
0.03729974
0.03714541
In [30]:
visualize_delta_similarity_map(similarity_map(deltas))
visualize_delta_similarity_map(parameter_similarities)
In [31]:
def magnitude_update(deltas):
    # change the deltas to have the same magnitudes as the corresponding weight updates
    normalized_deltas = (deltas.T / np.linalg.norm(deltas, axis=-1)).T
    return np.array([delta * parameter_delta_magnitudes[i] for i, delta in enumerate(deltas)])
In [32]:
visualize_deltas(magnitude_update(deltas))
In [33]:
delta_opt_history = {}
for dim in range(2, 11):

    deltas = np.random.randn(100, dim) * 0.1
    history = []
    
    for i in range(100):
        angle_similarity_grad_fn = jax.grad(angle_similarity_loss)
        grad = angle_similarity_grad_fn(deltas)
        deltas -= 0.5 * grad
        error = angle_similarity_loss(deltas)
        history.append(error)
    delta_opt_history[f"{dim}"] = history
    
px.line(delta_opt_history).show() 

Example Gradient Descent for a simple function¶

In [37]:
def _L(d):
    return (0.7*d[:,:,0])**2 + d[:,:,1]**2
def L(d):
    return (0.7*d[0])**2 + (d[1])**2

x = np.arange(-10, 10, 0.5)
y = np.arange(-10, 10, 0.5)
X,Y = np.meshgrid(x,y)
X = np.expand_dims(X,axis=-1)
Y = np.expand_dims(Y,axis=-1)
domain = np.concatenate([X,Y],axis=-1)

Z = _L(domain)
Z.shape
Out[37]:
(40, 40)
In [38]:
parameters = np.array([8.0,8.0])
loss_grad = jax.grad(L)
point_history = [np.concatenate([parameters, [L(parameters)]])]
for i in range(10):
    parameters -= 0.2 * loss_grad(parameters)
    point_history.append(np.concatenate([parameters, [L(parameters)+0.1]]))
point_history = np.array(point_history)
In [39]:
fig = go.Figure(go.Surface(
    contours = {
        "z": {"show": True, "start": 0, "end": 200, "size": 5, "color":"white"},
    },
    x = x,
    y = y,
    z = Z
))
line_marker = dict(color='#ffffff', width=4)
fig.add_scatter3d(x=point_history[:,0], y=point_history[:,1], z=point_history[:,2], mode='lines', line=line_marker, name='')
fig.update_layout(
            autosize=False,
            width=800,
            height=800,
            
        )
fig.update_layout(
        scene = {
            "aspectratio": {"x": 1, "y": 1, "z": 0.5}
        })
fig.show()