In [1]:
import jax.numpy as jnp
import plotly.express as px
from plotly.subplots import make_subplots
import jax
import numpy as np
from datasets import mnist
import plotly.graph_objects as go
import copy
In [2]:
train_images, train_labels, test_images, test_labels = mnist()
train_images = train_images.astype(jnp.float32)
test_images = test_images.astype(jnp.float32)
train_labels = jnp.asarray(train_labels, dtype=jnp.int32)
test_labels = jnp.asarray(test_labels, dtype=jnp.int32)
In [3]:
def visualize_images(images_tensor, w=28, h=28, col_wrap=5):
img = images_tensor.reshape(-1, w, h)
fig = px.imshow(img[:, :, :], binary_string=False, facet_col=0, facet_col_wrap=col_wrap)
item_map={f'{i}':"" for i, key in enumerate(range(img.shape[0]))}
fig.for_each_annotation(lambda a: a.update(text=item_map[a.text.split("=")[1]]))
fig.show()
In [5]:
net_parameters = {
'w0' : np.random.randn(256, 784) * 0.1,
'w1' : np.random.randn(256, 256) * 0.1,
'w2' : np.random.randn(256, 256) * 0.1,
'w3' : np.random.randn(10, 256) * 0.1,
}
In [6]:
def ReLU(x):
return jnp.maximum(0,x)
def forward(parameters, x):
x = x.T
x = parameters['w0'] @ x
x = ReLU(x)
x = parameters['w1'] @ x
x = ReLU(x)
x = parameters['w2'] @ x
x = ReLU(x)
x = parameters['w3'] @ x
x = x.T
return x
In [7]:
def loss(parameters, x, y):
out = forward(parameters, x)
out = jax.nn.softmax(out)
_loss = -(y * jnp.log(out)).sum(axis=-1).mean()
return _loss
loss(net_parameters, test_images, test_labels)
Out[7]:
Array(2.731128, dtype=float32)
In [8]:
(forward(net_parameters, train_images).argmax(axis=-1) == train_labels.argmax(axis=-1)).mean()
Out[8]:
Array(0.06178333, dtype=float32)
In [9]:
grad_loss = jax.grad(loss)
lr = 0.1
# keep track of all the previous gradients
grad_history = []
# keep track of all the previous parameters
parameter_deltas = []
for epoch in range(100):
p_grad = grad_loss(net_parameters, train_images, train_labels)
grad_history.append(p_grad)
net_parameters['w0'] -= lr * p_grad['w0']
net_parameters['w1'] -= lr * p_grad['w1']
net_parameters['w2'] -= lr * p_grad['w2']
net_parameters['w3'] -= lr * p_grad['w3']
# record the changes that were made to the parameters for analysis
parameter_delta_vector = np.concatenate((
-lr * p_grad['w0'].flatten(),
-lr * p_grad['w1'].flatten(),
-lr * p_grad['w2'].flatten(),
-lr * p_grad['w3'].flatten()
))
parameter_deltas.append(parameter_delta_vector)
print(f"epoch {epoch}")
print(f"validation loss: {loss(net_parameters, test_images, test_labels)}")
print(f"train loss: {loss(net_parameters, train_images, train_labels)}")
acc = (forward(net_parameters, train_images).argmax(axis=-1) == train_labels.argmax(axis=-1)).mean()
print(f"accuracy: {acc}")
print("\n")
epoch 0 validation loss: 2.277862548828125 train loss: 2.276801347732544 accuracy: 0.20038333535194397 epoch 1 validation loss: 1.9986860752105713 train loss: 2.0062105655670166 accuracy: 0.28600001335144043 epoch 2 validation loss: 1.8096214532852173 train loss: 1.8151549100875854 accuracy: 0.3808833360671997 epoch 3 validation loss: 1.7291929721832275 train loss: 1.7464510202407837 accuracy: 0.4345833361148834 epoch 4 validation loss: 1.49941086769104 train loss: 1.5121557712554932 accuracy: 0.5228833556175232 epoch 5 validation loss: 1.4020278453826904 train loss: 1.424275279045105 accuracy: 0.5604666471481323 epoch 6 validation loss: 1.3016060590744019 train loss: 1.3177744150161743 accuracy: 0.5866666436195374 epoch 7 validation loss: 1.2220079898834229 train loss: 1.247183084487915 accuracy: 0.609499990940094 epoch 8 validation loss: 1.157178282737732 train loss: 1.175179362297058 accuracy: 0.6294500231742859 epoch 9 validation loss: 1.0745911598205566 train loss: 1.1006921529769897 accuracy: 0.6504666805267334 epoch 10 validation loss: 1.0312285423278809 train loss: 1.0506861209869385 accuracy: 0.6718666553497314 epoch 11 validation loss: 0.9537557363510132 train loss: 0.9796228408813477 accuracy: 0.6926500201225281 epoch 12 validation loss: 0.9217555522918701 train loss: 0.9424632787704468 accuracy: 0.7120500206947327 epoch 13 validation loss: 0.8608860373497009 train loss: 0.8858886957168579 accuracy: 0.726983368396759 epoch 14 validation loss: 0.8320112824440002 train loss: 0.8535038232803345 accuracy: 0.7462166547775269 epoch 15 validation loss: 0.7871823906898499 train loss: 0.8111382126808167 accuracy: 0.7528499960899353 epoch 16 validation loss: 0.7618271708488464 train loss: 0.7834699153900146 accuracy: 0.7703499794006348 epoch 17 validation loss: 0.7334457635879517 train loss: 0.7562983632087708 accuracy: 0.76746666431427 epoch 18 validation loss: 0.7144235372543335 train loss: 0.735580325126648 accuracy: 0.7815499901771545 epoch 19 validation loss: 0.7055370211601257 train loss: 0.7271606922149658 accuracy: 0.7684000134468079 epoch 20 validation loss: 0.6890653967857361 train loss: 0.7091876268386841 accuracy: 0.7810333371162415 epoch 21 validation loss: 0.6934828758239746 train loss: 0.7136143445968628 accuracy: 0.7663333415985107 epoch 22 validation loss: 0.6664817929267883 train loss: 0.6853572726249695 accuracy: 0.7836333513259888 epoch 23 validation loss: 0.6628146171569824 train loss: 0.681693434715271 accuracy: 0.7766333222389221 epoch 24 validation loss: 0.6299082636833191 train loss: 0.6479393839836121 accuracy: 0.7970499992370605 epoch 25 validation loss: 0.6206050515174866 train loss: 0.6385985612869263 accuracy: 0.7936999797821045 epoch 26 validation loss: 0.5956608057022095 train loss: 0.612824559211731 accuracy: 0.810283362865448 epoch 27 validation loss: 0.5865873694419861 train loss: 0.6037764549255371 accuracy: 0.807033360004425 epoch 28 validation loss: 0.568708598613739 train loss: 0.5849623680114746 accuracy: 0.8195500373840332 epoch 29 validation loss: 0.5596689581871033 train loss: 0.5761591196060181 accuracy: 0.8181333541870117 epoch 30 validation loss: 0.5462689995765686 train loss: 0.5616114735603333 accuracy: 0.82833331823349 epoch 31 validation loss: 0.5368631482124329 train loss: 0.5527257919311523 accuracy: 0.8262166976928711 epoch 32 validation loss: 0.5265526175498962 train loss: 0.541017472743988 accuracy: 0.8357666730880737 epoch 33 validation loss: 0.5168095827102661 train loss: 0.5320915579795837 accuracy: 0.8339666724205017 epoch 34 validation loss: 0.5087394118309021 train loss: 0.5223720669746399 accuracy: 0.8425499796867371 epoch 35 validation loss: 0.49895179271698 train loss: 0.513683557510376 accuracy: 0.8413166999816895 epoch 36 validation loss: 0.49255242943763733 train loss: 0.505401074886322 accuracy: 0.8484833240509033 epoch 37 validation loss: 0.4828387498855591 train loss: 0.497036337852478 accuracy: 0.8472999930381775 epoch 38 validation loss: 0.47769641876220703 train loss: 0.4898192882537842 accuracy: 0.8543333411216736 epoch 39 validation loss: 0.4683235287666321 train loss: 0.48197564482688904 accuracy: 0.8528000116348267 epoch 40 validation loss: 0.4641515016555786 train loss: 0.47560182213783264 accuracy: 0.8591166734695435 epoch 41 validation loss: 0.45532241463661194 train loss: 0.4684279263019562 accuracy: 0.8571500182151794 epoch 42 validation loss: 0.4518907964229584 train loss: 0.4627283215522766 accuracy: 0.8634333610534668 epoch 43 validation loss: 0.44364094734191895 train loss: 0.4562109708786011 accuracy: 0.8617166876792908 epoch 44 validation loss: 0.4408518373966217 train loss: 0.4511384963989258 accuracy: 0.8675333261489868 epoch 45 validation loss: 0.43332454562187195 train loss: 0.4453541040420532 accuracy: 0.8657333254814148 epoch 46 validation loss: 0.43103447556495667 train loss: 0.4408362805843353 accuracy: 0.8708666563034058 epoch 47 validation loss: 0.4243217706680298 train loss: 0.4358236789703369 accuracy: 0.8690000176429749 epoch 48 validation loss: 0.42252859473228455 train loss: 0.43190935254096985 accuracy: 0.8734166622161865 epoch 49 validation loss: 0.4167732298374176 train loss: 0.4277592599391937 accuracy: 0.8722833395004272 epoch 50 validation loss: 0.41566863656044006 train loss: 0.42469915747642517 accuracy: 0.875249981880188 epoch 51 validation loss: 0.4111364185810089 train loss: 0.42160138487815857 accuracy: 0.8744833469390869 epoch 52 validation loss: 0.41114169359207153 train loss: 0.4198891520500183 accuracy: 0.8761000037193298 epoch 53 validation loss: 0.40817520022392273 train loss: 0.418091744184494 accuracy: 0.8753499984741211 epoch 54 validation loss: 0.41018104553222656 train loss: 0.4187377095222473 accuracy: 0.8759333491325378 epoch 55 validation loss: 0.4091891646385193 train loss: 0.41852423548698425 accuracy: 0.8740666508674622 epoch 56 validation loss: 0.41472214460372925 train loss: 0.42317917943000793 accuracy: 0.87336665391922 epoch 57 validation loss: 0.41550925374031067 train loss: 0.4242120385169983 accuracy: 0.8715167045593262 epoch 58 validation loss: 0.42603418231010437 train loss: 0.4344639480113983 accuracy: 0.8686167001724243 epoch 59 validation loss: 0.4259662926197052 train loss: 0.43401408195495605 accuracy: 0.8682000041007996 epoch 60 validation loss: 0.43949228525161743 train loss: 0.4479300081729889 accuracy: 0.8633833527565002 epoch 61 validation loss: 0.43219423294067383 train loss: 0.4395970106124878 accuracy: 0.8658833503723145 epoch 62 validation loss: 0.44001853466033936 train loss: 0.4486149549484253 accuracy: 0.8625333309173584 epoch 63 validation loss: 0.4229450225830078 train loss: 0.42973002791404724 accuracy: 0.8697500228881836 epoch 64 validation loss: 0.4187626242637634 train loss: 0.42750319838523865 accuracy: 0.8693333268165588 epoch 65 validation loss: 0.40092435479164124 train loss: 0.40716466307640076 accuracy: 0.8785333633422852 epoch 66 validation loss: 0.390180379152298 train loss: 0.39872950315475464 accuracy: 0.8792499899864197 epoch 67 validation loss: 0.3786790370941162 train loss: 0.38462117314338684 accuracy: 0.8875499963760376 epoch 68 validation loss: 0.3687390983104706 train loss: 0.37683331966400146 accuracy: 0.887583315372467 epoch 69 validation loss: 0.3628421425819397 train loss: 0.3686543107032776 accuracy: 0.8930166959762573 epoch 70 validation loss: 0.355516642332077 train loss: 0.36311420798301697 accuracy: 0.8926500082015991 epoch 71 validation loss: 0.35248467326164246 train loss: 0.3581932783126831 accuracy: 0.8963333368301392 epoch 72 validation loss: 0.3471384048461914 train loss: 0.3542581796646118 accuracy: 0.895633339881897 epoch 73 validation loss: 0.345338374376297 train loss: 0.35090339183807373 accuracy: 0.8987833261489868 epoch 74 validation loss: 0.3412402868270874 train loss: 0.3479164242744446 accuracy: 0.897683322429657 epoch 75 validation loss: 0.33992841839790344 train loss: 0.3453098237514496 accuracy: 0.9003166556358337 epoch 76 validation loss: 0.336586058139801 train loss: 0.34285449981689453 accuracy: 0.899566650390625 epoch 77 validation loss: 0.3354648947715759 train loss: 0.34064599871635437 accuracy: 0.9019666910171509 epoch 78 validation loss: 0.3326025903224945 train loss: 0.3385034203529358 accuracy: 0.9012333154678345 epoch 79 validation loss: 0.3315693438053131 train loss: 0.33653268218040466 accuracy: 0.9030333161354065 epoch 80 validation loss: 0.3290373980998993 train loss: 0.33460214734077454 accuracy: 0.9025499820709229 epoch 81 validation loss: 0.328059583902359 train loss: 0.33279159665107727 accuracy: 0.9039833545684814 epoch 82 validation loss: 0.32575860619544983 train loss: 0.3310023546218872 accuracy: 0.9035166501998901 epoch 83 validation loss: 0.3248150944709778 train loss: 0.3293081820011139 accuracy: 0.9050166606903076 epoch 84 validation loss: 0.3226830065250397 train loss: 0.32762911915779114 accuracy: 0.9047166705131531 epoch 85 validation loss: 0.3217836916446686 train loss: 0.3260313868522644 accuracy: 0.9059000015258789 epoch 86 validation loss: 0.3197770416736603 train loss: 0.3244391679763794 accuracy: 0.9055666923522949 epoch 87 validation loss: 0.31892064213752747 train loss: 0.3229207694530487 accuracy: 0.9067833423614502 epoch 88 validation loss: 0.31702181696891785 train loss: 0.3214121162891388 accuracy: 0.9064666628837585 epoch 89 validation loss: 0.31621429324150085 train loss: 0.31996384263038635 accuracy: 0.907633364200592 epoch 90 validation loss: 0.3143947422504425 train loss: 0.31851959228515625 accuracy: 0.9072999954223633 epoch 91 validation loss: 0.31363949179649353 train loss: 0.3171364665031433 accuracy: 0.90829998254776 epoch 92 validation loss: 0.31189385056495667 train loss: 0.3157590627670288 accuracy: 0.9079999923706055 epoch 93 validation loss: 0.3111976981163025 train loss: 0.31443601846694946 accuracy: 0.9089666604995728 epoch 94 validation loss: 0.3095097243785858 train loss: 0.31311991810798645 accuracy: 0.9086000323295593 epoch 95 validation loss: 0.3088778257369995 train loss: 0.3118544816970825 accuracy: 0.9096166491508484 epoch 96 validation loss: 0.30723172426223755 train loss: 0.3105928599834442 accuracy: 0.9092666506767273 epoch 97 validation loss: 0.30667465925216675 train loss: 0.3093916177749634 accuracy: 0.9103333353996277 epoch 98 validation loss: 0.30506303906440735 train loss: 0.3081871271133423 accuracy: 0.9098833203315735 epoch 99 validation loss: 0.3045886754989624 train loss: 0.3070499300956726 accuracy: 0.9108999967575073
In [10]:
im = 0
visualize_images(test_images[im])
forward(net_parameters, test_images[im])
Out[10]:
Array([ 0.34763533, -3.1584437 , 1.6465373 , 0.7349794 , -0.6372952 , -0.7725915 , -3.224408 , 9.081954 , -0.8517007 , 2.2727547 ], dtype=float32)
In [11]:
# the magnitude of the gradient at each training step
grad_norms = {
'w0':[],
'w1':[],
'w2':[],
'w3':[]
}
for grad_vector in grad_history:
grad_norms['w0'].append(np.linalg.norm(grad_vector['w0'].flatten()))
grad_norms['w1'].append(np.linalg.norm(grad_vector['w1'].flatten()))
grad_norms['w2'].append(np.linalg.norm(grad_vector['w2'].flatten()))
grad_norms['w3'].append(np.linalg.norm(grad_vector['w3'].flatten()))
fig = px.line(grad_norms)
fig.show()
In [12]:
# for each training step, calculate the the angle between the current and previous vector
grad_cosines = {
'w0':[],
'w1':[],
'w2':[],
'w3':[]
}
grad_angles = {
'w0':[],
'w1':[],
'w2':[],
'w3':[]
}
for i in range(1,len(grad_history)):
for key in ['w0','w1','w2','w3']:
g_i = grad_history[i][key].flatten()
g_i_norm = g_i / np.linalg.norm(g_i)
g_im1 = grad_history[i-1][key].flatten()
g_im1_norm = g_im1 / np.linalg.norm(g_im1)
cos = g_i_norm @ g_im1_norm
grad_cosines[key].append(cos)
angle = np.degrees(np.arccos(cos))
grad_angles[key].append(angle)
fig_0 = px.line(grad_cosines, title="Cosine Between Each Previous Gradient")
fig_0.show()
fig_1 = px.line(grad_angles, title="Angle Between Each Previous Gradient")
fig_1.show()
In [13]:
# Here we are going to find the similarity between each gradient, and each other gradient (per weight)
for key in ['w0','w1','w2','w3']:
# The history of every gradient for this parameter during the training process
history = [gradient_dict[key] for gradient_dict in grad_history]
# convert from a list to an numpy array
history = np.array(history)
#print(history.shape) # should have a shape: (training_epochs, output_dim, input_dim)
training_epochs, output_dim, input_dim = history.shape
history = history.reshape(training_epochs, output_dim * input_dim)
# normalize the gradient vector for each time step
magnitudes = np.linalg.norm(history, axis=-1)
history = (history.T / magnitudes).T
# find the cosine of the angle between the gradient of each step, and each other step
similarity_matrix = history @ history.T
fig = px.imshow(similarity_matrix.reshape(100,100), title=f"Similarity Matrix for {key} Gradients")
fig.update_layout(
autosize=False,
width=800,
height=800,
margin=dict(
l=50,
r=50,
b=100,
t=100,
pad=4
),
)
fig.show()
Visualizing the weight changes¶
We will use a process similar to stochastic neighbor embedding in which we will have $T$ vectors representing the changes that the weights undergo in each time step. We randomly create $T$ vectors in low dimensional space, and optimize them so that they have the most similar relative angles, and relative magnitudes as the changes in the training weights.
In [14]:
parameter_deltas = np.array(parameter_deltas)
In [15]:
parameter_deltas.shape
Out[15]:
(100, 334336)
In [16]:
def similarity_map(repr):
magnitudes = jnp.linalg.norm(repr, axis=-1)
repr = (repr.T / magnitudes).T
return repr @ repr.T
In [17]:
# find the similarities between parameters, and other paramteters
parameter_similarities = similarity_map(parameter_deltas)
# find the magnitudes of each parameter change
parameter_delta_magnitudes = np.linalg.norm(parameter_deltas, axis=-1)
In [18]:
px.line(parameter_delta_magnitudes, title="magnitude of each parameter change").show()
In [19]:
fig = px.imshow(parameter_similarities.reshape(100,100), title=f"Similarity Matrix for Weight Deltas During Training")
fig.update_layout(
autosize=False,
width=800,
height=800,
margin=dict(
l=50,
r=50,
b=100,
t=100,
pad=4
),
)
fig.show()
In [20]:
deltas = np.random.randn(100, 3) * 0.1
In [21]:
def visualize_delta_similarity_map(similarities, title=""):
fig = px.imshow(similarities.reshape(100,100), title=title)
fig.update_layout(
autosize=False,
width=800,
height=800,
margin=dict(
l=50,
r=50,
b=100,
t=100,
pad=4
),
)
fig.show()
ld_similarities = similarity_map(deltas)
visualize_delta_similarity_map(ld_similarities, title="Intial Similarity Matrix for random 3d deltas")
In [22]:
def construct_points(deltas):
# given a list of deltas, create points in space starting at the origin
dim = deltas.shape[-1]
points = [np.zeros(dim)]
for delta in deltas:
last_point = points[-1]
new_point = last_point + delta
points.append(new_point)
return np.array(points)
In [23]:
def calculate_distances(deltas):
# given a list of deltas, find the total distance traveled along each point
# find the distance that each delta goes
magnitudes = np.linalg.norm(deltas, axis=-1)
# find the distance traveled at each point
distances = [magnitudes[:i].sum() for i in range(0, len(deltas))]
return distances
In [24]:
def visualize_deltas(deltas):
points = construct_points(deltas)
fig = go.Figure(data=go.Scatter3d(
x=points[:,0], y=points[:,1], z=points[:,2],
marker=dict(
size=4,
color=calculate_distances(deltas),
colorscale='Viridis',
),
line=dict(
color='darkblue',
width=2
)
))
fig.update_layout(
width=800,
height=700,
autosize=False,
scene=dict(
camera=dict(
up=dict(
x=0,
y=0,
z=1
),
eye=dict(
x=0,
y=1.0707,
z=1,
)
),
aspectratio = dict( x=1, y=1, z=0.7 ),
aspectmode = 'manual'
),
)
fig.show()
visualize_deltas(deltas)
In [25]:
def angle_similarity_loss(deltas):
# compare the relative similarities between angles between deltas, and parameter deltas
ld_similarities = similarity_map(deltas)
difference = ld_similarities - parameter_similarities
return (difference ** 2).mean()
angle_similarity_loss(deltas)
Out[25]:
Array(0.6063774, dtype=float32)
In [26]:
def magnitude_similarity_loss(deltas):
# compare the similarity between the magnitudes of the deltas, and the magnitudes of the parameter changes
magnitudes = jnp.linalg.norm(deltas, axis=-1)
return ((magnitudes - parameter_delta_magnitudes) ** 2).mean()
magnitude_similarity_loss(deltas)
Out[26]:
Array(0.00914208, dtype=float32)
In [27]:
def total_loss(deltas):
return magnitude_similarity_loss(deltas) + angle_similarity_loss(deltas)
total_loss(deltas)
Out[27]:
Array(0.6155195, dtype=float32)
In [28]:
deltas = np.random.randn(100, 3) * 0.1
In [29]:
for i in range(100):
angle_similarity_grad_fn = jax.grad(angle_similarity_loss)
grad = angle_similarity_grad_fn(deltas)
deltas -= 0.5 * grad
print(angle_similarity_loss(deltas))
0.5964816 0.5829333 0.5622697 0.5381429 0.50967044 0.47500637 0.4340699 0.39085728 0.34838098 0.3070019 0.2680783 0.2345094 0.207783 0.18676044 0.16971576 0.15537523 0.14294219 0.1319135 0.12195632 0.112854674 0.10449049 0.09683234 0.089912236 0.08378158 0.078463316 0.07392803 0.07010139 0.06688679 0.06418512 0.061906323 0.059973497 0.05832324 0.056903996 0.055674277 0.054600693 0.053656347 0.052819524 0.05207257 0.05140107 0.05079317 0.05023908 0.049730603 0.049260907 0.048824254 0.048415806 0.048031535 0.04766804 0.0473226 0.04699296 0.04667753 0.046374846 0.046084583 0.045805406 0.045538396 0.045280438 0.04503451 0.044792477 0.04456093 0.04432836 0.04410638 0.043886006 0.04367704 0.0434731 0.0432785 0.043088365 0.04290485 0.04272402 0.042547736 0.04237282 0.042200852 0.04202944 0.041859604 0.04168965 0.041520078 0.041349847 0.041179154 0.04100745 0.040834904 0.040661335 0.04048702 0.040312037 0.04013674 0.03996139 0.03978637 0.03961198 0.039438535 0.03926626 0.039095327 0.038925815 0.03875775 0.03859107 0.03842565 0.038261384 0.038098194 0.037936 0.037774857 0.037614897 0.037456393 0.03729974 0.03714541
In [30]:
visualize_delta_similarity_map(similarity_map(deltas))
visualize_delta_similarity_map(parameter_similarities)
In [31]:
def magnitude_update(deltas):
# change the deltas to have the same magnitudes as the corresponding weight updates
normalized_deltas = (deltas.T / np.linalg.norm(deltas, axis=-1)).T
return np.array([delta * parameter_delta_magnitudes[i] for i, delta in enumerate(deltas)])
In [32]:
visualize_deltas(magnitude_update(deltas))
In [33]:
delta_opt_history = {}
for dim in range(2, 11):
deltas = np.random.randn(100, dim) * 0.1
history = []
for i in range(100):
angle_similarity_grad_fn = jax.grad(angle_similarity_loss)
grad = angle_similarity_grad_fn(deltas)
deltas -= 0.5 * grad
error = angle_similarity_loss(deltas)
history.append(error)
delta_opt_history[f"{dim}"] = history
px.line(delta_opt_history).show()
Example Gradient Descent for a simple function¶
In [37]:
def _L(d):
return (0.7*d[:,:,0])**2 + d[:,:,1]**2
def L(d):
return (0.7*d[0])**2 + (d[1])**2
x = np.arange(-10, 10, 0.5)
y = np.arange(-10, 10, 0.5)
X,Y = np.meshgrid(x,y)
X = np.expand_dims(X,axis=-1)
Y = np.expand_dims(Y,axis=-1)
domain = np.concatenate([X,Y],axis=-1)
Z = _L(domain)
Z.shape
Out[37]:
(40, 40)
In [38]:
parameters = np.array([8.0,8.0])
loss_grad = jax.grad(L)
point_history = [np.concatenate([parameters, [L(parameters)]])]
for i in range(10):
parameters -= 0.2 * loss_grad(parameters)
point_history.append(np.concatenate([parameters, [L(parameters)+0.1]]))
point_history = np.array(point_history)
In [39]:
fig = go.Figure(go.Surface(
contours = {
"z": {"show": True, "start": 0, "end": 200, "size": 5, "color":"white"},
},
x = x,
y = y,
z = Z
))
line_marker = dict(color='#ffffff', width=4)
fig.add_scatter3d(x=point_history[:,0], y=point_history[:,1], z=point_history[:,2], mode='lines', line=line_marker, name='')
fig.update_layout(
autosize=False,
width=800,
height=800,
)
fig.update_layout(
scene = {
"aspectratio": {"x": 1, "y": 1, "z": 0.5}
})
fig.show()