diff --git a/packages/demos/3dot-smc/data/frame-0021.png b/packages/demos/3dot-smc/data/frame-0021.png index e26ed5e..987d1e3 100644 Binary files a/packages/demos/3dot-smc/data/frame-0021.png and b/packages/demos/3dot-smc/data/frame-0021.png differ diff --git a/packages/demos/3dot-smc/data/frame-0022.png b/packages/demos/3dot-smc/data/frame-0022.png index 96abe83..987d1e3 100644 Binary files a/packages/demos/3dot-smc/data/frame-0022.png and b/packages/demos/3dot-smc/data/frame-0022.png differ diff --git a/packages/demos/3dot-smc/data/frame-0023.png b/packages/demos/3dot-smc/data/frame-0023.png index e246313..987d1e3 100644 Binary files a/packages/demos/3dot-smc/data/frame-0023.png and b/packages/demos/3dot-smc/data/frame-0023.png differ diff --git a/packages/demos/3dot-smc/data/frame-0024.png b/packages/demos/3dot-smc/data/frame-0024.png index 385e2ae..987d1e3 100644 Binary files a/packages/demos/3dot-smc/data/frame-0024.png and b/packages/demos/3dot-smc/data/frame-0024.png differ diff --git a/packages/demos/3dot-smc/data/frame-0025.png b/packages/demos/3dot-smc/data/frame-0025.png index d48bfff..987d1e3 100644 Binary files a/packages/demos/3dot-smc/data/frame-0025.png and b/packages/demos/3dot-smc/data/frame-0025.png differ diff --git a/packages/demos/3dot-smc/data/frame-0026.png b/packages/demos/3dot-smc/data/frame-0026.png index c2078f0..987d1e3 100644 Binary files a/packages/demos/3dot-smc/data/frame-0026.png and b/packages/demos/3dot-smc/data/frame-0026.png differ diff --git a/packages/demos/3dot-smc/data/frame-0027.png b/packages/demos/3dot-smc/data/frame-0027.png index e7d75e5..987d1e3 100644 Binary files a/packages/demos/3dot-smc/data/frame-0027.png and b/packages/demos/3dot-smc/data/frame-0027.png differ diff --git a/packages/demos/3dot-smc/data/frame-0028.png b/packages/demos/3dot-smc/data/frame-0028.png index 56f3abd..987d1e3 100644 Binary files a/packages/demos/3dot-smc/data/frame-0028.png and b/packages/demos/3dot-smc/data/frame-0028.png differ diff --git a/packages/demos/3dot-smc/data/frame-0029.png b/packages/demos/3dot-smc/data/frame-0029.png index 6ca0261..987d1e3 100644 Binary files a/packages/demos/3dot-smc/data/frame-0029.png and b/packages/demos/3dot-smc/data/frame-0029.png differ diff --git a/packages/demos/3dot-smc/notebooks/demo.ipynb b/packages/demos/3dot-smc/notebooks/demo.ipynb index 2ca6d3a..a210096 100644 --- a/packages/demos/3dot-smc/notebooks/demo.ipynb +++ b/packages/demos/3dot-smc/notebooks/demo.ipynb @@ -2,68 +2,37 @@ "cells": [ { "cell_type": "code", - "execution_count": 16, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import genbrain_model_3dot as model\n", - "import genbrain_smcnn_core.interpreter as smcnn\n", + "from genbrain_3dot_smc import collect_frames\n", + "from genbrain_3dot_smc.viz import *\n", "import genbrain_utils_genjax as gjutils\n", - "import numpy as np\n", "import jax.numpy as jnp\n", - "import jax\n", - "from PIL import Image\n", - "import re\n", - "import os" + "import jax" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ - "# First we will collect the generative functions from the 3dot model. Lets start with 20 particles.\n", + "# First we will collect the generative functions from the 3dot model.\n", "genfns = [\n", " model.initial_proposal,\n", " model.initial_model,\n", " model.step_proposal,\n", " model.step_model,\n", " model.obs_model,\n", - "]\n", - "\n", - "num_particles = 20" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "# This function allows us to grab the raw data stored in ../data and turn it into numpy frames indicating pixel occupancy\n", - "def collect_frames(directory=\"../data/\", file_pattern=\"frame-*.png\"):\n", - " frame_regex = re.compile(r\"frame-(\\d+)\\.png\")\n", - " frames = []\n", - " for filename in os.listdir(directory):\n", - " match = frame_regex.match(filename)\n", - " if match:\n", - " frame_number = int(match.group(1))\n", - " frames.append((frame_number, os.path.join(directory, filename)))\n", - " frames.sort(key=lambda x: x[0])\n", - " frame_paths = [path for _, path in frames]\n", - " numpy_frames = []\n", - " for path in frame_paths:\n", - " with Image.open(path) as img:\n", - " # this must be going downwards?\n", - " im = img.convert(\"L\").point(lambda p: 1 if p > 0 else 0)\n", - " numpy_frames.append((np.transpose(np.flipud(np.array(im))) > 0).astype(int))\n", - " return numpy_frames" + "]" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -78,7 +47,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -109,6 +78,8 @@ "source": [ "# choose a random seed so you can repeat the experiment\n", "key = jax.random.PRNGKey(100)\n", + "num_particles = 20\n", + "\n", "init_states_and_scores, first_step_states_and_scores, unrolled_pf = (\n", " gjutils.smc.run_particle_filter(\n", " obs_traces,\n", @@ -124,165 +95,228 @@ }, { "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], - "source": [ - "# connect to genstudio viz in ../viz folder rather than my old 3D visualizer.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": {}, "outputs": [ { "ename": "TypeError", - "evalue": "run_smcnn_particle_filter() takes 9 positional arguments but 10 were given", + "evalue": "xyz_and_particle_scores() takes 3 positional arguments but 4 were given", "output_type": "error", "traceback": [ "\u001b[31m---------------------------------------------------------------------------\u001b[39m", "\u001b[31mTypeError\u001b[39m Traceback (most recent call last)", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[14]\u001b[39m\u001b[32m, line 49\u001b[39m\n\u001b[32m 45\u001b[39m assembly_size = \u001b[32m10\u001b[39m\n\u001b[32m 47\u001b[39m num_particles = \u001b[32m2\u001b[39m\n\u001b[32m---> \u001b[39m\u001b[32m49\u001b[39m pf_results = \u001b[43msmcnn\u001b[49m\u001b[43m.\u001b[49m\u001b[43mrun_smcnn_particle_filter\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 50\u001b[39m \u001b[43m \u001b[49m\u001b[43mvariables\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 51\u001b[39m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m.\u001b[49m\u001b[43minitial_model\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 52\u001b[39m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep_model\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 53\u001b[39m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m.\u001b[49m\u001b[43minitial_proposal\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 54\u001b[39m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep_proposal\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 55\u001b[39m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m.\u001b[49m\u001b[43mobs_model\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 56\u001b[39m \u001b[43m \u001b[49m\u001b[43massembly_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 57\u001b[39m \u001b[43m \u001b[49m\u001b[43mnum_particles\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 58\u001b[39m \u001b[43m \u001b[49m\u001b[43mvis_angle_observations\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 59\u001b[39m \u001b[43m \u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43manalog\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n", - "\u001b[31mTypeError\u001b[39m: run_smcnn_particle_filter() takes 9 positional arguments but 10 were given" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Values: P = [ -inf -inf -inf -14.069142 -inf -inf\n", - " -31.47956 -32.92604 -inf -18.1999 -15.98837 -13.67204\n", - " -14.069142 -20.997154 -12.622659 -24.054615 -inf -inf\n", - " -inf -inf], Q = [ -3.754869 -3.754869 -9.319347 -5.619448 -7.2321143 -10.469468\n", - " -9.362024 -7.25119 -3.6446338 -5.6194477 -5.159434 -5.6194477\n", - " -5.619448 -7.711205 -5.6194477 -7.0030193 -4.8921866 -3.508615\n", - " -3.508615 -5.140358 ], O = [-34.666298 -34.666298 -38.142395 -38.142395 -38.142395 -31.190197\n", - " -48.57069 -41.618492 -38.142395 -38.142395 -38.142395 -38.142395\n", - " -38.142395 -41.618492 -38.142395 -38.142395 -31.190197 -31.190197\n", - " -31.190197 -34.666298]\n", - "Values: P = [-15.793759 -18.565975 -inf -inf -inf -inf\n", - " -inf -inf -inf -12.996509 -inf -22.79751\n", - " -inf -inf -12.996509 -inf -12.025179 -21.972832\n", - " -inf -inf], Q = [ -7.00302 -5.366965 -5.0282054 -5.619448 -5.0282054 -12.218768\n", - " -3.1846192 -15.684491 -3.294855 -3.294855 -5.865702 -12.072917\n", - " -7.225317 -8.97685 -3.294855 -3.754869 -5.6194477 -7.017783\n", - " -4.9069505 -7.4826813], O = [-38.142395 -34.666298 -38.142395 -38.142395 -38.142395 -41.618492\n", - " -38.142395 -45.094593 -34.666298 -34.666298 -41.618492 -41.618492\n", - " -38.142395 -38.142395 -34.666298 -34.666298 -38.142395 -41.618492\n", - " -34.666298 -41.618492]\n", - "Values: P = [ -inf -inf -inf -inf -inf -26.680742\n", - " -inf -inf -inf -inf -inf -inf\n", - " -inf -inf -inf -inf -inf -inf\n", - " -inf -inf], Q = [ -4.73735 -8.06703 -7.4696765 -4.3812065 -7.372854 -10.568775\n", - " -9.21254 -6.9365463 -5.8842616 -4.028142 -12.710102 -7.601531\n", - " -5.524967 -5.261316 -6.689106 -7.052661 -4.8412204 -4.1288624\n", - " -8.86578 -6.337224 ], O = [-38.142395 -38.142395 -41.618492 -38.142395 -41.618492 -48.570694\n", - " -45.094593 -41.618492 -38.142395 -34.666298 -48.570694 -34.666298\n", - " -34.666298 -34.666298 -45.094593 -41.618492 -38.142395 -31.190197\n", - " -45.094593 -38.142395]\n" + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[6]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m xyz, pscores = \u001b[43mxyz_and_particle_scores\u001b[49m\u001b[43m(\u001b[49m\u001b[43minit_states_and_scores\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[32m 2\u001b[39m \u001b[43m \u001b[49m\u001b[43mfirst_step_states_and_scores\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3\u001b[39m \u001b[43m \u001b[49m\u001b[43munrolled_pf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m.\u001b[49m\u001b[43mvisual_angles\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 4\u001b[39m \u001b[38;5;66;03m#pscores = [jnp.zeros(num_particles) for i in range(len_sim+1)]\u001b[39;00m\n\u001b[32m 5\u001b[39m animate_current_particle_locs(xy_obs_frames[\u001b[32m0\u001b[39m:len_sim], \n\u001b[32m 6\u001b[39m xyz, \n\u001b[32m 7\u001b[39m pscores, \u001b[38;5;28;01mTrue\u001b[39;00m)\n", + "\u001b[31mTypeError\u001b[39m: xyz_and_particle_scores() takes 3 positional arguments but 4 were given" ] } ], "source": [ - "def extract_xyz_from_smcnn(pf_results, obs, particles_to_animate):\n", - " xyz_inferences = []\n", - " p_scores = []\n", - " particles_per_step = pf_results[1]\n", - " resampler_per_step = pf_results[2]\n", - " for step in range(len(obs)):\n", - " particles = particles_per_step[step]\n", - " particle_choicemaps = [\n", - " p.choicemap for i, p in enumerate(particles) if i in particles_to_animate\n", - " ]\n", - " particle_scores = resampler_per_step[step].log_weights\n", - " # note you would normally index the supports b/c\n", - " # choicemap is an assembly index. but\n", - " xyz = np.array(\n", - " [model.egocentric_3d_map[cm[\"xyz\"]] for cm in particle_choicemaps]\n", - " )\n", - " xyz_inferences.append(xyz)\n", - " p_scores.append(particle_scores)\n", - " return np.array(xyz_inferences), p_scores\n", - "\n", - "\n", - "model_variables = [\n", + "xyz, pscores = xyz_and_particle_scores(\n", + " init_states_and_scores,\n", + " first_step_states_and_scores,\n", + " unrolled_pf,\n", + " model.visual_angles,\n", + ")\n", + "# pscores = [jnp.zeros(num_particles) for i in range(len_sim+1)]\n", + "animate_current_particle_locs(\n", + " xy_obs_frames[0:len_sim], xyz, pscores, True, model.visual_angles\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "latent_variables = [\n", " {\n", - " \"variable\": \"lights\",\n", - " \"parents\": [],\n", - " \"support\": model.bool_support,\n", - " \"subtraced\": [],\n", + " \"variable\": \"v3d\",\n", + " \"q_id\": (\"dot\", \"v3d\"),\n", + " \"q_parents\": [],\n", + " \"p_parents\": [],\n", + " \"support\": model.xyz_vels,\n", + " \"type\": \"distribution\",\n", " },\n", - " {\"variable\": \"diam\", \"parents\": [], \"support\": model.diams, \"subtraced\": []},\n", - " {\"variable\": \"v3d\", \"parents\": [], \"support\": model.xyz_vels, \"subtraced\": []},\n", " {\n", " \"variable\": \"xyz\",\n", - " \"parents\": [\"v3d\"],\n", + " \"q_id\": (\"dot\", \"xyz\"),\n", + " \"q_parents\": [(\"ego_pos\", \"ego_matter\")],\n", + " \"p_parents\": [],\n", " \"support\": model.xyz_point_cloud,\n", - " \"subtraced\": [],\n", + " \"type\": \"distribution\",\n", " },\n", " {\n", - " \"variable\": \"ego_pos\",\n", - " \"parents\": [\"xyz\"],\n", - " \"support\": model.egocentric_3d_map,\n", - " \"subtraced\": [],\n", - " },\n", - "]\n", - "\n", - "proposal_variables = [\n", - " {\n", - " \"variable\": \"ego_pos\",\n", - " \"parents\": [],\n", - " \"support\": model.egocentric_3d_map,\n", - " \"subtraced\": [],\n", + " \"variable\": (\"ego_pos\", \"ego_matter\"),\n", + " \"q_id\": (\"dot\", \"ego_pos\", \"ego_matter\"),\n", + " \"q_parents\": [],\n", + " \"p_parents\": [\"xyz\"],\n", + " \"support\": model.bool_support,\n", + " \"type\": \"probmap\",\n", " },\n", " {\n", - " \"variable\": \"xyz\",\n", - " \"parents\": [\"ego_pos\"],\n", - " \"support\": model.xyz_point_cloud,\n", - " \"subtraced\": [],\n", + " \"variable\": \"lights\",\n", + " \"q_id\": (\"dot\", \"lights\"),\n", + " \"p_parents\": [],\n", + " \"q_parents\": [],\n", + " \"support\": model.bool_support,\n", + " \"type\": \"distribution\",\n", " },\n", - " {\"variable\": \"v3d\", \"parents\": [\"xyz\"], \"support\": model.xyz_vels, \"subtraced\": []},\n", " {\n", " \"variable\": \"diam\",\n", - " \"parents\": [\"ego_pos\"],\n", + " \"q_id\": (\"dot\", \"diam\"),\n", + " \"p_parents\": [],\n", + " \"q_parents\": [],\n", " \"support\": model.diams,\n", - " \"subtraced\": [],\n", - " },\n", - " {\n", - " \"variable\": \"lights\",\n", - " \"parents\": [],\n", - " \"support\": model.bool_support,\n", - " \"subtraced\": [],\n", + " \"type\": \"distribution\",\n", " },\n", "]\n", "\n", "obs_variables = [\n", " {\n", - " \"variable\": \"obs\",\n", + " \"variable\": (\"obs\", \"pix\"),\n", " \"parents\": [],\n", - " \"support\": model.egocentric_2d_map,\n", - " \"subtraced\": [],\n", + " \"support\": model.bool_support,\n", + " \"type\": \"probmap\",\n", " }\n", - "]\n", - "\n", - "variables = [model_variables, proposal_variables, obs_variables]\n", - "assembly_size = 10\n", - "\n", - "num_particles = 2\n", - "\n", - "pf_results = smcnn.run_smcnn_particle_filter(\n", - " variables,\n", - " model.initial_model,\n", - " model.step_model,\n", - " model.initial_proposal,\n", - " model.step_proposal,\n", - " model.obs_model,\n", - " assembly_size,\n", - " num_particles,\n", - " vis_angle_observations,\n", - " \"analog\",\n", - ")" + "]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Jitting Generative Functions\n", + "Initializing SMCNN Particle Filter\n", + "q not complete after max recursion\n", + "q not complete after max recursion\n", + "q not complete after max recursion\n", + "scoring pixels\n", + "0\n", + "()\n", + "[0 0 0 ... 0 0 0]\n", + "(1024,)\n", + "scoring pixels\n", + "0\n", + "()\n", + "[0 0 0 ... 0 0 0]\n", + "(1024,)\n", + "Initialized SMCNN Particle Filter\n", + "resampler score\n", + "[0.5 0.5]\n", + "step 0\n", + "scoring pixels\n", + "0\n", + "()\n", + "[0 0 0 ... 0 0 0]\n", + "(1024,)\n", + "scoring pixels\n", + "0\n", + "()\n", + "[0 0 0 ... 0 0 0]\n", + "(1024,)\n", + "resampler score\n", + "[0.5 0.5]\n" + ] + } + ], + "source": [ + "# results = run_smcnn_particle_filter(\n", + "# (latent_variables, obs_variables),\n", + "# model.initial_model,\n", + "# model.step_model,\n", + "# model.initial_proposal,\n", + "# model.step_proposal,\n", + "# model.obs_model,\n", + "# 5,\n", + "# 2,\n", + "# vis_angle_observations[1:3],\n", + "# )\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# xyz_spikes = gather_spikes_from_single_particle_samplescore(results, 0, \"xyz\", range(len_sim))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# xyz_filtered = {k: v for k, v in xyz_spikes[0].items() if len(v[0]) != 0}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# def plot_spike_dictionary_studio(spiketimes_dict):\n", + "# num_components = len(spiketimes_dict)\n", + "# plot_objects = []\n", + "# component_labels = []\n", + "# for neuron_id, spikes_and_label in enumerate(spiketimes_dict.values()):\n", + "# spikes, label = spikes_and_label\n", + "# component_labels.append(label)\n", + "# neuron_y = num_components - (neuron_id + 1)\n", + "# for sp in spikes:\n", + "# plot_objects.append({\"spiketimes\": sp, \"neuron_id\": neuron_y, \"component\": label})\n", + "# raster = Plot.tickX(plot_objects, x=\"spiketimes\", y=\"neuron_id\", stroke=\"component\") + Plot.axisY(\n", + "# {\"tickSize\": 1, \"tickFormat\": None, \"tickRotate\": 1, \"label\": component_labels})\n", + "# return raster" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# pa = plot_spike_dictionary_studio(xyz_filtered)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2b50a7c017224ea8b0affa5c81d1cf7d", + "version_major": 2, + "version_minor": 1 + }, + "text/plain": [ + "Widget(data=)" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# pa" ] }, { diff --git a/packages/demos/3dot-smc/notebooks/quest_demo_425.py b/packages/demos/3dot-smc/notebooks/quest_demo_425.py new file mode 100644 index 0000000..995413f --- /dev/null +++ b/packages/demos/3dot-smc/notebooks/quest_demo_425.py @@ -0,0 +1,179 @@ +import genbrain_model_3dot as model +from genbrain_3dot_smc import collect_frames +from genbrain_smcnn_core.interpreter import run_smcnn_particle_filter +from genbrain_3dot_smc.viz import * +import genbrain_utils_genjax as gjutils +import jax.numpy as jnp +import jax +import genstudio.plot as Plot +from matplotlib import pyplot as plt +from matplotlib.animation import FuncAnimation +import numpy as np +# %matplotlib tk + +# First we will collect the generative functions from the 3dot model. +genfns = [ + model.initial_proposal, + model.initial_model, + model.step_proposal, + model.step_model, + model.obs_model, +] + +# Here we convert digital x,y numpy frames into a spherical occupancy map (i.e. a visual angle occupancy grid) +xy_obs_frames = collect_frames()[1:] +vis_angle_observations = jax.vmap(lambda obs: model.find_occupied_2d_angles(obs))( + jnp.array(xy_obs_frames) +) +len_sim = 58 +obs_traces = model.generate_obs_traces(vis_angle_observations[0:len_sim]) + +# choose a random seed so you can repeat the experiment +# looks really cool w key = 2000 +key = jax.random.PRNGKey(2000) +num_particles = 20 + +def animate_current_particle_locs(observations, points_3D_seq, scores, plot_history, visual_angles, xyz_bounds, interval=50): + observation_grids = [np.transpose(np.fliplr(o.reshape(len(visual_angles), len(visual_angles)))) for o in observations] + xs, ys, zs = xyz_bounds + fig = plt.figure(figsize=(12, 6)) + cmap = my_tab20(len(scores[0])) + ax2D = fig.add_subplot(121) + ax3D = fig.add_subplot(122, projection='3d') + win = .3 + ax2D.set_xlabel("θ") + ax2D.set_ylabel("ϕ") + ax3D.set_xlabel('X') + ax3D.set_ylabel('Y') + ax3D.set_zlabel('Z') + ax3D.set_title("3D Hypotheses (Y = Depth)") + ax2D.set_title("2D Observation") + ax3D.set_xlim([xs[0], xs[-1]]) + ax3D.set_ylim([ys[0], ys[-1]]) + ax3D.set_zlim([zs[0], zs[-1]]) + num_particles = len(points_3D_seq[0]) + im = ax2D.imshow(observation_grids[0], cmap='gray', interpolation='none', vmin=0, vmax=1) + particles_3D = [ax3D.plot([], [], [], 'o')[0] for _ in range(num_particles)] + tails_3D = [ax3D.plot([], [], [], '-', color=cmap[i], alpha=0.3, linewidth=.5)[0] for i in range(num_particles)] + def update(frame): + print(frame) + im.set_array(observation_grids[frame]) + for i, (particle, (x, y, z)) in enumerate(zip(particles_3D, points_3D_seq[frame])): + particle.set_data([x], [y]) +# particle.set_alpha(float(jnp.exp(float(scores[frame][i]))**.1)) + particle.set_3d_properties([z]) + particle.set_color(cmap[i]) + if plot_history: + history = np.array(points_3D_seq[:frame+1])[:, i, :] + tails_3D[i].set_data(history[:, 0], history[:, 1]) + tails_3D[i].set_3d_properties(history[:, 2]) + return [im] + particles_3D + tails_3D + + anim = FuncAnimation(fig, update, frames=len(observations), interval=interval, blit=True) + return anim + +def plot_spike_dictionary_studio(spiketimes_dict): + num_components = len(spiketimes_dict) + plot_objects = [] + component_labels = [] + for neuron_id, spikes_and_label in enumerate(spiketimes_dict.values()): + spikes, label = spikes_and_label + component_labels.append(label) + neuron_y = num_components - (neuron_id + 1) + for sp in spikes: + plot_objects.append({"spiketimes": sp, "neuron_id": neuron_y, "component": label}) + raster = Plot.tickX(plot_objects, x="spiketimes", y="neuron_id", stroke="component") + Plot.axisY( + {"tickSize": 1, "tickFormat": None, "tickRotate": 1, "label": component_labels}) + return raster + +init_states_and_scores, first_step_states_and_scores, unrolled_pf = ( + gjutils.smc.run_particle_filter( + obs_traces, + num_particles, + len_sim, + genfns, + key, + model.translate_proposal_cm_to_model, + ) +) +# keep print of scores or not? its useful for debugging. +xyz, pscores = xyz_and_particle_scores(init_states_and_scores, + first_step_states_and_scores, + unrolled_pf) +#pscores = [jnp.zeros(num_particles) for i in range(len_sim+1)] +an = animate_current_particle_locs(vis_angle_observations[0:len_sim], + xyz, + pscores, True, model.visual_angles, (model.xs, model.ys, model.zs)) + +latent_variables = [ + { + "variable": "v3d", + "q_id": ("dot", "v3d"), + "q_parents": [], + "p_parents": [], + "support": model.xyz_vels, + "type": "distribution", + }, + { + "variable": "xyz", + "q_id": ("dot", "xyz"), + "q_parents": [("ego_pos", "ego_matter")], + "p_parents": [], + "support": model.xyz_point_cloud, + "type": "distribution", + }, + { + "variable": ("ego_pos", "ego_matter"), + "q_id": ("dot", "ego_pos", "ego_matter"), + "q_parents": [], + "p_parents": ["xyz"], + "support": model.bool_support, + "type": "probmap", + }, + { + "variable": "lights", + "q_id": ("dot", "lights"), + "p_parents": [], + "q_parents": [], + "support": model.bool_support, + "type": "distribution", + }, + { + "variable": "diam", + "q_id": ("dot", "diam"), + "p_parents": [], + "q_parents": [], + "support": model.diams, + "type": "distribution", + }, +] + +obs_variables = [ + { + "variable": ("obs", "pix"), + "parents": [], + "support": model.bool_support, + "type": "probmap", + } +] + +results = run_smcnn_particle_filter( + (latent_variables, obs_variables), + model.initial_model, + model.step_model, + model.initial_proposal, + model.step_proposal, + model.obs_model, + 5, + 2, + vis_angle_observations[0:35], +) + +xyz_spikes = gather_spikes_from_single_particle_samplescore(results, 0, "xyz", range(len_sim)) + +xyz_filtered = {k: v for k, v in xyz_spikes[0].items() if len(v[0]) != 0} + +pa = plot_spike_dictionary_studio(xyz_filtered) +pa + + diff --git a/packages/demos/3dot-smc/pixi.lock b/packages/demos/3dot-smc/pixi.lock index d7f2447..bcfe03c 100644 --- a/packages/demos/3dot-smc/pixi.lock +++ b/packages/demos/3dot-smc/pixi.lock @@ -118,7 +118,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/8e/37/efad0257dc6e593a18957422533ff0f87ede7c9c6ea010a2177d738fb82f/pure_eval-0.2.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e5/e0/050018d855d26d3c0b4a7d1b2ed692be758ce276d8289e2a2b44ba1014a5/pyerfa-2.0.1.5-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/8a/0b/9fcc47d19c48b59121088dd6da2488a49d5f72dacf8262e2790a1d2c7d15/pygments-2.19.1-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/f9/83/80c17698f41131f7157a26ae985e2c1f5526db79f277c4416af145f3e12b/pyparsing-3.2.2-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/05/e7/df2285f3d08fee213f2d041540fa4fc9ca6c2d44cf36d3a035bf2a8d2bcc/pyparsing-3.2.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/eb/38/ac33370d784287baa1c3d538978b5e2ea064d4c1b93ffbd12826c190dd10/pytz-2025.1-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b9/2b/614b4752f2e127db5cc206abc23a8c19678e92b23c3db30fc86ab731d3bd/PyYAML-6.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl @@ -251,7 +251,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/8e/37/efad0257dc6e593a18957422533ff0f87ede7c9c6ea010a2177d738fb82f/pure_eval-0.2.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/11/4a/31a363370478b63c6289a34743f2ba2d3ae1bd8223e004d18ab28fb92385/pyerfa-2.0.1.5-cp39-abi3-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/8a/0b/9fcc47d19c48b59121088dd6da2488a49d5f72dacf8262e2790a1d2c7d15/pygments-2.19.1-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/f9/83/80c17698f41131f7157a26ae985e2c1f5526db79f277c4416af145f3e12b/pyparsing-3.2.2-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/05/e7/df2285f3d08fee213f2d041540fa4fc9ca6c2d44cf36d3a035bf2a8d2bcc/pyparsing-3.2.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/eb/38/ac33370d784287baa1c3d538978b5e2ea064d4c1b93ffbd12826c190dd10/pytz-2025.1-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/a8/0c/38374f5bb272c051e2a69281d71cba6fdb983413e6758b84482905e29a5d/PyYAML-6.0.2-cp312-cp312-macosx_11_0_arm64.whl @@ -2555,10 +2555,10 @@ packages: requires_dist: - colorama>=0.4.6 ; extra == 'windows-terminal' requires_python: '>=3.8' -- pypi: https://files.pythonhosted.org/packages/f9/83/80c17698f41131f7157a26ae985e2c1f5526db79f277c4416af145f3e12b/pyparsing-3.2.2-py3-none-any.whl +- pypi: https://files.pythonhosted.org/packages/05/e7/df2285f3d08fee213f2d041540fa4fc9ca6c2d44cf36d3a035bf2a8d2bcc/pyparsing-3.2.3-py3-none-any.whl name: pyparsing - version: 3.2.2 - sha256: 6ab05e1cb111cc72acc8ed811a3ca4c2be2af8d7b6df324347f04fd057d8d793 + version: 3.2.3 + sha256: a749938e02d6fd0b59b356ca504a24982314bb090c383e3cf201c95ef7e2bfcf requires_dist: - railroad-diagrams ; extra == 'diagrams' - jinja2 ; extra == 'diagrams' diff --git a/packages/demos/3dot-smc/src/genbrain_3dot_smc/__init__.py b/packages/demos/3dot-smc/src/genbrain_3dot_smc/__init__.py index e69de29..2bf47a7 100644 --- a/packages/demos/3dot-smc/src/genbrain_3dot_smc/__init__.py +++ b/packages/demos/3dot-smc/src/genbrain_3dot_smc/__init__.py @@ -0,0 +1,32 @@ +import numpy as np +import os +import re +from PIL import Image +import genbrain_model_3dot as model + + +def collect_frames(directory="../data/", file_pattern="frame-*.png"): + frame_regex = re.compile(r"frame-(\d+)\.png") + frames = [] + for filename in os.listdir(directory): + match = frame_regex.match(filename) + if match: + frame_number = int(match.group(1)) + frames.append((frame_number, os.path.join(directory, filename))) + frames.sort(key=lambda x: x[0]) + frame_paths = [path for _, path in frames] + numpy_frames = [] + for path in frame_paths: + with Image.open(path) as img: + # this must be going downwards? + im = img.convert("L").point(lambda p: 1 if p > 0 else 0) + numpy_frames.append((np.transpose(np.flipud(np.array(im))) > 0).astype(int)) + return numpy_frames + + +def get_xyz(results): + particles_per_step = results[1] + xyz_vals = [] + for particles in particles_per_step: + xyz_vals.append([model.xyz_point_cloud[p.choicemap["xyz"]] for p in particles]) + return xyz_vals diff --git a/packages/demos/3dot-smc/src/genbrain_3dot_smc/viz.py b/packages/demos/3dot-smc/src/genbrain_3dot_smc/viz.py index 5b08068..1d42645 100644 --- a/packages/demos/3dot-smc/src/genbrain_3dot_smc/viz.py +++ b/packages/demos/3dot-smc/src/genbrain_3dot_smc/viz.py @@ -5,6 +5,11 @@ import seaborn as sns import copy from astropy.convolution import convolve_fft, Gaussian1DKernel +from mpl_toolkits.mplot3d.art3d import Line3DCollection +from matplotlib import colormaps +import jax +import jax.numpy as jnp + np.seterr(divide="ignore") """ ANALOG PLOT LIB """ @@ -154,15 +159,13 @@ def particle_id_decode(label): return int(label[-2:]) fig, ax = plt.subplots() - # cpal = sns.color_palette("tab20b", 20) - cpal = my_tab20(100) - + if resampler != (): resampler_spikedict = resampler[0] indexed_resampler_spikes = { i + len(ss_spiketimes_list[0]): v for i, v in resampler_spikedict.items() } - # return indexed_resampler_spikesmerged_dict = {**dict1, **dict2} + ss_spiketimes_list[0] = {**ss_spiketimes_list[0], **indexed_resampler_spikes} # this is impervious to whether resampler has been added or not. @@ -172,14 +175,8 @@ def particle_id_decode(label): for neuron_id, spikes_and_label in spiketimes.items(): spikes, label = spikes_and_label neuron_y = num_components - (neuron_id + 1) - if label[0:3] in ["res", "nor"]: - color_id = particle_id_decode(label) - else: - color_id = c - ax.vlines( - spikes, neuron_y, neuron_y + 0.8, color=cpal[color_id], linewidth=1.0 - ) - + ax.vlines(spikes, neuron_y, neuron_y + 0.8, color="k", linewidth=1.0) + comp_labels = [v[1] for k, v in ss_spiketimes_list[0].items()] # xlim = np.max(np.concatenate([v[0] for v in ss_spiketimes_list.values()])) + 1 comp_labels.reverse() @@ -328,13 +325,12 @@ def merge_component_dicts(spiketimes, component_order): # this will combine all spikes across particles def merge_particle_dicts(spikes_by_particle): - def add_scalar_to_keys(scalar, d): new_d = {} for k, v in d.items(): new_d[k + scalar] = v return new_d - + final_spiketimes = {} max_key = 0 for st in spikes_by_particle: @@ -669,45 +665,90 @@ def update(frame): return ani -# x_p_assemblies_particle_0 = snmc_spikes_wrapper(pf_results, 'x', range(num_snmc_steps), range(0,1), ["assemblies_p13", "assemblies_p14", "assemblies_p15"])[0] - -# there are 3 neuroscience results i think we can explain from stryker. -# one, the average selectivity index of neurons in the different layers -# two, the flow of activity (source / sink) from layer to layer. -# three, the presence of direction selective neurons (state buffer). - -# mountcastle contains a second analysis method – what if you go off -# by just a bit, do the receptive fields change? - -# you'll have a separate diagram for excitatory neurons vs inhibitory neurons. -# you'll do a drawing of what the circuit looks like with inhibitory neurons -# added. - -# ctx = merge_particle_dicts(snmc_spikes_wrapper - -# bg = merge_particle_dicts(snmc_spikes_wrapper(pf_results, 'x', range(num_snmc_steps), range(num_particles), get_components(positions, range(num_particles), ['bg']))) - -# x_spikes = merge_particle_dicts(snmc_spikes_wrapper(pf_results, 'x', range(num_snmc_steps), range(num_particles), get_components(positions, range(num_particles), ['ctx']))[0]) - -# vx_spikes = merge_particle_dicts(snmc_spikes_wrapper(pf_results, 'vx', range(num_snmc_steps), range(num_particles), get_components(positions, range(num_particles), ['ctx']))) - -# to all spikes at once, just merge the return vals, which are -# all spiketimes for ss and all spiketimes for rs. also use all particles. - -# merged_assemblies = merge_component_dicts(x_p_assemblies_particle_0[0], ["assemblies_p13", "assemblies_p14", "assemblies_p15"]) +def xyz_and_particle_scores(init_pf, fs_pf, unrolled_pf): + xyz_init = init_pf[0].get_retval()[1] + xyz_step1 = fs_pf[0].get_retval()[1] + xyz_rest = unrolled_pf[1][1].get_retval()[1] + all_xyz = jnp.vstack([xyz_init[None, :], xyz_step1[None, :], xyz_rest]) + + init_scores = jax.tree.reduce(lambda x, y: x + y, init_pf[4]) + fs_scores = jax.tree.reduce(lambda x, y: x + y, fs_pf[4]) + unrolled_scores = jax.vmap( + lambda scores: jax.tree.reduce(lambda x, y: x + y, scores) + )(unrolled_pf[1][4]) + all_scores = [init_scores, fs_scores, *unrolled_scores] + return all_xyz, all_scores + + +def animate_current_particle_locs( + observations, + points_3D_seq, + scores, + plot_history, + visual_angles, + xyz_bounds, + interval=50, +): + observation_grids = [ + np.transpose(np.fliplr(o.reshape(len(visual_angles), len(visual_angles)))) + for o in observations + ] + xs, ys, zs = xyz_bounds + fig = plt.figure(figsize=(12, 6)) + cmap = my_tab20(len(scores[0])) + ax2D = fig.add_subplot(121) + ax3D = fig.add_subplot(122, projection="3d") + ax2D.set_xlabel("θ") + ax2D.set_ylabel("ϕ") + ax3D.set_xlabel("X") + ax3D.set_ylabel("Y") + ax3D.set_zlabel("Z") + ax3D.set_title("3D Hypotheses (Y = Depth)") + ax2D.set_title("2D Observation") + ax3D.set_xlim([xs[0], xs[-1]]) + ax3D.set_ylim([ys[0], ys[-1]]) + ax3D.set_zlim([zs[0], zs[-1]]) + num_particles = len(points_3D_seq[0]) + im = ax2D.imshow( + observation_grids[0], cmap="gray", interpolation="none", vmin=0, vmax=1 + ) + particles_3D = [ax3D.plot([], [], [], "o")[0] for _ in range(num_particles)] + tails_3D = [ + ax3D.plot([], [], [], "-", color=cmap[i], alpha=0.3, linewidth=0.5)[0] + for i in range(num_particles) + ] -# wta_components = ['wta_' + str(int(i)) for i in positions] + def update(frame): + im.set_array(observation_grids[frame]) + for i, (particle, (x, y, z)) in enumerate( + zip(particles_3D, points_3D_seq[frame]) + ): + particle.set_data([x], [y]) + # particle.set_alpha(float(jnp.exp(float(scores[frame][i]))**.1)) + particle.set_3d_properties([z]) + particle.set_color(cmap[i]) + if plot_history: + history = np.array(points_3D_seq[: frame + 1])[:, i, :] + tails_3D[i].set_data(history[:, 0], history[:, 1]) + tails_3D[i].set_3d_properties(history[:, 2]) + return [im] + particles_3D + tails_3D + + anim = FuncAnimation( + fig, update, frames=len(observations), interval=interval, blit=True + ) + return anim -# this is for the traveling wave. -# merged_wtas = invert_spiketimes(merge_component_dicts(x_spikes, wta_components)) -# lfp_and_spikes(invert_spiketime_labels(merged_assemblies), eeg(x_spikes)) -# have to incorporate multiple levels of the bayes net. -# so even if you're querying on a single variable, have to check -# the last spike time for ALL variables per step for ALL particles. That's when -# the resampler starts. +def make_probability_heatmap(matrix): + if len(matrix.shape) == 3: + x, y, z = np.indices(matrix.shape) + fig = plt.figure(figsize=(10, 7)) + ax = fig.add_subplot(111, projection="3d") + ax.set_xlabel("X") + ax.set_ylabel("Y") + ax.set_zlabel("Z") + scatter = ax.scatter(x, y, z, c=matrix.flatten(), cmap="viridis") + fig.colorbar(scatter, ax=ax, label="Value") + plt.show() -# for any given step, times within a samplescore are all correct and synched. -# the resampler should start at the very end of all samplescore times. -# first go through all the samplescores and save their spiketimes by step. diff --git a/packages/demos/3dot-smc/tests/genbrain_3dot_smc/tests.py b/packages/demos/3dot-smc/tests/genbrain_3dot_smc/tests.py index 691e4bb..afa6708 100644 --- a/packages/demos/3dot-smc/tests/genbrain_3dot_smc/tests.py +++ b/packages/demos/3dot-smc/tests/genbrain_3dot_smc/tests.py @@ -1,35 +1,13 @@ import genbrain_model_3dot as model +from genbrain_3dot_smc import collect_frames from genbrain_smcnn_core.interpreter import ( get_categorical_probs, run_smcnn_particle_filter, ) import genbrain_smcnn_core.interpreter.master as smcnn -import numpy as np import jax.numpy as jnp import jax from genjax import ChoiceMapBuilder as CMB -from PIL import Image -import re -import os - - -def collect_frames(directory="../data/", file_pattern="frame-*.png"): - frame_regex = re.compile(r"frame-(\d+)\.png") - frames = [] - for filename in os.listdir(directory): - match = frame_regex.match(filename) - if match: - frame_number = int(match.group(1)) - frames.append((frame_number, os.path.join(directory, filename))) - frames.sort(key=lambda x: x[0]) - frame_paths = [path for _, path in frames] - numpy_frames = [] - for path in frame_paths: - with Image.open(path) as img: - # this must be going downwards? - im = img.convert("L").point(lambda p: 1 if p > 0 else 0) - numpy_frames.append((np.transpose(np.flipud(np.array(im))) > 0).astype(int)) - return numpy_frames key = jax.random.PRNGKey(100) @@ -107,7 +85,6 @@ def collect_frames(directory="../data/", file_pattern="frame-*.png"): obs_tr_constrained.get_choices()["obs", "pix"] == obs_constraint[("obs", "pix")] ).all() - # Test classes from master interpreter q_probs = jnp.array([0.1, 0.2, 0.7]) @@ -223,11 +200,3 @@ def collect_frames(directory="../data/", file_pattern="frame-*.png"): 2, vis_angle_observations[1:3], ) - - -def get_xyz(results): - particles_per_step = results[1] - xyz_vals = [] - for particles in particles_per_step: - xyz_vals.append([model.xyz_point_cloud[p.choicemap["xyz"]] for p in particles]) - return xyz_vals diff --git a/packages/models/perception/3dot/src/genbrain_model_3dot/__init__.py b/packages/models/perception/3dot/src/genbrain_model_3dot/__init__.py index 36f8f4d..6dfb34d 100644 --- a/packages/models/perception/3dot/src/genbrain_model_3dot/__init__.py +++ b/packages/models/perception/3dot/src/genbrain_model_3dot/__init__.py @@ -361,26 +361,6 @@ def find_occupied_2d_angles(frame): return dig_2d_array(occupied_inds) -# observations = jax.vmap(lambda obs: find_occupied_2d_angles(obs))(jnp.array(obs_frames)) - -# np.save("observations.npy", np.array(observations[0:5])) -# loaded_obs = np.load("observations.npy") -# observations = jnp.array(loaded_obs) - - -# def show_obs_as_image(obs): -# im = obs.reshape(len(visual_angles), len(visual_angles)) -# fig, ax = plt.subplots() -# ax.imshow(im, cmap="viridis") -# ax.set_xticks(jnp.arange(len(visual_angles))) -# ax.set_yticks(jnp.arange(len(visual_angles))) -# ax.set_xticklabels(visual_angles, fontsize=4, rotation=90) -# ax.set_yticklabels((-1 * visual_angles), fontsize=4) -# ax.set_xlabel("Θ") -# ax.set_ylabel("ϕ") -# plt.imshow(im) - - def generate_obs_traces(observations): key = jax.random.PRNGKey(1000) random_args = initial_model.simulate(key, ()).get_retval()