Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified packages/demos/3dot-smc/data/frame-0021.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified packages/demos/3dot-smc/data/frame-0022.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified packages/demos/3dot-smc/data/frame-0023.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified packages/demos/3dot-smc/data/frame-0024.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified packages/demos/3dot-smc/data/frame-0025.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified packages/demos/3dot-smc/data/frame-0026.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified packages/demos/3dot-smc/data/frame-0027.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified packages/demos/3dot-smc/data/frame-0028.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified packages/demos/3dot-smc/data/frame-0029.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
363 changes: 199 additions & 164 deletions packages/demos/3dot-smc/notebooks/demo.ipynb

Large diffs are not rendered by default.

179 changes: 179 additions & 0 deletions packages/demos/3dot-smc/notebooks/quest_demo_425.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import genbrain_model_3dot as model
from genbrain_3dot_smc import collect_frames
from genbrain_smcnn_core.interpreter import run_smcnn_particle_filter
from genbrain_3dot_smc.viz import *
import genbrain_utils_genjax as gjutils
import jax.numpy as jnp
import jax
import genstudio.plot as Plot
from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation
import numpy as np
%matplotlib tk

Check failure on line 12 in packages/demos/3dot-smc/notebooks/quest_demo_425.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff

packages/demos/3dot-smc/notebooks/quest_demo_425.py:12:1: SyntaxError: Expected a statement

Check failure on line 12 in packages/demos/3dot-smc/notebooks/quest_demo_425.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff

packages/demos/3dot-smc/notebooks/quest_demo_425.py:12:13: SyntaxError: Simple statements must be separated by newlines or semicolons

# First we will collect the generative functions from the 3dot model.
genfns = [
model.initial_proposal,
model.initial_model,
model.step_proposal,
model.step_model,
model.obs_model,
]

# Here we convert digital x,y numpy frames into a spherical occupancy map (i.e. a visual angle occupancy grid)
xy_obs_frames = collect_frames()[1:]
vis_angle_observations = jax.vmap(lambda obs: model.find_occupied_2d_angles(obs))(
jnp.array(xy_obs_frames)
)
len_sim = 58
obs_traces = model.generate_obs_traces(vis_angle_observations[0:len_sim])

# choose a random seed so you can repeat the experiment
# looks really cool w key = 2000
key = jax.random.PRNGKey(2000)
num_particles = 20

def animate_current_particle_locs(observations, points_3D_seq, scores, plot_history, visual_angles, xyz_bounds, interval=50):
observation_grids = [np.transpose(np.fliplr(o.reshape(len(visual_angles), len(visual_angles)))) for o in observations]
xs, ys, zs = xyz_bounds
fig = plt.figure(figsize=(12, 6))
cmap = my_tab20(len(scores[0]))
ax2D = fig.add_subplot(121)
ax3D = fig.add_subplot(122, projection='3d')
win = .3
ax2D.set_xlabel("θ")
ax2D.set_ylabel("ϕ")
ax3D.set_xlabel('X')
ax3D.set_ylabel('Y')
ax3D.set_zlabel('Z')
ax3D.set_title("3D Hypotheses (Y = Depth)")
ax2D.set_title("2D Observation")
ax3D.set_xlim([xs[0], xs[-1]])
ax3D.set_ylim([ys[0], ys[-1]])
ax3D.set_zlim([zs[0], zs[-1]])
num_particles = len(points_3D_seq[0])
im = ax2D.imshow(observation_grids[0], cmap='gray', interpolation='none', vmin=0, vmax=1)
particles_3D = [ax3D.plot([], [], [], 'o')[0] for _ in range(num_particles)]
tails_3D = [ax3D.plot([], [], [], '-', color=cmap[i], alpha=0.3, linewidth=.5)[0] for i in range(num_particles)]
def update(frame):
print(frame)
im.set_array(observation_grids[frame])
for i, (particle, (x, y, z)) in enumerate(zip(particles_3D, points_3D_seq[frame])):
particle.set_data([x], [y])
# particle.set_alpha(float(jnp.exp(float(scores[frame][i]))**.1))
particle.set_3d_properties([z])
particle.set_color(cmap[i])
if plot_history:
history = np.array(points_3D_seq[:frame+1])[:, i, :]
tails_3D[i].set_data(history[:, 0], history[:, 1])
tails_3D[i].set_3d_properties(history[:, 2])
return [im] + particles_3D + tails_3D

anim = FuncAnimation(fig, update, frames=len(observations), interval=interval, blit=True)
return anim

