In [1]:
import jax.numpy as jnp
import plotly.express as px
from plotly.subplots import make_subplots
import jax
import numpy as np
from datasets import mnist
import plotly.graph_objects as go
import copy
In [2]:
train_images, train_labels, test_images, test_labels = mnist()
train_images = train_images.astype(jnp.float32)
test_images = test_images.astype(jnp.float32)
train_labels = jnp.asarray(train_labels, dtype=jnp.int32)
test_labels = jnp.asarray(test_labels, dtype=jnp.int32)
In [3]:
def visualize_images(images_tensor, w=28, h=28, col_wrap=5):
img = images_tensor.reshape(-1, w, h)
fig = px.imshow(img[:, :, :], binary_string=False, facet_col=0, facet_col_wrap=col_wrap)
item_map={f'{i}':"" for i, key in enumerate(range(img.shape[0]))}
fig.for_each_annotation(lambda a: a.update(text=item_map[a.text.split("=")[1]]))
fig.show()
In [5]:
net_parameters = {
'w0' : np.random.randn(256, 784) * 0.1,
'w1' : np.random.randn(256, 256) * 0.1,
'w2' : np.random.randn(256, 256) * 0.1,
'w3' : np.random.randn(10, 256) * 0.1,
}
In [6]:
def ReLU(x):
return jnp.maximum(0,x)
def forward(parameters, x):
x = x.T
x = parameters['w0'] @ x
x = ReLU(x)
x = parameters['w1'] @ x
x = ReLU(x)
x = parameters['w2'] @ x
x = ReLU(x)
x = parameters['w3'] @ x
x = x.T
return x
In [7]:
def loss(parameters, x, y):
out = forward(parameters, x)
out = jax.nn.softmax(out)
_loss = -(y * jnp.log(out)).sum(axis=-1).mean()
return _loss
loss(net_parameters, test_images, test_labels)
Out[7]:
Array(2.731128, dtype=float32)
In [8]:
(forward(net_parameters, train_images).argmax(axis=-1) == train_labels.argmax(axis=-1)).mean()
Out[8]:
Array(0.06178333, dtype=float32)
In [9]:
grad_loss = jax.grad(loss)
lr = 0.1
# keep track of all the previous gradients
grad_history = []
# keep track of all the previous parameters
parameter_deltas = []
for epoch in range(100):
p_grad = grad_loss(net_parameters, train_images, train_labels)
grad_history.append(p_grad)
net_parameters['w0'] -= lr * p_grad['w0']
net_parameters['w1'] -= lr * p_grad['w1']
net_parameters['w2'] -= lr * p_grad['w2']
net_parameters['w3'] -= lr * p_grad['w3']
# record the changes that were made to the parameters for analysis
parameter_delta_vector = np.concatenate((
-lr * p_grad['w0'].flatten(),
-lr * p_grad['w1'].flatten(),
-lr * p_grad['w2'].flatten(),
-lr * p_grad['w3'].flatten()
))
parameter_deltas.append(parameter_delta_vector)
print(f"epoch {epoch}")
print(f"validation loss: {loss(net_parameters, test_images, test_labels)}")
print(f"train loss: {loss(net_parameters, train_images, train_labels)}")
acc = (forward(net_parameters, train_images).argmax(axis=-1) == train_labels.argmax(axis=-1)).mean()
print(f"accuracy: {acc}")
print("\n")
epoch 0 validation loss: 2.277862548828125 train loss: 2.276801347732544 accuracy: 0.20038333535194397 epoch 1 validation loss: 1.9986860752105713 train loss: 2.0062105655670166 accuracy: 0.28600001335144043 epoch 2 validation loss: 1.8096214532852173 train loss: 1.8151549100875854 accuracy: 0.3808833360671997 epoch 3 validation loss: 1.7291929721832275 train loss: 1.7464510202407837 accuracy: 0.4345833361148834 epoch 4 validation loss: 1.49941086769104 train loss: 1.5121557712554932 accuracy: 0.5228833556175232 epoch 5 validation loss: 1.4020278453826904 train loss: 1.424275279045105 accuracy: 0.5604666471481323 epoch 6 validation loss: 1.3016060590744019 train loss: 1.3177744150161743 accuracy: 0.5866666436195374 epoch 7 validation loss: 1.2220079898834229 train loss: 1.247183084487915 accuracy: 0.609499990940094 epoch 8 validation loss: 1.157178282737732 train loss: 1.175179362297058 accuracy: 0.6294500231742859 epoch 9 validation loss: 1.0745911598205566 train loss: 1.1006921529769897 accuracy: 0.6504666805267334 epoch 10 validation loss: 1.0312285423278809 train loss: 1.0506861209869385 accuracy: 0.6718666553497314 epoch 11 validation loss: 0.9537557363510132 train loss: 0.9796228408813477 accuracy: 0.6926500201225281 epoch 12 validation loss: 0.9217555522918701 train loss: 0.9424632787704468 accuracy: 0.7120500206947327 epoch 13 validation loss: 0.8608860373497009 train loss: 0.8858886957168579 accuracy: 0.726983368396759 epoch 14 validation loss: 0.8320112824440002 train loss: 0.8535038232803345 accuracy: 0.7462166547775269 epoch 15 validation loss: 0.7871823906898499 train loss: 0.8111382126808167 accuracy: 0.7528499960899353 epoch 16 validation loss: 0.7618271708488464 train loss: 0.7834699153900146 accuracy: 0.7703499794006348 epoch 17 validation loss: 0.7334457635879517 train loss: 0.7562983632087708 accuracy: 0.76746666431427 epoch 18 validation loss: 0.7144235372543335 train loss: 0.735580325126648 accuracy: 0.7815499901771545 epoch 19 validation loss: 0.7055370211601257 train loss: 0.7271606922149658 accuracy: 0.7684000134468079 epoch 20 validation loss: 0.6890653967857361 train loss: 0.7091876268386841 accuracy: 0.7810333371162415 epoch 21 validation loss: 0.6934828758239746 train loss: 0.7136143445968628 accuracy: 0.7663333415985107 epoch 22 validation loss: 0.6664817929267883 train loss: 0.6853572726249695 accuracy: 0.7836333513259888 epoch 23 validation loss: 0.6628146171569824 train loss: 0.681693434715271 accuracy: 0.7766333222389221 epoch 24 validation loss: 0.6299082636833191 train loss: 0.6479393839836121 accuracy: 0.7970499992370605 epoch 25 validation loss: 0.6206050515174866 train loss: 0.6385985612869263 accuracy: 0.7936999797821045 epoch 26 validation loss: 0.5956608057022095 train loss: 0.612824559211731 accuracy: 0.810283362865448 epoch 27 validation loss: 0.5865873694419861 train loss: 0.6037764549255371 accuracy: 0.807033360004425 epoch 28 validation loss: 0.568708598613739 train loss: 0.5849623680114746 accuracy: 0.8195500373840332 epoch 29 validation loss: 0.5596689581871033 train loss: 0.5761591196060181 accuracy: 0.8181333541870117 epoch 30 validation loss: 0.5462689995765686 train loss: 0.5616114735603333 accuracy: 0.82833331823349 epoch 31 validation loss: 0.5368631482124329 train loss: 0.5527257919311523 accuracy: 0.8262166976928711 epoch 32 validation loss: 0.5265526175498962 train loss: 0.541017472743988 accuracy: 0.8357666730880737 epoch 33 validation loss: 0.5168095827102661 train loss: 0.5320915579795837 accuracy: 0.8339666724205017 epoch 34 validation loss: 0.5087394118309021 train loss: 0.5223720669746399 accuracy: 0.8425499796867371 epoch 35 validation loss: 0.49895179271698 train loss: 0.513683557510376 accuracy: 0.8413166999816895 epoch 36 validation loss: 0.49255242943763733 train loss: 0.505401074886322 accuracy: 0.8484833240509033 epoch 37 validation loss: 0.4828387498855591 train loss: 0.497036337852478 accuracy: 0.8472999930381775 epoch 38 validation loss: 0.47769641876220703 train loss: 0.4898192882537842 accuracy: 0.8543333411216736 epoch 39 validation loss: 0.4683235287666321 train loss: 0.48197564482688904 accuracy: 0.8528000116348267 epoch 40 validation loss: 0.4641515016555786 train loss: 0.47560182213783264 accuracy: 0.8591166734695435 epoch 41 validation loss: 0.45532241463661194 train loss: 0.4684279263019562 accuracy: 0.8571500182151794 epoch 42 validation loss: 0.4518907964229584 train loss: 0.4627283215522766 accuracy: 0.8634333610534668 epoch 43 validation loss: 0.44364094734191895 train loss: 0.4562109708786011 accuracy: 0.8617166876792908 epoch 44 validation loss: 0.4408518373966217 train loss: 0.4511384963989258 accuracy: 0.8675333261489868 epoch 45 validation loss: 0.43332454562187195 train loss: 0.4453541040420532 accuracy: 0.8657333254814148 epoch 46 validation loss: 0.43103447556495667 train loss: 0.4408362805843353 accuracy: 0.8708666563034058 epoch 47 validation loss: 0.4243217706680298 train loss: 0.4358236789703369 accuracy: 0.8690000176429749 epoch 48 validation loss: 0.42252859473228455 train loss: 0.43190935254096985 accuracy: 0.8734166622161865 epoch 49 validation loss: 0.4167732298374176 train loss: 0.4277592599391937 accuracy: 0.8722833395004272 epoch 50 validation loss: 0.41566863656044006 train loss: 0.42469915747642517 accuracy: 0.875249981880188 epoch 51 validation loss: 0.4111364185810089 train loss: 0.42160138487815857 accuracy: 0.8744833469390869 epoch 52 validation loss: 0.41114169359207153 train loss: 0.4198891520500183 accuracy: 0.8761000037193298 epoch 53 validation loss: 0.40817520022392273 train loss: 0.418091744184494 accuracy: 0.8753499984741211 epoch 54 validation loss: 0.41018104553222656 train loss: 0.4187377095222473 accuracy: 0.8759333491325378 epoch 55 validation loss: 0.4091891646385193 train loss: 0.41852423548698425 accuracy: 0.8740666508674622 epoch 56 validation loss: 0.41472214460372925 train loss: 0.42317917943000793 accuracy: 0.87336665391922 epoch 57 validation loss: 0.41550925374031067 train loss: 0.4242120385169983 accuracy: 0.8715167045593262 epoch 58 validation loss: 0.42603418231010437 train loss: 0.4344639480113983 accuracy: 0.8686167001724243 epoch 59 validation loss: 0.4259662926197052 train loss: 0.43401408195495605 accuracy: 0.8682000041007996 epoch 60 validation loss: 0.43949228525161743 train loss: 0.4479300081729889 accuracy: 0.8633833527565002 epoch 61 validation loss: 0.43219423294067383 train loss: 0.4395970106124878 accuracy: 0.8658833503723145 epoch 62 validation loss: 0.44001853466033936 train loss: 0.4486149549484253 accuracy: 0.8625333309173584 epoch 63 validation loss: 0.4229450225830078 train loss: 0.42973002791404724 accuracy: 0.8697500228881836 epoch 64 validation loss: 0.4187626242637634 train loss: 0.42750319838523865 accuracy: 0.8693333268165588 epoch 65 validation loss: 0.40092435479164124 train loss: 0.40716466307640076 accuracy: 0.8785333633422852 epoch 66 validation loss: 0.390180379152298 train loss: 0.39872950315475464 accuracy: 0.8792499899864197 epoch 67 validation loss: 0.3786790370941162 train loss: 0.38462117314338684 accuracy: 0.8875499963760376 epoch 68 validation loss: 0.3687390983104706 train loss: 0.37683331966400146 accuracy: 0.887583315372467 epoch 69 validation loss: 0.3628421425819397 train loss: 0.3686543107032776 accuracy: 0.8930166959762573 epoch 70 validation loss: 0.355516642332077 train loss: 0.36311420798301697 accuracy: 0.8926500082015991 epoch 71 validation loss: 0.35248467326164246 train loss: 0.3581932783126831 accuracy: 0.8963333368301392 epoch 72 validation loss: 0.3471384048461914 train loss: 0.3542581796646118 accuracy: 0.895633339881897 epoch 73 validation loss: 0.345338374376297 train loss: 0.35090339183807373 accuracy: 0.8987833261489868 epoch 74 validation loss: 0.3412402868270874 train loss: 0.3479164242744446 accuracy: 0.897683322429657 epoch 75 validation loss: 0.33992841839790344 train loss: 0.3453098237514496 accuracy: 0.9003166556358337 epoch 76 validation loss: 0.336586058139801 train loss: 0.34285449981689453 accuracy: 0.899566650390625 epoch 77 validation loss: 0.3354648947715759 train loss: 0.34064599871635437 accuracy: 0.9019666910171509 epoch 78 validation loss: 0.3326025903224945 train loss: 0.3385034203529358 accuracy: 0.9012333154678345 epoch 79 validation loss: 0.3315693438053131 train loss: 0.33653268218040466 accuracy: 0.9030333161354065 epoch 80 validation loss: 0.3290373980998993 train loss: 0.33460214734077454 accuracy: 0.9025499820709229 epoch 81 validation loss: 0.328059583902359 train loss: 0.33279159665107727 accuracy: 0.9039833545684814 epoch 82 validation loss: 0.32575860619544983 train loss: 0.3310023546218872 accuracy: 0.9035166501998901 epoch 83 validation loss: 0.3248150944709778 train loss: 0.3293081820011139 accuracy: 0.9050166606903076 epoch 84 validation loss: 0.3226830065250397 train loss: 0.32762911915779114 accuracy: 0.9047166705131531 epoch 85 validation loss: 0.3217836916446686 train loss: 0.3260313868522644 accuracy: 0.9059000015258789 epoch 86 validation loss: 0.3197770416736603 train loss: 0.3244391679763794 accuracy: 0.9055666923522949 epoch 87 validation loss: 0.31892064213752747 train loss: 0.3229207694530487 accuracy: 0.9067833423614502 epoch 88 validation loss: 0.31702181696891785 train loss: 0.3214121162891388 accuracy: 0.9064666628837585 epoch 89 validation loss: 0.31621429324150085 train loss: 0.31996384263038635 accuracy: 0.907633364200592 epoch 90 validation loss: 0.3143947422504425 train loss: 0.31851959228515625 accuracy: 0.9072999954223633 epoch 91 validation loss: 0.31363949179649353 train loss: 0.3171364665031433 accuracy: 0.90829998254776 epoch 92 validation loss: 0.31189385056495667 train loss: 0.3157590627670288 accuracy: 0.9079999923706055 epoch 93 validation loss: 0.3111976981163025 train loss: 0.31443601846694946 accuracy: 0.9089666604995728 epoch 94 validation loss: 0.3095097243785858 train loss: 0.31311991810798645 accuracy: 0.9086000323295593 epoch 95 validation loss: 0.3088778257369995 train loss: 0.3118544816970825 accuracy: 0.9096166491508484 epoch 96 validation loss: 0.30723172426223755 train loss: 0.3105928599834442 accuracy: 0.9092666506767273 epoch 97 validation loss: 0.30667465925216675 train loss: 0.3093916177749634 accuracy: 0.9103333353996277 epoch 98 validation loss: 0.30506303906440735 train loss: 0.3081871271133423 accuracy: 0.9098833203315735 epoch 99 validation loss: 0.3045886754989624 train loss: 0.3070499300956726 accuracy: 0.9108999967575073
In [10]:
im = 0
visualize_images(test_images[im])
forward(net_parameters, test_images[im])