IHT

迭代硬阈值算法(IHT)是一种用于解决M稀疏近似问题的迭代方法,通过梯度更新和硬阈值操作强制解具有指定稀疏度,并附有MATLAB实现代码。

基本原理

迭代硬阈值算法(Iterative Hard Thresholding, IHT),在提出时称M-Sparse Algorithm。$M$稀疏问题可以表示为

$$ \min_\mathbf y\lVert \mathbf x - \Phi\mathbf y \rVert^2_2\quad \text{s.t.}\quad \lVert\mathbf y\rVert_0\le M, $$

可以得到其迭代算法

$$ \mathbf y^{n+1}=H_M(\mathbf y^n+\Phi^\rm H(\mathbf x-\Phi\mathbf y^n)), $$

其中,$H_M$为非线性算子,其返回$M$维最大幅值

$$ H_M(y_i)= \left\{ \begin{align} 0\quad \lvert y_i\rvert\lt\lambda^{0.5}_M(\mathbf y)\\ y_i\quad \lvert y_i\rvert\ge\lambda^{0.5}_M(\mathbf y) \end{align} \right.. $$

推导过程

替代目标函数(Surrogate Objective Function)

$$ C_M^S(\mathbf{y,z})=\lVert \mathbf{x-\Phi y} \rVert_2^2 - \lVert \mathbf{\Phi y-\Phi z} \rVert_2^2 + \lVert \mathbf{y-z} \rVert_2^2 \qquad \lVert\mathbf\Phi\rVert_2\lt1, $$

当$\mathbf{y=z}$时,该函数即为原目标函数,其余情况下均大于目标函数。$\lVert\mathbf\Phi\rVert_2\lt1$由$0\lt eig(\mathbf I-\mathbf \Phi^\rm{H} \mathbf\Phi)\lt 1$推导出。

替代目标函数变形

$$ C_M^S(\mathbf{y,z})=\sum_i[y_i^2-2y_i(z_i+\phi_i^\rm{T}\mathbf x-\phi_i^\rm{T}\mathbf{\Phi z})]+\lVert\mathbf x\rVert_2^2+\lVert\mathbf z\rVert_2^2+\lVert\mathbf{\Phi z}\rVert_2^2, $$

由于优化目标为$\mathbf y$,而后面三项与其无关。因此,对其优化时忽略这三项不影响结果。

$$ C_M^S(\mathbf{y,z})\propto\sum_i[y_i^2-2y_i(z_i+\phi_i^\rm{T}\mathbf x-\phi_i^\rm{T}\mathbf{\Phi z})]. $$

极值点获取

$$ y_i^*=z_i+\phi_i^\rm H\mathbf x-\phi_i^\rm H\mathbf{\Phi z}, $$

极值点通过简单的配方法即可获得。这时,取得最小值$\sum_i{-(y_i^*)^2}$。

迭代公式获取

考虑到$\lVert\mathbf y\rVert_0\le M$的约束,需要保留$\lVert\mathbf y^*\rVert$的最大$M$项,即使用硬阈值函数。

代码

简单

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
function [s_est, info] = iht(g, Phi, max_iter, K)
    [~, N] = size(Phi);
    s_est = zeros(N, 1);
    g = g(:);

    g_norm = norm(g);
    L = compute_lipschitz_constant(Phi);
    t = 1 / L;  % 计算步长
    
    residual_tol = 1e-2;
    plateau_tol = 1e-6;
    plateau_patience = 20;
    plateau_count = 0;

    residual = Phi * s_est - g;  % 初始残差
    residual_norm = norm(residual);
    stop_reason = "max_iter";
    iter_done = 0;

    for iter = 1:max_iter
        z = s_est - Phi' * residual * t;    % 计算梯度步长更新
        s_est = hard_thresholding(z, K); % 应用硬阈值处理

        prev_residual_norm = residual_norm;
        residual = Phi * s_est - g;  % 更新残差
        residual_norm = norm(residual);
        iter_done = iter;

        if residual_norm < residual_tol * g_norm
            stop_reason = "residual_tol";
            break;
        end

        improvement = (prev_residual_norm - residual_norm) / max(prev_residual_norm, eps);
        if improvement < plateau_tol
            plateau_count = plateau_count + 1;
        else
            plateau_count = 0;
        end

        if plateau_count >= plateau_patience
            stop_reason = "plateau";
            break;
        end
    end

    if nargout > 1
        info = struct();
        info.iter = iter_done;
        info.residual_norm = residual_norm;
        info.relative_residual = residual_norm / max(g_norm, eps);
        info.stop_reason = stop_reason;
        info.step_size = t;
        info.L = L;
        info.plateau_count = plateau_count;
    end
end

function r = hard_thresholding(x, K)
    % 对向量x进行硬阈值处理,保留最大的K个元素
    r = zeros(size(x));
    [~, idx] = sort(abs(x), 'descend');
    r(idx(1:K)) = x(idx(1:K));
end

function L = compute_lipschitz_constant(Phi)
    ATA = Phi' * Phi;
    L = max(eig(ATA));
end

原文

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
function [s, err_mse, iter_time]=hard_l0_Mterm(x,A,m,M,varargin)
% hard_l0_Mterm: Hard thresholding algorithm that keeps exactly M elements 
% in each iteration. 
%
% This algorithm has certain performance guarantees as described in [1],
% [2] and [3].
%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Usage
%
%   [s, err_mse, iter_time]=hard_l0_Mterm(x,P,m,M,'option_name','option_value')
%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%
% Input
%
%   Mandatory:
%               x   Observation vector to be decomposed
%               P   Either:
%                       1) An nxm matrix (n must be dimension of x)
%                       2) A function handle (type "help function_format" 
%                          for more information)
%                          Also requires specification of P_trans option.
%                       3) An object handle (type "help object_format" for 
%                          more information)
%               m   length of s 
%               M   non-zero elements to keep in each iteration
%
%   Possible additional options:
%   (specify as many as you want using 'option_name','option_value' pairs)
%   See below for explanation of options:
%__________________________________________________________________________
%   option_name    |     available option_values                | default
%--------------------------------------------------------------------------
%   stopTol        | number (see below)                         | 1e-16
%   P_trans        | function_handle (see below)                | 
%   maxIter        | positive integer (see below)               | n^2
%   verbose        | true, false                                | false
%   start_val      | vector of length m                         | zeros
%   step_size      | number                                     | 0 (auto)
%
%   stopping criteria used : (OldRMS-NewRMS)/RMS(x) < stopTol
%
%   stopTol: Value for stopping criterion.
%
%   P_trans: If P is a function handle, then P_trans has to be specified and 
%            must be a function handle. 
%
%   maxIter: Maximum number of allowed iterations.
%
%   verbose: Logical value to allow algorithm progress to be displayed.
%
%   start_val: Allows algorithms to start from partial solution.
%
%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%
% Outputs
%
%    s              Solution vector 
%    err_mse        Vector containing mse of approximation error for each 
%                   iteration
%    iter_time      Vector containing computation times for each iteration
%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%
% Description
%
%   Implements the M-sparse algorithm described in [1], [2] and [3].
%   This algorithm takes a gradient step and then thresholds to only retain
%   M non-zero elements. It allows the step-size to be calculated
%   automatically as described in [3] and is therefore now independent from 
%   a rescaling of P.
%   
%   
% References
%   [1]  T. Blumensath and M.E. Davies, "Iterative Thresholding for Sparse 
%        Approximations", submitted, 2007
%   [2]  T. Blumensath and M. Davies; "Iterative Hard Thresholding for 
%        Compressed Sensing" to appear Applied and Computational Harmonic 
%        Analysis 
%   [3] T. Blumensath and M. Davies; "A modified Iterative Hard 
%        Thresholding algorithm with guaranteed performance and stability" 
%        in preparation (title may change) 
% See Also
%   hard_l0_reg
%
% Copyright (c) 2007 Thomas Blumensath
%
% The University of Edinburgh
% Email: thomas.blumensath@ed.ac.uk
% Comments and bug reports welcome
%
% This file is part of sparsity Version 0.4
% Created: April 2007
% Modified January 2009
%
% Part of this toolbox was developed with the support of EPSRC Grant
% D000246/1
%
% Please read COPYRIGHT.m for terms and conditions.


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%                    Default values and initialisation
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%