def plot_spike_dictionary_studio(spiketimes_dict):
num_components = len(spiketimes_dict)
plot_objects = []
component_labels = []
for neuron_id, spikes_and_label in enumerate(spiketimes_dict.values()):
spikes, label = spikes_and_label
component_labels.append(label)
neuron_y = num_components - (neuron_id + 1)
for sp in spikes:
plot_objects.append({"spiketimes": sp, "neuron_id": neuron_y, "component": label})
raster = Plot.tickX(plot_objects, x="spiketimes", y="neuron_id", stroke="component") + Plot.axisY(
{"tickSize": 1, "tickFormat": None, "tickRotate": 1, "label": component_labels})
return raster

init_states_and_scores, first_step_states_and_scores, unrolled_pf = (
gjutils.smc.run_particle_filter(
obs_traces,
num_particles,
len_sim,
genfns,
key,
model.translate_proposal_cm_to_model,
)
)
# keep print of scores or not? its useful for debugging.
xyz, pscores = xyz_and_particle_scores(init_states_and_scores,
first_step_states_and_scores,
unrolled_pf)
#pscores = [jnp.zeros(num_particles) for i in range(len_sim+1)]
an = animate_current_particle_locs(vis_angle_observations[0:len_sim],
xyz,
pscores, True, model.visual_angles, (model.xs, model.ys, model.zs))

latent_variables = [
{
"variable": "v3d",
"q_id": ("dot", "v3d"),
"q_parents": [],
"p_parents": [],
"support": model.xyz_vels,
"type": "distribution",
},
{
"variable": "xyz",
"q_id": ("dot", "xyz"),
"q_parents": [("ego_pos", "ego_matter")],
"p_parents": [],
"support": model.xyz_point_cloud,
"type": "distribution",
},
{
"variable": ("ego_pos", "ego_matter"),
"q_id": ("dot", "ego_pos", "ego_matter"),
"q_parents": [],
"p_parents": ["xyz"],
"support": model.bool_support,
"type": "probmap",
},
{
"variable": "lights",
"q_id": ("dot", "lights"),
"p_parents": [],
"q_parents": [],
"support": model.bool_support,
"type": "distribution",
},
{
"variable": "diam",
"q_id": ("dot", "diam"),
"p_parents": [],
"q_parents": [],
"support": model.diams,
"type": "distribution",
},
]

obs_variables = [
{
"variable": ("obs", "pix"),
"parents": [],
"support": model.bool_support,
"type": "probmap",
}
]

results = run_smcnn_particle_filter(
(latent_variables, obs_variables),
model.initial_model,
model.step_model,
model.initial_proposal,
model.step_proposal,
model.obs_model,
5,
2,
vis_angle_observations[0:35],
)

xyz_spikes = gather_spikes_from_single_particle_samplescore(results, 0, "xyz", range(len_sim))

xyz_filtered = {k: v for k, v in xyz_spikes[0].items() if len(v[0]) != 0}

pa = plot_spike_dictionary_studio(xyz_filtered)
pa


10 changes: 5 additions & 5 deletions packages/demos/3dot-smc/pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

31 changes: 31 additions & 0 deletions packages/demos/3dot-smc/src/genbrain_3dot_smc/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import numpy as np
import os
import re
from PIL import Image
import genbrain_model_3dot as model

def collect_frames(directory="../data/", file_pattern="frame-*.png"):
frame_regex = re.compile(r"frame-(\d+)\.png")
frames = []
for filename in os.listdir(directory):
match = frame_regex.match(filename)
if match:
frame_number = int(match.group(1))
frames.append((frame_number, os.path.join(directory, filename)))
frames.sort(key=lambda x: x[0])
frame_paths = [path for _, path in frames]
numpy_frames = []
for path in frame_paths:
with Image.open(path) as img:
# this must be going downwards?
im = img.convert("L").point(lambda p: 1 if p > 0 else 0)
numpy_frames.append((np.transpose(np.flipud(np.array(im))) > 0).astype(int))
return numpy_frames

def get_xyz(results):
particles_per_step = results[1]
xyz_vals = []
for particles in particles_per_step:
xyz_vals.append([model.xyz_point_cloud[p.choicemap["xyz"]] for p in particles])
return xyz_vals

125 changes: 124 additions & 1 deletion packages/demos/3dot-smc/src/genbrain_3dot_smc/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
import seaborn as sns
import copy
from astropy.convolution import convolve_fft, Gaussian1DKernel
from mpl_toolkits.mplot3d.art3d import Line3DCollection
from matplotlib import patches, colormaps

