]> rtime.felk.cvut.cz Git - frescor/ffmpeg.git/blobdiff - libavcodec/wmaprodec.c
WMA: extend exponent range to 95
[frescor/ffmpeg.git] / libavcodec / wmaprodec.c
index f96e4dbed3d71a7126f10b51f0e8aabd18242803..cbc97d9c82aaf50c3ea41a969bcb4e0af472e8fd 100644 (file)
@@ -98,7 +98,7 @@
 #define WMAPRO_MAX_CHANNELS    8                             ///< max number of handled channels
 #define MAX_SUBFRAMES  32                                    ///< max number of subframes per channel
 #define MAX_BANDS      29                                    ///< max number of scale factor bands
-#define MAX_FRAMESIZE  16384                                 ///< maximum compressed frame size
+#define MAX_FRAMESIZE  32768                                 ///< maximum compressed frame size
 
 #define WMAPRO_BLOCK_MAX_BITS 12                                           ///< log2 of max block size
 #define WMAPRO_BLOCK_MAX_SIZE (1 << WMAPRO_BLOCK_MAX_BITS)                 ///< maximum block size
@@ -137,8 +137,9 @@ typedef struct {
     int8_t   reuse_sf;                                ///< share scale factors between subframes
     int8_t   scale_factor_step;                       ///< scaling step for the current subframe
     int      max_scale_factor;                        ///< maximum scale factor for the current subframe
-    int      scale_factors[MAX_BANDS];                ///< scale factor values for the current subframe
-    int      saved_scale_factors[MAX_BANDS];          ///< scale factors from a previous subframe
+    int      saved_scale_factors[2][MAX_BANDS];       ///< resampled and (previously) transmitted scale factor values
+    int8_t   scale_factor_idx;                        ///< index for the transmitted scale factor values (used for resampling)
+    int*     scale_factors;                           ///< pointer to the scale factor values used for decoding
     uint8_t  table_idx;                               ///< index in sf_offsets for the scale factor reference block
     float*   coeffs;                                  ///< pointer to the subframe decode buffer
     DECLARE_ALIGNED_16(float, out[WMAPRO_BLOCK_MAX_SIZE + WMAPRO_BLOCK_MAX_SIZE / 2]); ///< output buffer
@@ -165,7 +166,7 @@ typedef struct WMAProDecodeCtx {
     uint8_t          frame_data[MAX_FRAMESIZE +
                       FF_INPUT_BUFFER_PADDING_SIZE];///< compressed frame data
     PutBitContext    pb;                            ///< context for filling the frame_data buffer
-    MDCTContext      mdct_ctx[WMAPRO_BLOCK_SIZES];  ///< MDCT context per block size
+    FFTContext       mdct_ctx[WMAPRO_BLOCK_SIZES];  ///< MDCT context per block size
     DECLARE_ALIGNED_16(float, tmp[WMAPRO_BLOCK_MAX_SIZE]); ///< IMDCT output buffer
     float*           windows[WMAPRO_BLOCK_SIZES];   ///< windows for the different block sizes
 
@@ -188,11 +189,14 @@ typedef struct WMAProDecodeCtx {
     int16_t          subwoofer_cutoffs[WMAPRO_BLOCK_SIZES]; ///< subwoofer cutoff values
 
     /* packet decode state */
+    GetBitContext    pgb;                           ///< bitstream reader context for the packet
+    uint8_t          packet_offset;                 ///< frame offset in the packet
     uint8_t          packet_sequence_number;        ///< current packet number
     int              num_saved_bits;                ///< saved number of bits
     int              frame_offset;                  ///< frame offset in the bit reservoir
     int              subframe_offset;               ///< subframe offset in the bit reservoir
     uint8_t          packet_loss;                   ///< set in case of bitstream error
+    uint8_t          packet_done;                   ///< set when a packet is fully decoded
 
     /* frame decode state */
     uint32_t         frame_num;                     ///< current frame number (not used for decoding)
@@ -262,7 +266,7 @@ static av_cold int decode_end(AVCodecContext *avctx)
 static av_cold int decode_init(AVCodecContext *avctx)
 {
     WMAProDecodeCtx *s = avctx->priv_data;
-    uint8_t *edata_ptr   = avctx->extradata;
+    uint8_t *edata_ptr = avctx->extradata;
     unsigned int channel_mask;
     int i;
     int log2_max_num_subframes;
@@ -381,11 +385,11 @@ static av_cold int decode_init(AVCodecContext *avctx)
 
         s->sfb_offsets[i][0] = 0;
 
-        for (x = 0; x < MAX_BANDS-1 && s->sfb_offsets[i][band-1] < subframe_len; x++) {
+        for (x = 0; x < MAX_BANDS-1 && s->sfb_offsets[i][band - 1] < subframe_len; x++) {
             int offset = (subframe_len * 2 * critical_freq[x])
                           / s->avctx->sample_rate + 2;
             offset &= ~3;
-            if ( offset > s->sfb_offsets[i][band - 1] )
+            if (offset > s->sfb_offsets[i][band - 1])
                 s->sfb_offsets[i][band++] = offset;
         }
         s->sfb_offsets[i][band - 1] = subframe_len;
@@ -403,7 +407,7 @@ static av_cold int decode_init(AVCodecContext *avctx)
         for (b = 0; b < s->num_sfb[i]; b++) {
             int x;
             int offset = ((s->sfb_offsets[i][b]
-                         + s->sfb_offsets[i][b + 1] - 1)<<i) >> 1;
+                           + s->sfb_offsets[i][b + 1] - 1) << i) >> 1;
             for (x = 0; x < num_possible_block_sizes; x++) {
                 int v = 0;
                 while (s->sfb_offsets[x][v + 1] << x < offset)
@@ -416,15 +420,15 @@ static av_cold int decode_init(AVCodecContext *avctx)
     /** init MDCT, FIXME: only init needed sizes */
     for (i = 0; i < WMAPRO_BLOCK_SIZES; i++)
         ff_mdct_init(&s->mdct_ctx[i], BLOCK_MIN_BITS+1+i, 1,
-                     1.0 / (1 <<(BLOCK_MIN_BITS + i - 1))
+                     1.0 / (1 << (BLOCK_MIN_BITS + i - 1))
                      / (1 << (s->bits_per_sample - 1)));
 
     /** init MDCT windows: simple sinus window */
     for (i = 0; i < WMAPRO_BLOCK_SIZES; i++) {
         const int n       = 1 << (WMAPRO_BLOCK_MAX_BITS - i);
-        const int win_idx = WMAPRO_BLOCK_MAX_BITS - i - 7;
+        const int win_idx = WMAPRO_BLOCK_MAX_BITS - i;
         ff_sine_window_init(ff_sine_windows[win_idx], n);
-        s->windows[WMAPRO_BLOCK_SIZES-i-1] = ff_sine_windows[win_idx];
+        s->windows[WMAPRO_BLOCK_SIZES - i - 1] = ff_sine_windows[win_idx];
     }
 
     /** calculate subwoofer cutoff values */
@@ -471,8 +475,8 @@ static int decode_subframe_length(WMAProDecodeCtx *s, int offset)
     subframe_len = s->samples_per_frame >> frame_len_shift;
 
     /** sanity check the length */
-    if (subframe_len < s->min_samples_per_subframe
-              || subframe_len > s->samples_per_frame) {
+    if (subframe_len < s->min_samples_per_subframe ||
+        subframe_len > s->samples_per_frame) {
         av_log(s->avctx, AV_LOG_ERROR, "broken frame: subframe_len %i\n",
                subframe_len);
         return AVERROR_INVALIDDATA;
@@ -559,11 +563,11 @@ static int decode_tilehdr(WMAProDecodeCtx *s)
                 num_samples[c] += subframe_len;
                 ++chan->num_subframes;
                 if (num_samples[c] > s->samples_per_frame) {
-                    av_log(s->avctx, AV_LOG_ERROR,"broken frame: "
+                    av_log(s->avctx, AV_LOG_ERROR, "broken frame: "
                            "channel len > samples_per_frame\n");
                     return AVERROR_INVALIDDATA;
                 }
-            } else if(num_samples[c] <= min_channel_len) {
+            } else if (num_samples[c] <= min_channel_len) {
                 if (num_samples[c] < min_channel_len) {
                     channels_for_cur_subframe = 0;
                     min_channel_len = num_samples[c];
@@ -578,7 +582,8 @@ static int decode_tilehdr(WMAProDecodeCtx *s)
         int offset = 0;
         for (i = 0; i < s->channel[c].num_subframes; i++) {
             dprintf(s->avctx, "frame[%i] channel[%i] subframe[%i]"
-                   " len %i\n", s->frame_num, c, i, s->channel[c].subframe_len[i]);
+                    " len %i\n", s->frame_num, c, i,
+                    s->channel[c].subframe_len[i]);
             s->channel[c].subframe_offset[i] = offset;
             offset += s->channel[c].subframe_len[i];
         }
@@ -606,7 +611,7 @@ static void decode_decorrelation_matrix(WMAProDecodeCtx *s,
 
     for (i = 0; i < chgroup->num_channels; i++)
         chgroup->decorrelation_matrix[chgroup->num_channels * i + i] =
-                                                get_bits1(&s->gb) ? 1.0 : -1.0;
+            get_bits1(&s->gb) ? 1.0 : -1.0;
 
     for (i = 1; i < chgroup->num_channels; i++) {
         int x;
@@ -621,10 +626,10 @@ static void decode_decorrelation_matrix(WMAProDecodeCtx *s,
 
                 if (n < 32) {
                     sinv = sin64[n];
-                    cosv = sin64[32-n];
+                    cosv = sin64[32 - n];
                 } else {
-                    sinv = sin64[64-n];
-                    cosv = -sin64[n-32];
+                    sinv =  sin64[64 -  n];
+                    cosv = -sin64[n  - 32];
                 }
 
                 chgroup->decorrelation_matrix[y + x * chgroup->num_channels] =
@@ -646,7 +651,7 @@ static int decode_channel_transform(WMAProDecodeCtx* s)
 {
     int i;
     /* should never consume more than 1921 bits for the 8 channel case
-     * 1 + MAX_CHANNELS * ( MAX_CHANNELS + 2 + 3 * MAX_CHANNELS * MAX_CHANNELS
+     * 1 + MAX_CHANNELS * (MAX_CHANNELS + 2 + 3 * MAX_CHANNELS * MAX_CHANNELS
      * + MAX_CHANNELS + MAX_BANDS + 1)
      */
 
@@ -662,7 +667,7 @@ static int decode_channel_transform(WMAProDecodeCtx* s)
         }
 
         for (s->num_chgroups = 0; remaining_channels &&
-            s->num_chgroups < s->channels_for_cur_subframe; s->num_chgroups++) {
+             s->num_chgroups < s->channels_for_cur_subframe; s->num_chgroups++) {
             WMAProChannelGrp* chgroup = &s->chgroup[s->num_chgroups];
             float** channel_data = chgroup->channel_data;
             chgroup->num_channels = 0;
@@ -694,7 +699,7 @@ static int decode_channel_transform(WMAProDecodeCtx* s)
                 if (get_bits1(&s->gb)) {
                     if (get_bits1(&s->gb)) {
                         av_log_ask_for_sample(s->avctx,
-                               "unsupported channel transform type\n");
+                                              "unsupported channel transform type\n");
                     }
                 } else {
                     chgroup->transform = 1;
@@ -720,12 +725,12 @@ static int decode_channel_transform(WMAProDecodeCtx* s)
                         /** FIXME: more than 6 coupled channels not supported */
                         if (chgroup->num_channels > 6) {
                             av_log_ask_for_sample(s->avctx,
-                                   "coupled channels > 6\n");
+                                                  "coupled channels > 6\n");
                         } else {
                             memcpy(chgroup->decorrelation_matrix,
-                              default_decorrelation[chgroup->num_channels],
-                              chgroup->num_channels * chgroup->num_channels *
-                              sizeof(*chgroup->decorrelation_matrix));
+                                   default_decorrelation[chgroup->num_channels],
+                                   chgroup->num_channels * chgroup->num_channels *
+                                   sizeof(*chgroup->decorrelation_matrix));
                         }
                     }
                 }
@@ -757,6 +762,15 @@ static int decode_channel_transform(WMAProDecodeCtx* s)
  */
 static int decode_coeffs(WMAProDecodeCtx *s, int c)
 {
+    /* Integers 0..15 as single-precision floats.  The table saves a
+       costly int to float conversion, and storing the values as
+       integers allows fast sign-flipping. */
+    static const int fval_tab[16] = {
+        0x00000000, 0x3f800000, 0x40000000, 0x40400000,
+        0x40800000, 0x40a00000, 0x40c00000, 0x40e00000,
+        0x41000000, 0x41100000, 0x41200000, 0x41300000,
+        0x41400000, 0x41500000, 0x41600000, 0x41700000,
+    };
     int vlctable;
     VLC* vlc;
     WMAProChannelCtx* ci = &s->channel[c];
@@ -764,7 +778,7 @@ static int decode_coeffs(WMAProDecodeCtx *s, int c)
     int cur_coeff = 0;
     int num_zeros = 0;
     const uint16_t* run;
-    const uint16_t* level;
+    const float* level;
 
     dprintf(s->avctx, "decode coefficients for channel %i\n", c);
 
@@ -792,35 +806,38 @@ static int decode_coeffs(WMAProDecodeCtx *s, int c)
             for (i = 0; i < 4; i += 2) {
                 idx = get_vlc2(&s->gb, vec2_vlc.table, VLCBITS, VEC2MAXDEPTH);
                 if (idx == HUFF_VEC2_SIZE - 1) {
-                    vals[i] = get_vlc2(&s->gb, vec1_vlc.table, VLCBITS, VEC1MAXDEPTH);
-                    if (vals[i] == HUFF_VEC1_SIZE - 1)
-                        vals[i] += ff_wma_get_large_val(&s->gb);
-                    vals[i+1] = get_vlc2(&s->gb, vec1_vlc.table, VLCBITS, VEC1MAXDEPTH);
-                    if (vals[i+1] == HUFF_VEC1_SIZE - 1)
-                        vals[i+1] += ff_wma_get_large_val(&s->gb);
+                    int v0, v1;
+                    v0 = get_vlc2(&s->gb, vec1_vlc.table, VLCBITS, VEC1MAXDEPTH);
+                    if (v0 == HUFF_VEC1_SIZE - 1)
+                        v0 += ff_wma_get_large_val(&s->gb);
+                    v1 = get_vlc2(&s->gb, vec1_vlc.table, VLCBITS, VEC1MAXDEPTH);
+                    if (v1 == HUFF_VEC1_SIZE - 1)
+                        v1 += ff_wma_get_large_val(&s->gb);
+                    ((float*)vals)[i  ] = v0;
+                    ((float*)vals)[i+1] = v1;
                 } else {
-                    vals[i]   = symbol_to_vec2[idx] >> 4;
-                    vals[i+1] = symbol_to_vec2[idx] & 0xF;
+                    vals[i]   = fval_tab[symbol_to_vec2[idx] >> 4 ];
+                    vals[i+1] = fval_tab[symbol_to_vec2[idx] & 0xF];
                 }
             }
         } else {
-             vals[0] =  symbol_to_vec4[idx] >> 12;
-             vals[1] = (symbol_to_vec4[idx] >> 8) & 0xF;
-             vals[2] = (symbol_to_vec4[idx] >> 4) & 0xF;
-             vals[3] =  symbol_to_vec4[idx]       & 0xF;
+            vals[0] = fval_tab[ symbol_to_vec4[idx] >> 12      ];
+            vals[1] = fval_tab[(symbol_to_vec4[idx] >> 8) & 0xF];
+            vals[2] = fval_tab[(symbol_to_vec4[idx] >> 4) & 0xF];
+            vals[3] = fval_tab[ symbol_to_vec4[idx]       & 0xF];
         }
 
         /** decode sign */
         for (i = 0; i < 4; i++) {
             if (vals[i]) {
                 int sign = get_bits1(&s->gb) - 1;
-                ci->coeffs[cur_coeff] = (vals[i]^sign) - sign;
+                *(uint32_t*)&ci->coeffs[cur_coeff] = vals[i] ^ sign<<31;
                 num_zeros = 0;
             } else {
                 ci->coeffs[cur_coeff] = 0;
                 /** switch to run level mode when subframe_len / 128 zeros
-                   were found in a row */
-                rl_mode |= (++num_zeros > s->subframe_len>>8);
+                    were found in a row */
+                rl_mode |= (++num_zeros > s->subframe_len >> 8);
             }
             ++cur_coeff;
         }
@@ -856,7 +873,9 @@ static int decode_scale_factors(WMAProDecodeCtx* s)
     for (i = 0; i < s->channels_for_cur_subframe; i++) {
         int c = s->channel_indexes_for_cur_subframe[i];
         int* sf;
-        int* sf_end = s->channel[c].scale_factors + s->num_bands;
+        int* sf_end;
+        s->channel[c].scale_factors = s->channel[c].saved_scale_factors[!s->channel[c].scale_factor_idx];
+        sf_end = s->channel[c].scale_factors + s->num_bands;
 
         /** resample scale factors for the new block size
          *  as the scale factors might need to be resampled several times
@@ -868,7 +887,7 @@ static int decode_scale_factors(WMAProDecodeCtx* s)
             int b;
             for (b = 0; b < s->num_bands; b++)
                 s->channel[c].scale_factors[b] =
-                                   s->channel[c].saved_scale_factors[*sf_offsets++];
+                    s->channel[c].saved_scale_factors[s->channel[c].scale_factor_idx][*sf_offsets++];
         }
 
         if (!s->channel[c].cur_subframe || get_bits1(&s->gb)) {
@@ -893,7 +912,7 @@ static int decode_scale_factors(WMAProDecodeCtx* s)
 
                     idx = get_vlc2(&s->gb, sf_rl_vlc.table, VLCBITS, SCALERLMAXDEPTH);
 
-                    if ( !idx ) {
+                    if (!idx) {
                         uint32_t code = get_bits(&s->gb, 14);
                         val  =  code >> 6;
                         sign = (code & 1) - 1;
@@ -908,19 +927,15 @@ static int decode_scale_factors(WMAProDecodeCtx* s)
 
                     i += skip;
                     if (i >= s->num_bands) {
-                        av_log(s->avctx,AV_LOG_ERROR,
+                        av_log(s->avctx, AV_LOG_ERROR,
                                "invalid scale factor coding\n");
                         return AVERROR_INVALIDDATA;
                     }
                     s->channel[c].scale_factors[i] += (val ^ sign) - sign;
                 }
             }
-
-            /** save transmitted scale factors so that they can be reused for
-                the next subframe */
-            memcpy(s->channel[c].saved_scale_factors,
-                   s->channel[c].scale_factors, s->num_bands *
-                   sizeof(*s->channel[c].saved_scale_factors));
+            /** swap buffers */
+            s->channel[c].scale_factor_idx = !s->channel[c].scale_factor_idx;
             s->channel[c].table_idx = s->table_idx;
             s->channel[c].reuse_sf  = 1;
         }
@@ -955,7 +970,7 @@ static void inverse_channel_transform(WMAProDecodeCtx *s)
 
             /** multichannel decorrelation */
             for (sfb = s->cur_sfb_offsets;
-                sfb < s->cur_sfb_offsets + s->num_bands;sfb++) {
+                 sfb < s->cur_sfb_offsets + s->num_bands; sfb++) {
                 int y;
                 if (*tb++ == 1) {
                     /** multiply values with the decorrelation_matrix */
@@ -966,7 +981,7 @@ static void inverse_channel_transform(WMAProDecodeCtx *s)
                         float** ch;
 
                         for (ch = ch_data; ch < ch_end; ch++)
-                           *data_ptr++ = (*ch)[y];
+                            *data_ptr++ = (*ch)[y];
 
                         for (ch = ch_data; ch < ch_end; ch++) {
                             float sum = 0;
@@ -978,10 +993,13 @@ static void inverse_channel_transform(WMAProDecodeCtx *s)
                         }
                     }
                 } else if (s->num_channels == 2) {
-                    for (y = sfb[0]; y < FFMIN(sfb[1], s->subframe_len); y++) {
-                        ch_data[0][y] *= 181.0 / 128;
-                        ch_data[1][y] *= 181.0 / 128;
-                    }
+                    int len = FFMIN(sfb[1], s->subframe_len) - sfb[0];
+                    s->dsp.vector_fmul_scalar(ch_data[0] + sfb[0],
+                                              ch_data[0] + sfb[0],
+                                              181.0 / 128, len);
+                    s->dsp.vector_fmul_scalar(ch_data[1] + sfb[0],
+                                              ch_data[1] + sfb[0],
+                                              181.0 / 128, len);
                 }
             }
         }
@@ -995,18 +1013,18 @@ static void inverse_channel_transform(WMAProDecodeCtx *s)
 static void wmapro_window(WMAProDecodeCtx *s)
 {
     int i;
-    for (i = 0; i< s->channels_for_cur_subframe; i++) {
+    for (i = 0; i < s->channels_for_cur_subframe; i++) {
         int c = s->channel_indexes_for_cur_subframe[i];
         float* window;
         int winlen = s->channel[c].prev_block_len;
         float* start = s->channel[c].coeffs - (winlen >> 1);
 
         if (s->subframe_len < winlen) {
-            start += (winlen - s->subframe_len)>>1;
+            start += (winlen - s->subframe_len) >> 1;
             winlen = s->subframe_len;
         }
 
-        window = s->windows[av_log2(winlen)-BLOCK_MIN_BITS];
+        window = s->windows[av_log2(winlen) - BLOCK_MIN_BITS];
 
         winlen >>= 1;
 
@@ -1047,7 +1065,7 @@ static int decode_subframe(WMAProDecodeCtx *s)
     }
 
     dprintf(s->avctx,
-           "processing subframe with offset %i len %i\n", offset, subframe_len);
+            "processing subframe with offset %i len %i\n", offset, subframe_len);
 
     /** get a list of all channels that contain the estimated block */
     s->channels_for_cur_subframe = 0;
@@ -1058,7 +1076,7 @@ static int decode_subframe(WMAProDecodeCtx *s)
 
         /** and count if there are multiple subframes that match our profile */
         if (offset == s->channel[i].decoded_samples &&
-           subframe_len == s->channel[i].subframe_len[cur_subframe]) {
+            subframe_len == s->channel[i].subframe_len[cur_subframe]) {
             total_samples -= s->channel[i].subframe_len[cur_subframe];
             s->channel[i].decoded_samples +=
                 s->channel[i].subframe_len[cur_subframe];
@@ -1074,7 +1092,7 @@ static int decode_subframe(WMAProDecodeCtx *s)
 
 
     dprintf(s->avctx, "subframe is part of %i channels\n",
-           s->channels_for_cur_subframe);
+            s->channels_for_cur_subframe);
 
     /** calculate number of scale factor bands and their offsets */
     s->table_idx         = av_log2(s->samples_per_frame/subframe_len);
@@ -1086,7 +1104,7 @@ static int decode_subframe(WMAProDecodeCtx *s)
     for (i = 0; i < s->channels_for_cur_subframe; i++) {
         int c = s->channel_indexes_for_cur_subframe[i];
 
-        s->channel[c].coeffs = &s->channel[c].out[(s->samples_per_frame>>1)
+        s->channel[c].coeffs = &s->channel[c].out[(s->samples_per_frame >> 1)
                                                   + offset];
     }
 
@@ -1103,7 +1121,7 @@ static int decode_subframe(WMAProDecodeCtx *s)
 
         if (num_fill_bits >= 0) {
             if (get_bits_count(&s->gb) + num_fill_bits > s->num_saved_bits) {
-                av_log(s->avctx,AV_LOG_ERROR,"invalid number of fill bits\n");
+                av_log(s->avctx, AV_LOG_ERROR, "invalid number of fill bits\n");
                 return AVERROR_INVALIDDATA;
             }
 
@@ -1143,13 +1161,13 @@ static int decode_subframe(WMAProDecodeCtx *s)
             const int sign = (step == 31) - 1;
             int quant = 0;
             while (get_bits_count(&s->gb) + 5 < s->num_saved_bits &&
-                   (step = get_bits(&s->gb, 5)) == 31 ) {
-                     quant += 31;
+                   (step = get_bits(&s->gb, 5)) == 31) {
+                quant += 31;
             }
             quant_step += ((quant + step) ^ sign) - sign;
         }
         if (quant_step < 0) {
-            av_log(s->avctx,AV_LOG_DEBUG,"negative quant step\n");
+            av_log(s->avctx, AV_LOG_DEBUG, "negative quant step\n");
         }
 
         /** decode quantization step modifiers for every channel */
@@ -1163,8 +1181,7 @@ static int decode_subframe(WMAProDecodeCtx *s)
                 s->channel[c].quant_step = quant_step;
                 if (get_bits1(&s->gb)) {
                     if (modifier_len) {
-                        s->channel[c].quant_step +=
-                                get_bits(&s->gb, modifier_len) + 1;
+                        s->channel[c].quant_step += get_bits(&s->gb, modifier_len) + 1;
                     } else
                         ++s->channel[c].quant_step;
                 }
@@ -1177,21 +1194,21 @@ static int decode_subframe(WMAProDecodeCtx *s)
     }
 
     dprintf(s->avctx, "BITSTREAM: subframe header length was %i\n",
-           get_bits_count(&s->gb) - s->subframe_offset);
+            get_bits_count(&s->gb) - s->subframe_offset);
 
     /** parse coefficients */
     for (i = 0; i < s->channels_for_cur_subframe; i++) {
         int c = s->channel_indexes_for_cur_subframe[i];
         if (s->channel[c].transmit_coefs &&
-           get_bits_count(&s->gb) < s->num_saved_bits) {
-                decode_coeffs(s, c);
+            get_bits_count(&s->gb) < s->num_saved_bits) {
+            decode_coeffs(s, c);
         } else
             memset(s->channel[c].coeffs, 0,
                    sizeof(*s->channel[c].coeffs) * subframe_len);
     }
 
     dprintf(s->avctx, "BITSTREAM: subframe length was %i\n",
-           get_bits_count(&s->gb) - s->subframe_offset);
+            get_bits_count(&s->gb) - s->subframe_offset);
 
     if (transmit_coeffs) {
         /** reconstruct the per channel data */
@@ -1212,14 +1229,14 @@ static int decode_subframe(WMAProDecodeCtx *s)
                             (s->channel[c].max_scale_factor - *sf++) *
                             s->channel[c].scale_factor_step;
                 const float quant = pow(10.0, exp / 20.0);
-                int start;
-
-                for (start = s->cur_sfb_offsets[b]; start < end; start++)
-                    s->tmp[start] = s->channel[c].coeffs[start] * quant;
+                int start = s->cur_sfb_offsets[b];
+                s->dsp.vector_fmul_scalar(s->tmp + start,
+                                          s->channel[c].coeffs + start,
+                                          quant, end - start);
             }
 
             /** apply imdct (ff_imdct_half == DCTIV with reverse) */
-            ff_imdct_half(&s->mdct_ctx[av_log2(subframe_len)-BLOCK_MIN_BITS],
+            ff_imdct_half(&s->mdct_ctx[av_log2(subframe_len) - BLOCK_MIN_BITS],
                           s->channel[c].coeffs, s->tmp);
         }
     }
@@ -1231,7 +1248,7 @@ static int decode_subframe(WMAProDecodeCtx *s)
     for (i = 0; i < s->channels_for_cur_subframe; i++) {
         int c = s->channel_indexes_for_cur_subframe[i];
         if (s->channel[c].cur_subframe >= s->channel[c].num_subframes) {
-            av_log(s->avctx,AV_LOG_ERROR,"broken subframe\n");
+            av_log(s->avctx, AV_LOG_ERROR, "broken subframe\n");
             return AVERROR_INVALIDDATA;
         }
         ++s->channel[c].cur_subframe;
@@ -1255,7 +1272,8 @@ static int decode_frame(WMAProDecodeCtx *s)
 
     /** check for potential output buffer overflow */
     if (s->num_channels * s->samples_per_frame > s->samples_end - s->samples) {
-        av_log(s->avctx,AV_LOG_ERROR,
+        /** return an error if no frame could be decoded at all */
+        av_log(s->avctx, AV_LOG_ERROR,
                "not enough space for the output samples\n");
         s->packet_loss = 1;
         return 0;
@@ -1306,7 +1324,7 @@ static int decode_frame(WMAProDecodeCtx *s)
     }
 
     dprintf(s->avctx, "BITSTREAM: frame header length was %i\n",
-           get_bits_count(gb) - s->frame_offset);
+            get_bits_count(gb) - s->frame_offset);
 
     /** reset subframe states */
     s->parsed_all_subframes = 0;
@@ -1351,7 +1369,7 @@ static int decode_frame(WMAProDecodeCtx *s)
 
     if (len != (get_bits_count(gb) - s->frame_offset) + 2) {
         /** FIXME: not sure if this is always an error */
-        av_log(s->avctx,AV_LOG_ERROR,"frame[%i] would have to skip %i bits\n",
+        av_log(s->avctx, AV_LOG_ERROR, "frame[%i] would have to skip %i bits\n",
                s->frame_num, len - (get_bits_count(gb) - s->frame_offset) - 1);
         s->packet_loss = 1;
         return 0;
@@ -1373,7 +1391,7 @@ static int decode_frame(WMAProDecodeCtx *s)
  *@param gb bitstream reader context
  *@return remaining size in bits
  */
-static int remaining_bits(WMAProDecodeCtx *s, GetBitContextgb)
+static int remaining_bits(WMAProDecodeCtx *s, GetBitContext *gb)
 {
     return s->buf_bit_size - get_bits_count(gb);
 }
@@ -1386,7 +1404,7 @@ static int remaining_bits(WMAProDecodeCtx *s, GetBitContext* gb)
  *@param append decides wether to reset the buffer or not
  */
 static void save_bits(WMAProDecodeCtx *s, GetBitContext* gb, int len,
-                          int append)
+                      int append)
 {
     int buflen;
 
@@ -1403,14 +1421,15 @@ static void save_bits(WMAProDecodeCtx *s, GetBitContext* gb, int len,
     buflen = (s->num_saved_bits + len + 8) >> 3;
 
     if (len <= 0 || buflen > MAX_FRAMESIZE) {
-         av_log_ask_for_sample(s->avctx, "input buffer too small\n");
-         s->packet_loss = 1;
-         return;
+        av_log_ask_for_sample(s->avctx, "input buffer too small\n");
+        s->packet_loss = 1;
+        return;
     }
 
     s->num_saved_bits += len;
     if (!append) {
-        ff_copy_bits(&s->pb, gb->buffer + (get_bits_count(gb) >> 3), s->num_saved_bits);
+        ff_copy_bits(&s->pb, gb->buffer + (get_bits_count(gb) >> 3),
+                     s->num_saved_bits);
     } else {
         int align = 8 - (get_bits_count(gb) & 7);
         align = FFMIN(align, len);
@@ -1421,8 +1440,8 @@ static void save_bits(WMAProDecodeCtx *s, GetBitContext* gb, int len,
     skip_bits_long(gb, len);
 
     {
-    PutBitContext tmp = s->pb;
-    flush_put_bits(&tmp);
+        PutBitContext tmp = s->pb;
+        flush_put_bits(&tmp);
     }
 
     init_get_bits(&s->gb, s->frame_data, s->num_saved_bits);
@@ -1438,92 +1457,90 @@ static void save_bits(WMAProDecodeCtx *s, GetBitContext* gb, int len,
  *@return number of bytes that were read from the input buffer
  */
 static int decode_packet(AVCodecContext *avctx,
-                             void *data, int *data_size, AVPacket* avpkt)
+                         void *data, int *data_size, AVPacket* avpkt)
 {
-    GetBitContext gb;
     WMAProDecodeCtx *s = avctx->priv_data;
-    const uint8_t* buf   = avpkt->data;
-    int buf_size         = avpkt->size;
-    int more_frames      = 1;
+    GetBitContext* gb  = &s->pgb;
+    const uint8_t* buf = avpkt->data;
+    int buf_size       = avpkt->size;
     int num_bits_prev_frame;
     int packet_sequence_number;
 
-    s->samples      = data;
-    s->samples_end  = (float*)((int8_t*)data + *data_size);
-    s->buf_bit_size = buf_size << 3;
-
-
+    s->samples       = data;
+    s->samples_end   = (float*)((int8_t*)data + *data_size);
     *data_size = 0;
 
-    /** sanity check for the buffer length */
-    if (buf_size < avctx->block_align)
-        return 0;
-
-    buf_size = avctx->block_align;
+    if (s->packet_done || s->packet_loss) {
+        s->packet_done = 0;
+        s->buf_bit_size = buf_size << 3;
 
-    /** parse packet header */
-    init_get_bits(&gb, buf, s->buf_bit_size);
-    packet_sequence_number = get_bits(&gb, 4);
-    skip_bits(&gb, 2);
+        /** sanity check for the buffer length */
+        if (buf_size < avctx->block_align)
+            return 0;
 
-    /** get number of bits that need to be added to the previous frame */
-    num_bits_prev_frame = get_bits(&gb, s->log2_frame_size);
-    dprintf(avctx, "packet[%d]: nbpf %x\n", avctx->frame_number,
-                  num_bits_prev_frame);
+        buf_size = avctx->block_align;
 
-    /** check for packet loss */
-    if (!s->packet_loss &&
-        ((s->packet_sequence_number + 1)&0xF) != packet_sequence_number) {
-        s->packet_loss = 1;
-        av_log(avctx, AV_LOG_ERROR, "Packet loss detected! seq %x vs %x\n",
-                      s->packet_sequence_number, packet_sequence_number);
-    }
-    s->packet_sequence_number = packet_sequence_number;
-
-    if (num_bits_prev_frame > 0) {
-        /** append the previous frame data to the remaining data from the
-            previous packet to create a full frame */
-        save_bits(s, &gb, num_bits_prev_frame, 1);
-        dprintf(avctx, "accumulated %x bits of frame data\n",
-                      s->num_saved_bits - s->frame_offset);
-
-        /** decode the cross packet frame if it is valid */
-        if (!s->packet_loss)
-            decode_frame(s);
-    } else if (s->num_saved_bits - s->frame_offset) {
-        dprintf(avctx, "ignoring %x previously saved bits\n",
-                      s->num_saved_bits - s->frame_offset);
-    }
+        /** parse packet header */
+        init_get_bits(gb, buf, s->buf_bit_size);
+        packet_sequence_number = get_bits(gb, 4);
+        skip_bits(gb, 2);
 
-    s->packet_loss = 0;
-    /** decode the rest of the packet */
-    while (!s->packet_loss && more_frames &&
-          remaining_bits(s, &gb) > s->log2_frame_size) {
-        int frame_size = show_bits(&gb, s->log2_frame_size);
+        /** get number of bits that need to be added to the previous frame */
+        num_bits_prev_frame = get_bits(gb, s->log2_frame_size);
+        dprintf(avctx, "packet[%d]: nbpf %x\n", avctx->frame_number,
+                num_bits_prev_frame);
 
-        /** there is enough data for a full frame */
-        if (remaining_bits(s,&gb) >= frame_size && frame_size > 0) {
-            save_bits(s, &gb, frame_size, 0);
+        /** check for packet loss */
+        if (!s->packet_loss &&
+            ((s->packet_sequence_number + 1) & 0xF) != packet_sequence_number) {
+            s->packet_loss = 1;
+            av_log(avctx, AV_LOG_ERROR, "Packet loss detected! seq %x vs %x\n",
+                   s->packet_sequence_number, packet_sequence_number);
+        }
+        s->packet_sequence_number = packet_sequence_number;
+
+        if (num_bits_prev_frame > 0) {
+            /** append the previous frame data to the remaining data from the
+                previous packet to create a full frame */
+            save_bits(s, gb, num_bits_prev_frame, 1);
+            dprintf(avctx, "accumulated %x bits of frame data\n",
+                    s->num_saved_bits - s->frame_offset);
+
+            /** decode the cross packet frame if it is valid */
+            if (!s->packet_loss)
+                decode_frame(s);
+        } else if (s->num_saved_bits - s->frame_offset) {
+            dprintf(avctx, "ignoring %x previously saved bits\n",
+                    s->num_saved_bits - s->frame_offset);
+        }
 
-            /** decode the frame */
-            more_frames = decode_frame(s);
+        s->packet_loss = 0;
 
-            if (!more_frames) {
-                dprintf(avctx, "no more frames\n");
-            }
+    } else {
+        int frame_size;
+        s->buf_bit_size = avpkt->size << 3;
+        init_get_bits(gb, avpkt->data, s->buf_bit_size);
+        skip_bits(gb, s->packet_offset);
+        if (remaining_bits(s, gb) > s->log2_frame_size &&
+            (frame_size = show_bits(gb, s->log2_frame_size)) &&
+            frame_size <= remaining_bits(s, gb)) {
+            save_bits(s, gb, frame_size, 0);
+            s->packet_done = !decode_frame(s);
         } else
-            more_frames = 0;
+            s->packet_done = 1;
     }
 
-    if (!s->packet_loss && remaining_bits(s,&gb) > 0) {
+    if (s->packet_done && !s->packet_loss &&
+        remaining_bits(s, gb) > 0) {
         /** save the rest of the data so that it can be decoded
             with the next packet */
-        save_bits(s, &gb, remaining_bits(s,&gb), 0);
+        save_bits(s, gb, remaining_bits(s, gb), 0);
     }
 
     *data_size = (int8_t *)s->samples - (int8_t *)data;
+    s->packet_offset = get_bits_count(gb) & 7;
 
-    return avctx->block_align;
+    return (s->packet_loss) ? AVERROR_INVALIDDATA : get_bits_count(gb) >> 3;
 }
 
 /**
@@ -1555,6 +1572,7 @@ AVCodec wmapro_decoder = {
     NULL,
     decode_end,
     decode_packet,
+    .capabilities = CODEC_CAP_SUBFRAMES,
     .flush= flush,
     .long_name = NULL_IF_CONFIG_SMALL("Windows Media Audio 9 Professional"),
 };