[n1 n2]=size(x);
if n2 == 1
    n=n1;
elseif n1 == 1
    x=x';
    n=n2;
else
   error('x must be a vector.');
end
    
sigsize     = x'*x/n;
oldERR      = sigsize;
err_mse     = [];
iter_time   = [];
STOPTOL     = 1e-16;
MAXITER     = n^2;
verbose     = false;
initial_given=0;
s_initial   = zeros(m,1);
MU          = 0;

if verbose
   display('Initialising...') 
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%                           Output variables
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

switch nargout 
    case 3
        comp_err=true;
        comp_time=true;
    case 2 
        comp_err=true;
        comp_time=false;
    case 1
        comp_err=false;
        comp_time=false;
    case 0
        error('Please assign output variable.')        
    otherwise
        error('Too many output arguments specified')
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%                       Look through options
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

% Put option into nice format
Options={};
OS=nargin-4;
c=1;
for i=1:OS
    if isa(varargin{i},'cell')
        CellSize=length(varargin{i});
        ThisCell=varargin{i};
        for j=1:CellSize
            Options{c}=ThisCell{j};
            c=c+1;
        end
    else
        Options{c}=varargin{i};
        c=c+1;
    end
end
OS=length(Options);
if rem(OS,2)
   error('Something is wrong with argument name and argument value pairs.') 