Check failure on line 9 in packages/demos/3dot-smc/src/genbrain_3dot_smc/viz.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

packages/demos/3dot-smc/src/genbrain_3dot_smc/viz.py:9:24: F401 `matplotlib.patches` imported but unused
import jax
import jax.numpy as jnp
np.seterr(divide="ignore")

""" ANALOG PLOT LIB """
Expand Down Expand Up @@ -155,7 +159,7 @@

fig, ax = plt.subplots()
# cpal = sns.color_palette("tab20b", 20)
cpal = my_tab20(100)

Check failure on line 162 in packages/demos/3dot-smc/src/genbrain_3dot_smc/viz.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F841)

packages/demos/3dot-smc/src/genbrain_3dot_smc/viz.py:162:5: F841 Local variable `cpal` is assigned to but never used

if resampler != ():
resampler_spikedict = resampler[0]
Expand All @@ -175,10 +179,13 @@
if label[0:3] in ["res", "nor"]:
color_id = particle_id_decode(label)
else:
color_id = c

Check failure on line 182 in packages/demos/3dot-smc/src/genbrain_3dot_smc/viz.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F841)

packages/demos/3dot-smc/src/genbrain_3dot_smc/viz.py:182:17: F841 Local variable `color_id` is assigned to but never used
ax.vlines(
spikes, neuron_y, neuron_y + 0.8, color=cpal[color_id], linewidth=1.0
spikes, neuron_y, neuron_y + 0.8, color='k', linewidth=1.0
)
# ax.vlines(
# spikes, neuron_y, neuron_y + 0.8, color=cpal[color_id], linewidth=1.0
# )

comp_labels = [v[1] for k, v in ss_spiketimes_list[0].items()]
# xlim = np.max(np.concatenate([v[0] for v in ss_spiketimes_list.values()])) + 1
Expand Down Expand Up @@ -668,6 +675,122 @@
)
return ani

def xyz_and_particle_scores(init_pf, fs_pf, unrolled_pf):
xyz_init = init_pf[0].get_retval()[1]
xyz_step1 = fs_pf[0].get_retval()[1]
xyz_rest = unrolled_pf[1][1].get_retval()[1]
all_xyz = jnp.vstack([xyz_init[None, :], xyz_step1[None, :], xyz_rest])

init_scores = jax.tree.reduce(lambda x, y: x + y, init_pf[4])
fs_scores = jax.tree.reduce(lambda x, y: x + y, fs_pf[4])
unrolled_scores = jax.vmap(lambda scores: jax.tree.reduce(lambda x, y: x + y, scores))(unrolled_pf[1][4])
all_scores = [init_scores, fs_scores, *unrolled_scores]
return all_xyz, all_scores

def animate_current_particle_locs(observations, points_3D_seq, scores, plot_history, visual_angles, xyz_bounds, interval=50):
observation_grids = [np.transpose(np.fliplr(o.reshape(len(visual_angles), len(visual_angles)))) for o in observations]
xs, ys, zs = xyz_bounds
fig = plt.figure(figsize=(12, 6))
cmap = my_tab20(len(scores[0]))
ax2D = fig.add_subplot(121)
ax3D = fig.add_subplot(122, projection='3d')
win = .3

Check failure on line 697 in packages/demos/3dot-smc/src/genbrain_3dot_smc/viz.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F841)

packages/demos/3dot-smc/src/genbrain_3dot_smc/viz.py:697:5: F841 Local variable `win` is assigned to but never used
ax2D.set_xlabel("θ")
ax2D.set_ylabel("ϕ")
ax3D.set_xlabel('X')
ax3D.set_ylabel('Y')
ax3D.set_zlabel('Z')
ax3D.set_title("3D Hypotheses (Y = Depth)")
ax2D.set_title("2D Observation")
ax3D.set_xlim([xs[0], xs[-1]])
ax3D.set_ylim([ys[0], ys[-1]])
ax3D.set_zlim([zs[0], zs[-1]])
num_particles = len(points_3D_seq[0])
im = ax2D.imshow(observation_grids[0], cmap='gray', interpolation='none', vmin=0, vmax=1)
particles_3D = [ax3D.plot([], [], [], 'o')[0] for _ in range(num_particles)]
tails_3D = [ax3D.plot([], [], [], '-', color=cmap[i], alpha=0.3, linewidth=.5)[0] for i in range(num_particles)]
def update(frame):
im.set_array(observation_grids[frame])
for i, (particle, (x, y, z)) in enumerate(zip(particles_3D, points_3D_seq[frame])):
particle.set_data([x], [y])
# particle.set_alpha(float(jnp.exp(float(scores[frame][i]))**.1))
particle.set_3d_properties([z])
particle.set_color(cmap[i])
if plot_history:
history = np.array(points_3D_seq[:frame+1])[:, i, :]
tails_3D[i].set_data(history[:, 0], history[:, 1])
tails_3D[i].set_3d_properties(history[:, 2])
return [im] + particles_3D + tails_3D

