]> rtime.felk.cvut.cz Git - opencv.git/commitdiff
added grabCut
authormdim <mdim@73c94f0f-984f-4a5f-82bc-2d8db8d8ee08>
Tue, 17 Nov 2009 16:44:16 +0000 (16:44 +0000)
committermdim <mdim@73c94f0f-984f-4a5f-82bc-2d8db8d8ee08>
Tue, 17 Nov 2009 16:44:16 +0000 (16:44 +0000)
git-svn-id: https://code.ros.org/svn/opencv/trunk@2295 73c94f0f-984f-4a5f-82bc-2d8db8d8ee08

opencv/include/opencv/cv.hpp
opencv/src/cv/cvgcgraph.hpp [new file with mode: 0644]
opencv/src/cv/cvgrabcut.cpp [new file with mode: 0644]

index 90a16c7049fbddb2dac8d41537b867013a0f7388..7aee44ed2f1d3cea090b6b7ffd90abc62ab19698 100644 (file)
@@ -432,6 +432,21 @@ CV_EXPORTS void equalizeHist( const Mat& src, Mat& dst );
 
 CV_EXPORTS void watershed( const Mat& image, Mat& markers );
 
+enum { GC_BGD    = 0,  // background
+GC_FGD    = 1,  // foreground
+GC_PR_BGD = 2,  // most probably background
+GC_PR_FGD = 3   // most probably foreground 
+};
+
+enum { GC_INIT_WITH_RECT  = 0,
+GC_INIT_WITH_MASK  = 1,
+GC_EVAL            = 2
+};
+
+CV_EXPORTS void grabCut( const Mat& img, Mat& mask, Rect rect, 
+                        Mat& bgdModel, Mat& fgdModel,
+                        int iterCount, int flag = GC_EVAL );
+
 enum { INPAINT_NS=CV_INPAINT_NS, INPAINT_TELEA=CV_INPAINT_TELEA };
 
 CV_EXPORTS void inpaint( const Mat& src, const Mat& inpaintMask,
@@ -1006,7 +1021,6 @@ struct CvLSHOperations
   virtual int hash_lookup(lsh_hash h, int l, int* ret_i, int ret_i_max) = 0;
 };
 
-
 #endif /* __cplusplus */
 
 #endif /* _CV_HPP_ */
