-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathkeepMaxClusters.m
75 lines (67 loc) · 2.05 KB
/
keepMaxClusters.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
function X = keepMaxClusters(results, T, threshold)
% X = keepMaxClusters(results, T, threshold) keeps only those clusters that
% have their largest waveform on the center channel. Spikes are assumed
% to be sorted in groups of K channels, with K-1 channels overlap between
% adjacent groups. At the end, any duplicate spikes are removed. During
% this process we always keep the spike from the cluster with the larger
% waveform.
[D, ~, K] = size(results(1).w);
center = (K + 1) / 2;
R = numel(results);
spikes = {}; %#ok<*AGROW>
clusters = {};
mag = [];
w = zeros(D, K, 0);
for i = 1 : R
r = results(i);
a = cluster(r.model, r.b);
for j = 1 : numel(r.model.pi);
m = permute(sum(mean(r.w(:, a == j, :), 2) .^ 2, 1), [3 2 1]);
[mm, ndx] = max(m);
if ndx == center || (i == 1 && ndx < center) || (i == R && ndx > center)
spikes{end + 1} = r.s(a == j);
clusters{end + 1} = numel(spikes) * ones(size(spikes{end}));
mag(end + 1) = mm;
w(:, :, end + 1) = permute(mean(r.w(:, a == j, :), 2), [1 3 2]);
end
end
end
M = numel(spikes);
total = cellfun(@numel, spikes);
spikes = cat(1, spikes{:});
clusters = cat(1, clusters{:});
% order spikes in time
[spikes, order] = sort(spikes);
clusters = clusters(order);
% remove spikes of smaller size
N = numel(spikes);
keep = true(N, 1);
prev = 1;
refrac = 4; % 1/3 ms
for i = 2 : N
if spikes(i) - spikes(prev) < refrac
if mag(clusters(i)) < mag(clusters(prev))
keep(i) = false;
else
keep(prev) = false;
prev = i;
end
else
prev = i;
end
end
spikes = spikes(keep);
clusters = clusters(keep);
% remove clusters that lost too many spikes to other clusters
frac = hist(clusters, 1 : M) ./ total;
keep = true(numel(spikes), 1);
for i = 1 : M
if frac(i) < threshold
keep(clusters == i) = false;
end
end
spikes = spikes(keep);
clusters = clusters(keep);
[~, ~, clusters] = unique(clusters);
% create spike matrix
X = sparse(spikes, clusters, 1, T, max(clusters));