end
for i=1:2:OS
   switch Options{i}
        case {'stopTol'}
            if isa(Options{i+1},'numeric') ; STOPTOL     = Options{i+1};   
            else error('stopTol must be number. Exiting.'); end
        case {'P_trans'} 
            if isa(Options{i+1},'function_handle'); Pt = Options{i+1};   
            else error('P_trans must be function _handle. Exiting.'); end
        case {'maxIter'}
            if isa(Options{i+1},'numeric'); MAXITER     = Options{i+1};             
            else error('maxIter must be a number. Exiting.'); end
        case {'verbose'}
            if isa(Options{i+1},'logical'); verbose     = Options{i+1};   
            else error('verbose must be a logical. Exiting.'); end 
        case {'start_val'}
            if isa(Options{i+1},'numeric') && length(Options{i+1}) == m ;
                s_initial     = Options{i+1};  
                initial_given=1;
            else error('start_val must be a vector of length m. Exiting.'); end
        case {'step_size'}
            if isa(Options{i+1},'numeric') && (Options{i+1}) > 0 ;
                MU     = Options{i+1};   
            else error('Stepsize must be between a positive number. Exiting.'); end
        otherwise
            error('Unrecognised option. Exiting.') 
   end
end

if nargout >=2
    err_mse = zeros(MAXITER,1);
end
if nargout ==3
    iter_time = zeros(MAXITER,1);
end


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%                        Make P and Pt functions
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

if          isa(A,'float')      P =@(z) A*z;  Pt =@(z) A'*z;
elseif      isobject(A)         P =@(z) A*z;  Pt =@(z) A'*z;
elseif      isa(A,'function_handle') 
    try
        if          isa(Pt,'function_handle'); P=A;
        else        error('If P is a function handle, Pt also needs to be a function handle. Exiting.'); end
    catch error('If P is a function handle, Pt needs to be specified. Exiting.'); end
else        error('P is of unsupported type. Use matrix, function_handle or object. Exiting.'); end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%                        Do we start from zero or not?
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%



if initial_given ==1;
    
    if length(find(s_initial)) > M
        display('Initial vector has more than M non-zero elements. Keeping only M largest.')
    
    end
    s                   =   s_initial;
    [ssort sortind]     =   sort(abs(s),'descend');
    s(sortind(M+1:end)) =   0;
    Ps                  =   P(s);
    Residual            =   x-Ps;
    oldERR      = Residual'*Residual/n;
else
    s_initial   = zeros(m,1);
    Residual    = x;
    s           = s_initial;
    Ps          = zeros(n,1);
    oldERR      = sigsize;