diff --git a/opencv/src/cv/cvgcgraph.hpp b/opencv/src/cv/cvgcgraph.hpp
new file mode 100644 (file)
index 0000000..6cc9ef5
--- /dev/null
@@ -0,0 +1,333 @@
+#pragma once;\r
+using namespace std;\r
+template <class TWeight> class GCGraph\r
+{\r
+public:\r
+    GCGraph();\r
+    GCGraph( unsigned int vtxCount, unsigned int edgeCount );\r
+    ~GCGraph();\r
+    int addVtx();\r
+    void addEdges( int i, int j, TWeight w, TWeight revw );\r
+    void addTermWeights( int i, TWeight sourceW, TWeight sinkW );\r
+    TWeight maxFlow();\r
+    bool inSourceSegment( int i );\r
+private:\r
+    class Vtx\r
+    {\r
+    public:\r
+        Vtx *next; // initialized and used in maxFlow() only
+        int parent;
+        int first;
+        int ts;
+        int dist;
+        TWeight weight;
+        uchar t; \r
+    };\r
+    class Edge\r
+    {\r
+    public:\r
+        int dst;
+        int next;
+        TWeight weight;\r
+    };\r
+\r
+    vector<Vtx> vtcs;\r
+    vector<Edge> edges;\r
+    TWeight flow;\r
+};\r
+\r
+template <class TWeight>\r
+GCGraph<TWeight>::GCGraph()\r
+{\r
+    flow = 0;\r
+}\r
+template <class TWeight>\r
+GCGraph<TWeight>::GCGraph( unsigned int vtxCount, unsigned int edgeCount )\r
+{\r
+    vtcs.reserve( vtxCount );\r
+    edges.reserve( edgeCount );\r
+    flow = 0;\r
+}\r
+template <class TWeight>\r
+GCGraph<TWeight>::~GCGraph()\r
+{\r
+}\r
+\r
+template <class TWeight>\r
+int GCGraph<TWeight>::addVtx()\r
+{\r
+    Vtx v;\r
+    memset( &v, 0, sizeof(Vtx));\r
+    vtcs.push_back(v);\r
+    return (int)vtcs.size() - 1;\r
+}\r
+\r
+template <class TWeight>\r
+void GCGraph<TWeight>::addEdges( int i, int j, TWeight w, TWeight revw )\r
+{\r
+    CV_Assert( i>=0 && i<(int)vtcs.size() );\r
+    CV_Assert( j>=0 && j<(int)vtcs.size() );\r
+    CV_Assert( w>=0 && revw>=0 );\r
+    CV_Assert( i != j );\r
+\r
+    Edge fromI, toI;\r
+    fromI.dst = j;\r
+    fromI.next = vtcs[i].first;\r
+    fromI.weight = w;\r
+    vtcs[i].first = (int)edges.size();\r
+    edges.push_back( fromI );\r
+\r
+    toI.dst = i;\r
+    toI.next = vtcs[j].first;\r
+    toI.weight = revw;\r
+    vtcs[j].first = (int)edges.size();\r
+    edges.push_back( toI );\r
+}\r
+\r
+template <class TWeight>\r
+void GCGraph<TWeight>::addTermWeights( int i, TWeight sourceW, TWeight sinkW )\r
+{\r
+    CV_Assert( i>=0 && i<(int)vtcs.size() );\r
+\r
+    TWeight dw = vtcs[i].weight;\r
+    if( dw > 0 )\r
+        sourceW += dw;\r
+    else\r
+        sinkW -= dw;\r
+    flow += (sourceW < sinkW) ? sourceW : sinkW;\r
+    vtcs[i].weight = sourceW - sinkW;\r
+}\r
+\r
+template <class TWeight>\r
+TWeight GCGraph<TWeight>::maxFlow()
+{
+    const int TERMINAL = -1, ORPHAN = -2;
+    Vtx stub, *nilNode = &stub, *first = nilNode, *last = nilNode;
+    int curr_ts = 0;
+    stub.next = nilNode;
+    Vtx *vtxPtr = &vtcs[0];
+    Edge *edgePtr = &edges[0];
+
+    vector<Vtx*> orphans;
+
+    // initialize the active queue and the graph vertices
+    for( int i = 0; i < (int)vtcs.size(); i++ )
+    {
+        Vtx* v = vtxPtr + i;
+        v->ts = 0;
+        if( v->weight != 0 )
+        {
+            last = last->next = v;
+            v->dist = 1;
+            v->parent = TERMINAL;
+            v->t = v->weight < 0;
+        }
+        else
+            v->parent = 0;        
+    }
+    first = first->next;
+    last->next = nilNode;
+    nilNode->next = 0;
+
+    // run the search-path -> augment-graph -> restore-trees loop
+    for(;;)
+    {
+        Vtx* v, *u;
+        int e0 = -1, ei = 0, ej = 0;
+        TWeight minWeight, weight;
+        uchar vt;
+
+        // grow S & T search trees, find an edge connecting them
+        while( first != nilNode )
+        {
+            v = first;
+            if( v->parent )
+            {
+                vt = v->t;
+                for( ei = v->first; ei != 0; ei = edgePtr[ei].next )
+                {
+                    if( edgePtr[ei^vt].weight == 0 )
+                        continue;
+                    int aN = (int)edges.size();
+                    u = vtxPtr+edgePtr[ei].dst;
+                    if( !u->parent )
+                    {
+                        u->t = vt;
+                        u->parent = ei ^ 1;
+                        u->ts = v->ts;
+                        u->dist = v->dist + 1;
+                        if( !u->next )
+                        {
+                            u->next = nilNode;
+                            last = last->next = u;
+                        }
+                        continue;
+                    }
+
+                    if( u->t != vt )
+                    {
+                        e0 = ei ^ vt;
+                        break;
+                    }
+
+                    if( u->dist > v->dist+1 && u->ts <= v->ts )
+                    {
+                        // reassign the parent
+                        u->parent = ei ^ 1;
+                        u->ts = v->ts;
+                        u->dist = v->dist + 1;
+                    }
+                }
+                if( e0 > 0 )
+                    break;
+            }
+            // exclude the vertex from the active list
+            first = first->next;
+            v->next = 0;
+        }
+
+        if( e0 <= 0 )
+            break;
+
+        // find the minimum edge weight along the path
+        minWeight = edgePtr[e0].weight;
+        assert( minWeight > 0 );
+        // k = 1: source tree, k = 0: destination tree
+        for( int k = 1; k >= 0; k-- )
+        {
+            for( v = vtxPtr+edgePtr[e0^k].dst;; v = vtxPtr+edgePtr[ei].dst )
+            {
+                if( (ei = v->parent) < 0 )
+                    break;
+                weight = edgePtr[ei^k].weight;
+                minWeight = MIN(minWeight, weight);
+                assert( minWeight > 0 );
+            }
+            weight = fabs(v->weight);
+            minWeight = MIN(minWeight, weight);
+            assert( minWeight > 0 );
+        }
+
+        // modify weights of the edges along the path and collect orphans
+        edgePtr[e0].weight -= minWeight;
+        edgePtr[e0^1].weight += minWeight;
+        flow += minWeight;
+
+        // k = 1: source tree, k = 0: destination tree
+        for( int k = 1; k >= 0; k-- )
+        {
+            for( v = vtxPtr+edgePtr[e0^k].dst;; v = vtxPtr+edgePtr[ei].dst )
+            {
+                if( (ei = v->parent) < 0 )
+                    break;
+                edgePtr[ei^(k^1)].weight += minWeight;
+                if( (edgePtr[ei^k].weight -= minWeight) == 0 )
+                {
+                    orphans.push_back(v);
+                    v->parent = ORPHAN;
+                }
+            }
+
+            v->weight = v->weight + minWeight*(1-k*2);
+            if( v->weight == 0 )
+            {
+               orphans.push_back(v);
+               v->parent = ORPHAN;
+            }
+        }
+
+        // restore the search trees by finding new parents for the orphans
+        curr_ts++;
+        while( !orphans.empty() )
+        {
+            Vtx* v = orphans.back();
+            orphans.pop_back();
+
+            int d, minDist = INT_MAX;
+            e0 = 0;
+            vt = v->t;
+
+            for( ei = v->first; ei != 0; ei = edgePtr[ei].next )
+            {
+                if( edgePtr[ei^(vt^1)].weight == 0 )
+                    continue;
+                u = vtxPtr+edgePtr[ei].dst;
+                if( u->t != vt || u->parent == 0 )
+                    continue;
+                // compute the distance to the tree root
+                for( d = 0;; )
+                {
+                    if( u->ts == curr_ts )
+                    {
+                        d += u->dist;
+                        break;
+                    }
+                    ej = u->parent;
+                    d++;
+                    if( ej < 0 )
+                    {
+                        if( ej == ORPHAN )
+                            d = INT_MAX-1;
+                        else
+                        {
+                            u->ts = curr_ts;
+                            u->dist = 1;
+                        }
+                        break;
+                    }
+                    u = vtxPtr+edgePtr[ej].dst;
+                }
+
+                // update the distance
+                if( ++d < INT_MAX )
+                {
+                    if( d < minDist )
+                    {
+                        minDist = d;
+                        e0 = ei;
+                    }
+                    for( u = vtxPtr+edgePtr[ei].dst; u->ts != curr_ts; u = vtxPtr+edgePtr[u->parent].dst )
+                    {
+                        u->ts = curr_ts;
+                        u->dist = --d;
+                    }
+                }
+            }
+
+            if( (v->parent = e0) > 0 )
+            {
+                v->ts = curr_ts;
+                v->dist = minDist;
+                continue;
+            }
+
+            /* no parent is found */
+            v->ts = 0;
+            for( ei = v->first; ei != 0; ei = edgePtr[ei].next )
+            {
+                u = vtxPtr+edgePtr[ei].dst;
+                ej = u->parent;
+                if( u->t != vt || !ej )
+                    continue;
+                if( edgePtr[ei^(vt^1)].weight && !u->next )
+                {
+                    u->next = nilNode;
+                    last = last->next = u;
+                }
+                if( ej > 0 && vtxPtr+edgePtr[ej].dst == v )
+                {
+                    orphans.push_back(u);
+                    u->parent = ORPHAN;
+                }
+            }
+        }
+    }
+    return flow;
+}\r
+
+template <class TWeight>
+bool GCGraph<TWeight>::inSourceSegment( int i )
+{
+    CV_Assert( i>=0 && i<(int)vtcs.size() );
+    return vtcs[i].t == 0;
+};
\ No newline at end of file
diff --git a/opencv/src/cv/cvgrabcut.cpp b/opencv/src/cv/cvgrabcut.cpp
new file mode 100644 (file)
index 0000000..db93845
--- /dev/null
@@ -0,0 +1,502 @@
+/*M///////////////////////////////////////////////////////////////////////////////////////\r
+//\r
+//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.\r
+//\r
+//  By downloading, copying, installing or using the software you agree to this license.\r
+//  If you do not agree to this license, do not download, install,\r
+//  copy or use the software.\r
+//\r
+//\r
+//                        Intel License Agreement\r
+//                For Open Source Computer Vision Library\r
+//\r
+// Copyright (C) 2000, Intel Corporation, all rights reserved.\r
+// Third party copyrights are property of their respective owners.\r
+//\r
+// Redistribution and use in source and binary forms, with or without modification,\r
+// are permitted provided that the following conditions are met:\r
+//\r
+//   * Redistribution's of source code must retain the above copyright notice,\r
+//     this list of conditions and the following disclaimer.\r
+//\r
+//   * Redistribution's in binary form must reproduce the above copyright notice,\r
+//     this list of conditions and the following disclaimer in the documentation\r
+//     and/or other materials provided with the distribution.\r
+//\r
+//   * The name of Intel Corporation may not be used to endorse or promote products\r
+//     derived from this software without specific prior written permission.\r
+//\r
+// This software is provided by the copyright holders and contributors "as is" and\r
+// any express or implied warranties, including, but not limited to, the implied\r
+// warranties of merchantability and fitness for a particular purpose are disclaimed.\r
+// In no event shall the Intel Corporation or contributors be liable for any direct,\r
+// indirect, incidental, special, exemplary, or consequential damages\r
+// (including, but not limited to, procurement of substitute goods or services;\r
+// loss of use, data, or profits; or business interruption) however caused\r
+// and on any theory of liability, whether in contract, strict liability,\r
+// or tort (including negligence or otherwise) arising in any way out of\r
+// the use of this software, even if advised of the possibility of such damage.\r
+//\r
+//M*/\r
+\r
+#include "_cv.h"\r
+#include "cvgcgraph.hpp"\r
+\r
+using namespace cv;\r
+inline Vec3f cvtVec( const Vec3b& v3b )\r
+{\r
+    Vec3f v3f;\r
+    v3f[0] = v3b[0]; v3f[1] = v3b[1]; v3f[2] = v3b[2];\r
+    return v3f;\r
+}\r
+\r
+class GMM\r
+{\r
+public:\r
+    static const uchar K = 5;\r
+\r
+    GMM( Mat& model );\r
+    float operator()( Vec3b color ) const;\r
+    float operator()( uchar ci, Vec3b color ) const;\r
+    uchar whatComponent( Vec3b color ) const;\r
+\r
+    void startLearning();\r
+    void addSample( uchar ci, Vec3b color );\r
+    void endLearning();\r
+private:\r
+    float* coefs;\r
+    float* mean;\r
+    float* cov;\r
+\r
+    float inverseCov[K][3][3];\r
+    float covDeterm[K];\r
+\r
+    float sum[3*K];\r
+    float prod[9*K];\r
+    int count[K];\r
+    int totalCount;\r
+};\r
+\r
+GMM::GMM( Mat& model )\r
+{\r
+    if( model.empty() )\r
+    {\r
+        model.create( 1, 13*K, CV_32FC1 );\r
+        model.setTo(Scalar(0));\r
+    }\r
+    else if( (model.type() != CV_32FC1) || (model.rows != 1) || (model.cols != 13*K) )\r
+        CV_Error( CV_StsBadArg, "model must have CV_32FC1 type, rows == 1 and cols == 13*K" );\r
+    coefs = model.ptr<float>(0);\r
+    mean = coefs + K;\r
+    cov = mean + 3*K;\r
+\r
+    for( uchar ci = 0; ci < K; ci++ )\r
+    {\r
+        if( coefs[ci] > 0 )\r
+        {\r
+            float *c = cov + 9*ci;\r
+            float dtrm = covDeterm[ci] = c[0]*(c[4]*c[8] - c[5]*c[7])\r
+                - c[1]*(c[3]*c[8] - c[5]*c[6]) + c[2]*(c[3]*c[7] - c[4]*c[6]);\r
+\r
+            if( dtrm > FLT_EPSILON )\r
+            {\r
+                inverseCov[ci][0][0] =  (c[4]*c[8] - c[5]*c[7]) / dtrm;\r
+                inverseCov[ci][1][0] = -(c[3]*c[8] - c[5]*c[6]) / dtrm;\r
+                inverseCov[ci][2][0] =  (c[3]*c[7] - c[4]*c[6]) / dtrm;\r
+                inverseCov[ci][0][1] = -(c[1]*c[8] - c[2]*c[7]) / dtrm;\r
+                inverseCov[ci][1][1] =  (c[0]*c[8] - c[2]*c[6]) / dtrm;\r
+                inverseCov[ci][2][1] = -(c[0]*c[7] - c[1]*c[6]) / dtrm;\r
+                inverseCov[ci][0][2] =  (c[1]*c[5] - c[2]*c[4]) / dtrm;\r
+                inverseCov[ci][1][2] = -(c[0]*c[5] - c[2]*c[3]) / dtrm;\r
+                inverseCov[ci][2][2] =  (c[0]*c[4] - c[1]*c[3]) / dtrm;\r
+            }\r
+\r
+        }\r
+    }\r
+}\r
+\r
+float GMM::operator()( Vec3b color ) const\r
+{\r
+    float res = 0;\r
+    for( uchar ki = 0; ki < GMM::K; ki++ )\r
+        res += coefs[ki] * this->operator()(ki, color );\r
+    return res;\r
+}\r
+\r
+float GMM::operator()( uchar ci, Vec3b color ) const\r
+{\r
+    float res = 0;\r
+    if( coefs[ci] > 0 )\r
+    {\r
+        if( covDeterm[ci] > FLT_EPSILON )\r
+        {\r
+            Vec3f diff = cvtVec(color);\r
+            float* m = mean + 3*ci;\r
+            diff[0] -= m[0]; diff[1] -= m[1]; diff[2] -= m[2];\r
+            float mult = diff[0]*(diff[0]*inverseCov[ci][0][0] + diff[1]*inverseCov[ci][1][0] + diff[2]*inverseCov[ci][2][0])\r
+                + diff[1]*(diff[0]*inverseCov[ci][0][1] + diff[1]*inverseCov[ci][1][1] + diff[2]*inverseCov[ci][2][1])\r
+                + diff[2]*(diff[0]*inverseCov[ci][0][2] + diff[1]*inverseCov[ci][1][2] + diff[2]*inverseCov[ci][2][2]);\r
+            res = (1.0f/(sqrt(covDeterm[ci])) * exp(-0.5f*mult));\r
+        }\r
+    }\r
+    return res;\r
+}\r
+\r
+uchar GMM::whatComponent( Vec3b color ) const\r
+{\r
+    uchar k = 0;\r
+    float max = 0;\r
+\r
+    for( uchar i = 0; i < K; i++ )\r
+    {\r
+        float p = this->operator()( i, color );\r
+        if( p > max )\r
+        {\r
+            k = i;\r
+            max = p;\r
+        }\r
+    }\r
+    return k;\r
+}\r
+\r
+void GMM::startLearning()\r
+{\r
+    memset( &sum, 0, 3*K*sizeof(sum[0]) );\r
+    memset( &prod, 0, 9*K*sizeof(prod[0]) );\r
+    memset( &count, 0, K*sizeof(count[0]) );\r
+    totalCount = 0;\r
+}\r
+\r
+void GMM::addSample( uchar ci, Vec3b color )\r
+{\r
+    float* s = sum + 3*ci;\r
+    s[0] += color[0], s[1] += color[1], s[2] += color[2];\r
+    float* p = prod + 9*ci;\r
+    p[0] += color[0]*color[0], p[1] += color[0]*color[1], p[2] += color[0]*color[2];\r
+    p[3] += color[1]*color[0], p[4] += color[1]*color[1], p[5] += color[1]*color[2];\r
+    p[6] += color[2]*color[0], p[7] += color[2]*color[1], p[8] += color[2]*color[2];\r
+    count[ci]++;\r
+    totalCount++;\r
+}\r
+\r
+void GMM::endLearning()\r
+{\r
+    for( uchar ci = 0; ci < K; ci++ )\r
+    {\r
+        if( count[ci] == 0 )\r
+        {\r
+            coefs[ci] = 0;\r
+        }\r
+        else\r
+        {\r
+            int n = count[ci];\r
+            coefs[ci] = (float)count[ci]/totalCount;\r
+            float* m = mean + 3*ci, *s = sum + 3*ci;\r
+            m[0] = s[0]/n; m[1] = s[1]/n; m[2] = s[2]/n;\r
+            \r
+            float* c = cov + 9*ci, *p = prod + 9*ci;\r
+            c[0] = p[0]/n - m[0]*m[0], c[1] = p[1]/n - m[0]*m[1], c[2] = p[2]/n - m[0]*m[2];\r
+            c[3] = p[3]/n - m[1]*m[0], c[4] = p[4]/n - m[1]*m[1], c[5] = p[5]/n - m[1]*m[2];\r
+            c[6] = p[6]/n - m[2]*m[0], c[7] = p[7]/n - m[2]*m[1], c[8] = p[8]/n - m[2]*m[2];\r
+\r
+            float dtrm = covDeterm[ci] = c[0]*(c[4]*c[8] - c[5]*c[7])\r
+                - c[1]*(c[3]*c[8] - c[5]*c[6]) + c[2]*(c[3]*c[7] - c[4]*c[6]);\r
+\r
+            if( dtrm > FLT_EPSILON )\r
+            {\r
+                inverseCov[ci][0][0] =  (c[4]*c[8] - c[5]*c[7]) / dtrm;\r
+                inverseCov[ci][1][0] = -(c[3]*c[8] - c[5]*c[6]) / dtrm;\r
+                inverseCov[ci][2][0] =  (c[3]*c[7] - c[4]*c[6]) / dtrm;\r
+                inverseCov[ci][0][1] = -(c[1]*c[8] - c[2]*c[7]) / dtrm;\r
+                inverseCov[ci][1][1] =  (c[0]*c[8] - c[2]*c[6]) / dtrm;\r
+                inverseCov[ci][2][1] = -(c[0]*c[7] - c[1]*c[6]) / dtrm;\r
+                inverseCov[ci][0][2] =  (c[1]*c[5] - c[2]*c[4]) / dtrm;\r
+                inverseCov[ci][1][2] = -(c[0]*c[5] - c[2]*c[3]) / dtrm;\r
+                inverseCov[ci][2][2] =  (c[0]*c[4] - c[1]*c[3]) / dtrm;\r
+            }\r
+        }\r
+    }\r
+}\r
+\r
+float calcBeta( const Mat& img )\r
+{\r
+    float beta = 0;\r
+    Point p;\r
+    for( p.y = 0; p.y < img.rows; p.y++ )\r
+    {\r
+        for( p.x = 0; p.x < img.cols; p.x++ )\r
+        {\r
+            Vec3f color = cvtVec(img.at<Vec3b>(p));\r
+            if( p.x>0 ) // left\r
+            {\r
+                Vec3f diff = color - cvtVec(img.at<Vec3b>(p.y, p.x-1));\r
+                beta += diff.dot(diff);\r
+            }\r
+            if( p.y>0 && p.x>0 ) // upleft\r
+            {\r
+                Vec3f diff = color - cvtVec(img.at<Vec3b>(p.y-1, p.x-1));\r
+                beta += diff.dot(diff);\r
+            }\r
+            if( p.y>0 ) // up\r
+            {\r
+                Vec3f diff = color - cvtVec(img.at<Vec3b>(p.y-1, p.x));\r
+                beta += diff.dot(diff);\r
+            }\r
+            if( p.y>0 && p.x<img.cols-1) // upright\r
+            {\r
+                Vec3f diff = color - cvtVec(img.at<Vec3b>(p.y-1, p.x+1));\r
+                beta += diff.dot(diff);\r
+            }\r
+        }\r
+    }\r
+    beta = 0.5f*(4*img.cols*img.rows - 3*img.cols - 3*img.rows + 2)/beta;\r
+    return beta;\r
+}\r
+\r
+void calcNWeights( const Mat& img, Mat& left, Mat& upleft, Mat& up, Mat& upright, float beta, float gamma )\r
+{\r
+    const float sqrt2 = sqrt(2.0f);\r
+    left.create( img.rows, img.cols, CV_32FC1 );\r
+    upleft.create( img.rows, img.cols, CV_32FC1 );\r
+    up.create( img.rows, img.cols, CV_32FC1 );\r
+    upright.create( img.rows, img.cols, CV_32FC1 );\r
+    Point p, p2;\r
+    Vec3b c, diff;\r
+    for( p.y = 0; p.y < img.rows; p.y++ )\r
+    {\r
+        for( p.x = 0; p.x < img.cols; p.x++ )\r
+        {\r
+            c = img.at<Vec3b>(p);\r
+\r
+            p2.y = p.y; // left            \r
+            p2.x = p.x-1;\r
+            if( p2.x>=0 )\r
+            {\r
+                diff = c - img.at<Vec3b>(p2);\r
+                left.at<float>(p) = gamma * exp(-beta*diff.dot(diff));\r
+            }\r
+            else\r
+                left.at<float>(p) = 0;\r
+\r
+            p2.y = p.y-1; // upleft\r
+            p2.x = p.x-1;\r
+            if( p2.x>=0 && p2.y>=0 )\r
+            {\r
+                diff = c - img.at<Vec3b>(p2);\r
+                upleft.at<float>(p) = gamma * exp(-beta*diff.dot(diff)) / sqrt2;\r
+            }\r
+            else\r
+                upleft.at<float>(p) = 0;\r
+\r
+            p2.y = p.y-1; // up\r
+            p2.x = p.x;\r
+            if( p2.y>=0 )\r
+            {\r
+                diff = c - img.at<Vec3b>(p2);\r
+                float res = gamma * exp(-beta*diff.dot(diff));\r
+                up.at<float>(p) = gamma * exp(-beta*diff.dot(diff));\r
+            }\r
+            else\r
+                up.at<float>(p) = 0;\r
+\r
+            p2.y = p.y-1; // upright\r
+            p2.x = p.x+1;\r
+            if( p2.x<img.cols-1 && p2.y>=0 )\r
+            {\r
+                diff = c - img.at<Vec3b>(p2);\r
+                upright.at<float>(p) = gamma * exp(-beta*diff.dot(diff)) / sqrt2;\r
+            }\r
+            else\r
+                upright.at<float>(p) = 0;\r
+        }\r
+    }\r
+}\r
+\r
+void grabCut( const Mat& img, Mat& mask, Rect rect, \r
+             Mat& bgdModel, Mat& fgdModel,\r
+             int iterCount, int flag )\r
+{\r
+    if( img.empty() )\r
+        CV_Error( CV_StsBadArg, "image is empty" );\r
+    if( img.type() != CV_8UC3 )\r
+        CV_Error( CV_StsBadArg, "image mush have CV_8UC3 type" );\r
+\r
+    const int KMI = 10;\r
+    const int KMT = KMEANS_PP_CENTERS;\r
+    const float gamma = 50;\r
+    const float lambda = 9*gamma;\r
+    const float beta = calcBeta( img );\r
+\r
+    Mat left, upleft, up, upright;\r
+    calcNWeights( img, left, upleft, up, upright, beta, gamma );\r
+\r
+    GMM bgdGMM( bgdModel ), fgdGMM( fgdModel );\r
+    Mat cidx( img.size(), CV_8UC1 );\r
+\r
+    Point p;\r
+    if( flag == GC_INIT_WITH_RECT || flag == GC_INIT_WITH_MASK )\r
+    {\r
+        if( flag == GC_INIT_WITH_RECT )\r
+        {\r
+            mask.create( img.size(), CV_8UC1 );\r
+            mask.setTo( GC_BGD );\r
+\r
+            rect.x = max(0, rect.x);\r
+            rect.y = max(0, rect.y);\r
+            rect.width = min(rect.width, img.cols-rect.x);\r
+            rect.height = min(rect.height, img.rows-rect.y);\r
+\r
+            Mat maskRect = mask( Range(rect.y, rect.y + rect.height), Range(rect.x, rect.x + rect.width) );\r
+            maskRect.setTo( Scalar(GC_PR_FGD) );\r
+        }\r
+        else // flag == GC_INIT_WITH_MASK \r
+        {\r
+            if( mask.empty() )\r
+                CV_Error( CV_StsBadArg, "mask is empty" );\r
+            if( mask.type() != CV_8UC1 )\r
+                CV_Error( CV_StsBadArg, "mask mush have CV_8UC1 type" );\r
+            if( mask.cols != img.cols || mask.rows != img.rows )\r
+                CV_Error( CV_StsBadArg, "mask mush have rows and cols as img" );\r
+            // TODO check mask elements? ( GC_BGD and GC_FGD only )\r
+        }\r
+\r
+        // init GMMs\r
+        Mat bgdLabels, fgdLabels;\r
+        vector<Vec3f> bgdSamples, fgdSamples;\r
+        for( p.y = 0; p.y < img.rows; p.y++ )\r
+        {\r
+            for( p.x = 0; p.x < img.cols; p.x++ )\r
+            {\r
+                if( mask.at<uchar>(p) == GC_BGD )\r
+                    bgdSamples.push_back( cvtVec(img.at<Vec3b>(p)) );\r
+                else // GC_PR_BGD | GC_FGD | GC_PR_FGD\r
+                    fgdSamples.push_back( cvtVec(img.at<Vec3b>(p)) );\r
+            }\r
+        }\r
+        CV_Assert( !bgdSamples.empty() && !fgdSamples.empty() );\r
+        Mat _bgdSamples( (int)bgdSamples.size(), 3, CV_32FC1, &bgdSamples[0][0] );\r
+        kmeans( _bgdSamples, GMM::K, bgdLabels, TermCriteria( CV_TERMCRIT_ITER, KMI, 0.0), 0, KMT, 0 );\r
+        Mat _fgdSamples( (int)fgdSamples.size(), 3, CV_32FC1, &fgdSamples[0][0] );\r
+        kmeans( _fgdSamples, GMM::K, fgdLabels, TermCriteria( CV_TERMCRIT_ITER, KMI, 0.0), 0, KMT, 0 );\r
+\r
+        bgdGMM.startLearning();\r
+        for( int i = 0; i < (int)bgdSamples.size(); i++ )\r
+            bgdGMM.addSample( bgdLabels.at<int>(i,0), bgdSamples[i] );\r
+        bgdGMM.endLearning();\r
+\r
+        fgdGMM.startLearning();\r
+        for( int i = 0; i < (int)fgdSamples.size(); i++ )\r
+            fgdGMM.addSample( fgdLabels.at<int>(i,0), fgdSamples[i] );\r
+        fgdGMM.endLearning();\r
+    }\r
+    \r
+    // TODO check mask\r
+\r
+    for( int i = 0; i < iterCount; i++ )\r
+    {\r
+        // assign GMMs components\r
+        for( p.y = 0; p.y < img.rows; p.y++ )\r
+        {\r
+            for( p.x = 0; p.x < img.cols; p.x++ )\r
+            {\r
+                Vec3b color = img.at<Vec3b>(p);\r
+                bool b = mask.at<uchar>(p) == GC_BGD || mask.at<uchar>(p) == GC_PR_BGD;\r
+                cidx.at<uchar>(p) = mask.at<uchar>(p) == GC_BGD || mask.at<uchar>(p) == GC_PR_BGD ?\r
+                    bgdGMM.whatComponent(color) : fgdGMM.whatComponent(color);\r
+            }\r
+        }\r
+\r
+        // learn GMMs parameters\r
+        bgdGMM.startLearning();\r
+        fgdGMM.startLearning();\r
+        for( uchar ci = 0; ci < GMM::K; ci++ )\r
+        {\r
+            for( p.y = 0; p.y < img.rows; p.y++ )\r
+            {\r
+                for( p.x = 0; p.x < img.cols; p.x++ )\r
+                {\r
+                    uchar c = cidx.at<uchar>(p);\r
+                    if( cidx.at<uchar>(p) == ci )\r
+                    {\r
+                        if( mask.at<uchar>(p) == GC_BGD || mask.at<uchar>(p) == GC_PR_BGD )    \r
+                            bgdGMM.addSample( ci, img.at<Vec3b>(p) );\r
+                        else\r
+                            fgdGMM.addSample( ci, img.at<Vec3b>(p) );\r
+                    }\r
+                }\r
+            }\r
+        }\r
+        bgdGMM.endLearning();\r
+        fgdGMM.endLearning();\r
+\r
+        // estimate segmentation\r
+        GCGraph<float> graph( img.cols*img.rows, 8*img.cols*img.rows - 6*(img.cols + img.rows) + 4 );\r
+        for( p.y = 0; p.y < img.rows; p.y++ )\r
+        {\r
+            for( p.x = 0; p.x < img.cols; p.x++)\r
+            {\r
+                // add node\r
+                graph.addVtx();\r
+                int idx = p.y*img.cols+p.x;\r
+                Vec3b color = img.at<Vec3b>(p);\r
+\r
+                // set t-weights\r
+                float fromSource, toSink;\r
+                if( mask.at<uchar>(p) == GC_PR_BGD || mask.at<uchar>(p) == GC_PR_FGD )\r
+                {\r
+                    fromSource = -log( bgdGMM(color) );\r
+                    toSink = -log( fgdGMM(color) );\r
+                }\r
+                else if( mask.at<uchar>(p) == GC_BGD )\r
+                {\r
+                    fromSource = 0;\r
+                    toSink = lambda;\r
+                }\r
+                else // GC_BGD\r
+                {\r
+                    fromSource = lambda;\r
+                    toSink = 0;\r
+                }\r
+                graph.addTermWeights( idx, fromSource, toSink );\r
+\r
+                // 3. set n-weights\r
+                if( p.x>0 )\r
+                {\r
+                    float w = left.at<float>(p);\r
+                    graph.addEdges( idx, idx-1, w, w );\r
+                }\r
+                if( p.x>0 && p.y>0 )\r
+                {\r
+                    float w = upleft.at<float>(p);\r
+                    graph.addEdges( idx, idx-img.cols-1, w, w );\r
+                }\r
+                if( p.y>0 )\r
+                {\r
+                    float w = up.at<float>(p);\r
+                    graph.addEdges( idx, idx-img.cols, w, w );\r
+                }\r
+                if( p.x<img.cols-1 && p.y>0 )\r
+                {\r
+                    float w = upright.at<float>(p);\r
+                    graph.addEdges( idx, idx-img.cols+1, w, w );\r
+                }\r
+            }\r
+        }\r
+\r
+        graph.maxFlow();\r
+\r
+        for( p.y = 0; p.y < img.rows; p.y++ )\r
+        {\r
+            for( p.x = 0; p.x < img.cols; p.x++ )\r
+            {\r
+                if( mask.at<uchar>(p) == GC_PR_BGD || mask.at<uchar>(p) == GC_PR_FGD )\r
+                {\r
+                    if( graph.inSourceSegment( p.y*img.cols+p.x) )\r
+                        mask.at<uchar>(p) = GC_PR_FGD;\r
+                    else\r
+                        mask.at<uchar>(p) = GC_PR_BGD;\r
+                }\r
+            }\r
+        }\r
+    }\r
+}
\ No newline at end of file