anim = FuncAnimation(fig, update, frames=len(observations), interval=interval, blit=True)
return anim



def animate_trajectories_only_3D(xyz_inferences, scores, interval=50):
fig = plt.figure(figsize=(12, 6))
cmap = colormaps['plasma']
ax3D = fig.add_subplot(111, projection='3d')
win = .3

Check failure on line 734 in packages/demos/3dot-smc/src/genbrain_3dot_smc/viz.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F841)

packages/demos/3dot-smc/src/genbrain_3dot_smc/viz.py:734:5: F841 Local variable `win` is assigned to but never used
ax3D.set_xlabel('X')
ax3D.set_ylabel('Y')
ax3D.set_zlabel('Z')
ax3D.set_xlim([xs[0], xs[-1]])

Check failure on line 738 in packages/demos/3dot-smc/src/genbrain_3dot_smc/viz.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

packages/demos/3dot-smc/src/genbrain_3dot_smc/viz.py:738:20: F821 Undefined name `xs`

Check failure on line 738 in packages/demos/3dot-smc/src/genbrain_3dot_smc/viz.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

packages/demos/3dot-smc/src/genbrain_3dot_smc/viz.py:738:27: F821 Undefined name `xs`
ax3D.set_ylim([ys[0], ys[-1]])

Check failure on line 739 in packages/demos/3dot-smc/src/genbrain_3dot_smc/viz.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

packages/demos/3dot-smc/src/genbrain_3dot_smc/viz.py:739:20: F821 Undefined name `ys`
ax3D.set_zlim([zs[0], zs[-1]])
ax3D.set_title("Inferred 3D Trajectories")
num_steps = len(scores)
num_particles = len(scores[0])
norm = plt.Normalize(0, num_steps)
particle_segments = []
num_steps = len(xyz_inferences)
num_particles = len(scores[0])
cmap_by_step = np.linspace(0, num_steps, num_steps)
segments_by_step = []
seg_colors_by_step = []
for curr_step in range(num_steps):
step_segs = []
step_colors = []
for p_ind in range(num_particles):
xyz = xyz_inferences[:, p_ind]
segs = [[list(xyz[i]), list(xyz[i+1])] for i in range(curr_step)]
colors = [cmap_by_step[i] for i in range(curr_step)]
step_segs = step_segs + segs
step_colors = step_colors + colors
segments_by_step.append(step_segs)
seg_colors_by_step.append(step_colors)
segments_by_step = segments_by_step[1:]
seg_colors_by_step = seg_colors_by_step[1:]

def update(frame):
ax3D.cla()
ax3D.set_xlabel('X')
ax3D.set_ylabel('Y')
ax3D.set_zlabel('Z')
ax3D.set_xlim([xs[0], xs[-1]])
ax3D.set_ylim([ys[0], ys[-1]])
ax3D.set_zlim([zs[0], zs[-1]])
ax3D.set_title("3D Hypotheses (Y = Depth)")
lc = Line3DCollection(segments_by_step[frame],
cmap=cmap, norm=norm, linewidths=1)
lc.set_array(seg_colors_by_step[frame])
ax3D.add_collection3d(lc)
return lc,

anim = FuncAnimation(fig, update, frames=num_steps-1, interval=interval, blit=False)
return anim

def make_probability_heatmap(matrix):
if len(matrix.shape) == 3:
x, y, z = np.indices(matrix.shape)
fig = plt.figure(figsize=(10, 7))
ax = fig.add_subplot(111, projection='3d')
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
scatter = ax.scatter(x, y, z, c=matrix.flatten(), cmap='viridis')
fig.colorbar(scatter, ax=ax, label='Value')
plt.show()

# x_p_assemblies_particle_0 = snmc_spikes_wrapper(pf_results, 'x', range(num_snmc_steps), range(0,1), ["assemblies_p13", "assemblies_p14", "assemblies_p15"])[0]

Expand Down
Loading