end


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%                 Random Check to see if dictionary norm is below 1 
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

        
        x_test=randn(m,1);
        x_test=x_test/norm(x_test);
        nP=norm(P(x_test));
        if abs(MU*nP)>1;
            display('WARNING! Algorithm likely to become unstable.')
            display('Use smaller step-size or || P ||_2 < 1.')
        end


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%                        Main algorithm
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
if verbose
   display('Main iterations...') 
end
tic
t=0;
done = 0;
iter=1;

while ~done
    
    if MU == 0

        %Calculate optimal step size and do line search
        olds                =   s;
        oldPs               =   Ps;
        IND                 =   s~=0;
        d                   =   Pt(Residual);
        % If the current vector is zero, we take the largest elements in d
        if sum(IND)==0
            [dsort sortdind]    =   sort(abs(d),'descend');
            IND(sortdind(1:M))  =   1;    
         end  

        id                  =   (IND.*d);
        Pd                  =   P(id);
        mu                  =   id'*id/(Pd'*Pd);
        s                   =   olds + mu * d;
        [ssort sortind]     =   sort(abs(s),'descend');
        s(sortind(M+1:end)) =   0;
        Ps                  =   P(s);
        
        % Calculate step-size requirement 
        omega               =   (norm(s-olds)/norm(Ps-oldPs))^2;

        % As long as the support changes and mu > omega, we decrease mu
        while mu > (0.99)*omega && sum(xor(IND,s~=0))~=0 && sum(IND)~=0
%             display(['decreasing mu'])
                    
                    % We use a simple line search, halving mu in each step
                    mu                  =   mu/2;
                    s                   =   olds + mu * d;
                    [ssort sortind]     =   sort(abs(s),'descend');
                    s(sortind(M+1:end)) =   0;
                    Ps                  =   P(s);
                    % Calculate step-size requirement 
                    omega               =   (norm(s-olds)/norm(Ps-oldPs))^2;
        end
        
    else
        % Use fixed step size
        s                   =   s + MU * Pt(Residual);
        [ssort sortind]     =   sort(abs(s),'descend');
        s(sortind(M+1:end)) =   0;
        Ps                  =   P(s);
        
    end
        Residual            =   x-Ps;

        
     ERR=Residual'*Residual/n;
     if comp_err
         err_mse(iter)=ERR;
     end
     
     if comp_time
         iter_time(iter)=toc;
     end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%                        Are we done yet?
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
     
         if comp_err && iter >=2
             if ((err_mse(iter-1)-err_mse(iter))/sigsize<STOPTOL);
                 if verbose
                    display(['Stopping. Approximation error changed less than ' num2str(STOPTOL)])
                 end
                done = 1; 
             elseif verbose && toc-t>10
                display(sprintf('Iteration %i. --- %i mse change',iter ,(err_mse(iter-1)-err_mse(iter))/sigsize)) 
                t=toc;
             end
         else
             if ((oldERR - ERR)/sigsize < STOPTOL) && iter >=2;
                 if verbose
                    display(['Stopping. Approximation error changed less than ' num2str(STOPTOL)])
                 end
                done = 1; 
             elseif verbose && toc-t>10
                display(sprintf('Iteration %i. --- %i mse change',iter ,(oldERR - ERR)/sigsize)) 
                t=toc;
             end
         end
         
    % Also stop if residual gets too small or maxIter reached
     if comp_err
         if err_mse(iter)<1e-16
             display('Stopping. Exact signal representation found!')
             done=1;
         end
     elseif iter>1 
         if ERR<1e-16
             display('Stopping. Exact signal representation found!')
             done=1;
         end
     end

     if iter >= MAXITER
         display('Stopping. Maximum number of iterations reached!')
         done = 1; 
     end
 
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%                    If not done, take another round
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
   
     if ~done
        iter=iter+1; 
        oldERR=ERR;        
     end
end


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%                  Only return as many elements as iterations
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

if nargout >=2
    err_mse = err_mse(1:iter);
end
if nargout ==3
    iter_time = iter_time(1:iter);
end
if verbose
   display('Done') 
end
使用 Hugo 构建
主题 StackJimmy 设计