In [1]:
import jax.numpy as jnp
import plotly.express as px
from plotly.subplots import make_subplots
import jax
import numpy as np
from datasets import mnist
import plotly.graph_objects as go
In [2]:
train_images, train_labels, test_images, test_labels = mnist()
train_images = train_images.astype(jnp.float32)
test_images = test_images.astype(jnp.float32)
train_labels = jnp.asarray(train_labels, dtype=jnp.int32)
test_labels = jnp.asarray(test_labels, dtype=jnp.int32)
In [3]:
# this is because my laptop is not very powerful
train_images = train_images[:100]
train_images.shape
Out[3]:
(100, 784)
In [4]:
def visualize_images(images_tensor):
img = images_tensor.reshape(-1, 28, 28)
fig = px.imshow(img[:, :, :], binary_string=False, facet_col=0, facet_col_wrap=5)
item_map={f'{i}':"" for i, key in enumerate(range(img.shape[0]))}
fig.for_each_annotation(lambda a: a.update(text=item_map[a.text.split("=")[1]]))
fig.show()
In [5]:
visualize_images(train_images[0:10])