-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathremoveDuplicateClusters.m
69 lines (63 loc) · 1.56 KB
/
removeDuplicateClusters.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
function X = removeDuplicateClusters(results, threshold, T)
% Remove duplicate clusters
M = sum(arrayfun(@(r) numel(r.model.pi), results));
R = numel(results);
spikes = cell(1, R);
clusters = cell(1, R);
m = zeros(1, M);
k = 0;
for i = 1 : R
r = results(i);
a = cluster(r.model, r.b);
J = numel(r.model.pi);
for j = 1 : J
m(k + j) = mean(mean(mean(r.w(:, a == j, :), 2) .^ 2));
end
spikes{i} = r.s;
clusters{i} = k + a;
k = k + J;
end
spikes = cat(1, spikes{:});
clusters = cat(1, clusters{:});
% order cluster ids by magnitude of average waveform
[~, order] = sort(m, 'descend');
for i = 1 : M
clusters(clusters == order(i)) = -i;
end
clusters = -clusters;
total = hist(clusters, 1 : M);
% order spikes in time
[spikes, order] = sort(spikes);
clusters = clusters(order);
% remove spikes of smaller size
N = numel(spikes);
keep = true(N, 1);
prev = 1;
refrac = 4; % 1/3 ms
for i = 2 : N
if spikes(i) - spikes(prev) < refrac
if clusters(i) > clusters(prev)
keep(i) = false;
else
keep(prev) = false;
prev = i;
end
else
prev = i;
end
end
spikes = spikes(keep);
clusters = clusters(keep);
% remove clusters that lost too many spikes to other clusters
frac = hist(clusters, 1 : M) ./ total;
keep = true(numel(spikes), 1);
for i = 1 : M
if frac(i) < threshold
keep(clusters == i) = false;
end
end
spikes = spikes(keep);
clusters = clusters(keep);
[~, ~, clusters] = unique(clusters);
% create spike matrix
X = sparse(spikes, clusters, 1, T, max(clusters));