add the CPU backend

This commit is contained in:
Ondrej Jamriska
2018-09-14 14:01:45 +02:00
parent 7742e68bd1
commit 93beddca07
21 changed files with 4195 additions and 2894 deletions

View File

@@ -28,6 +28,7 @@ ebsynth -style <style.png> -guide <source.png> <target.png> -output <output.png>
-pyramidlevels <number> -pyramidlevels <number>
-searchvoteiters <number> -searchvoteiters <number>
-patchmatchiters <number> -patchmatchiters <number>
-backend [cpu|cuda]
``` ```
## Download ## Download
@@ -129,10 +130,6 @@ equalized to match the luminance of the source painting.
-------------------------------------------------------------------------- --------------------------------------------------------------------------
## Requirements
`ebsynth` needs a CUDA-capable gpu in order to run. Besides CUDA, there are no other external dependencies. A cpu-only version that doesn't require CUDA will be released later.
## License ## License
The code is released into the public domain. You can do anything you want with it. The code is released into the public domain. You can do anything you want with it.

2
build-linux-cpu+cuda.sh Executable file
View File

@@ -0,0 +1,2 @@
#!/bin/sh
nvcc -arch compute_30 src/ebsynth.cpp src/ebsynth_cpu.cpp src/ebsynth_cuda.cu -I"include" -DNDEBUG -D__CORRECT_ISO_CPP11_MATH_H_PROTO -O6 -std=c++11 -w -Xcompiler -fopenmp -o bin/ebsynth

2
build-linux-cpu_only.sh Executable file
View File

@@ -0,0 +1,2 @@
#!/bin/sh
g++ src/ebsynth.cpp src/ebsynth_cpu.cpp src/ebsynth_nocuda.cpp -DNDEBUG -O6 -fopenmp -I"include" -std=c++11 -o bin/ebsynth

View File

@@ -1,2 +0,0 @@
#!/bin/sh
nvcc -arch compute_30 src/ebsynth.cu -o bin/ebsynth -I "include" -std=c++11 -Xcompiler "-DNDEBUG -O6 -D__CORRECT_ISO_CPP11_MATH_H_PROTO"

14
build-win32-cpu+cuda.bat Normal file
View File

@@ -0,0 +1,14 @@
@echo off
setlocal ENABLEDELAYEDEXPANSION
for %%V in (15,14,12,11) do if exist "!VS%%V0COMNTOOLS!" call "!VS%%V0COMNTOOLS!..\..\VC\vcvarsall.bat" x86 && goto compile
:compile
nvcc -m32 -arch compute_30 src\ebsynth.cpp src\ebsynth_cpu.cpp src\ebsynth_cuda.cu -DNDEBUG -O6 -I "include" -o "bin\ebsynth.exe" -Xcompiler "/openmp /fp:fast" -Xlinker "/IMPLIB:dummy.lib" -w || goto error
nvcc -m32 -arch compute_30 src\ebsynth.cpp src\ebsynth_cpu.cpp src\ebsynth_cuda.cu -DNDEBUG -O6 -I "include" -o "bin\ebsynth.dll" -Xcompiler "/openmp /fp:fast" -Xlinker "/IMPLIB:lib\ebsynth.lib" -shared -DEBSYNTH_API=__declspec(dllexport) -w || goto error
del dummy.lib;dummy.exp 2> NUL
goto :EOF
:error
echo FAILED
@%COMSPEC% /C exit 1 >nul

14
build-win32-cpu_only.bat Normal file
View File

@@ -0,0 +1,14 @@
@echo off
setlocal ENABLEDELAYEDEXPANSION
for %%V in (15,14,12,11) do if exist "!VS%%V0COMNTOOLS!" call "!VS%%V0COMNTOOLS!..\..\VC\vcvarsall.bat" x86 && goto compile
:compile
cl src\ebsynth.cpp src\ebsynth_cpu.cpp src\ebsynth_nocuda.cpp /DNDEBUG /O2 /openmp /EHsc /nologo /I"include" /Fe"bin\ebsynth.exe" || goto error
cl src\ebsynth.cpp src\ebsynth_cpu.cpp src\ebsynth_nocuda.cpp /DNDEBUG /O2 /openmp /EHsc /nologo /I"include" /Fe"bin\ebsynth.dll" /DEBSYNTH_API="__declspec(dllexport)" /link /IMPLIB:"lib\ebsynth.lib" || goto error
del ebsynth.obj;ebsynth_cpu.obj;ebsynth_nocuda.obj 2> NUL
goto :EOF
:error
echo FAILED
@%COMSPEC% /C exit 1 >nul

View File

@@ -1,12 +0,0 @@
@echo off
setlocal ENABLEDELAYEDEXPANSION
for %%V in (15,14,12,11) do if exist "!VS%%V0COMNTOOLS!" call "!VS%%V0COMNTOOLS!..\..\VC\vcvarsall.bat" x86 && goto compile
:compile
nvcc -arch compute_30 src\ebsynth.cu -m32 -O6 -w -I "include" -o "bin\ebsynth.exe" -Xcompiler "/DNDEBUG /Ox /Oy /Gy /Oi /fp:fast" -Xlinker "/IMPLIB:\"lib\ebsynth.lib\"" || goto error
goto :EOF
:error
echo FAILED
@%COMSPEC% /C exit 1 >nul

14
build-win64-cpu+cuda.bat Normal file
View File

@@ -0,0 +1,14 @@
@echo off
setlocal ENABLEDELAYEDEXPANSION
for %%V in (15,14,12,11) do if exist "!VS%%V0COMNTOOLS!" call "!VS%%V0COMNTOOLS!..\..\VC\vcvarsall.bat" amd64 && goto compile
:compile
nvcc -arch compute_30 src\ebsynth.cpp src\ebsynth_cpu.cpp src\ebsynth_cuda.cu -DNDEBUG -O6 -I "include" -o "bin\ebsynth.exe" -Xcompiler "/openmp /fp:fast" -Xlinker "/IMPLIB:dummy.lib" -w || goto error
nvcc -arch compute_30 src\ebsynth.cpp src\ebsynth_cpu.cpp src\ebsynth_cuda.cu -DNDEBUG -O6 -I "include" -o "bin\ebsynth.dll" -Xcompiler "/openmp /fp:fast" -Xlinker "/IMPLIB:lib\ebsynth.lib" -shared -DEBSYNTH_API=__declspec(dllexport) -w || goto error
del dummy.lib;dummy.exp 2> NUL
goto :EOF
:error
echo FAILED
@%COMSPEC% /C exit 1 >nul

14
build-win64-cpu_only.bat Normal file
View File

@@ -0,0 +1,14 @@
@echo off
setlocal ENABLEDELAYEDEXPANSION
for %%V in (15,14,12,11) do if exist "!VS%%V0COMNTOOLS!" call "!VS%%V0COMNTOOLS!..\..\VC\vcvarsall.bat" amd64 && goto compile
:compile
cl src\ebsynth.cpp src\ebsynth_cpu.cpp src\ebsynth_nocuda.cpp /DNDEBUG /O2 /openmp /EHsc /nologo /I"include" /Fe"bin\ebsynth.exe" || goto error
cl src\ebsynth.cpp src\ebsynth_cpu.cpp src\ebsynth_nocuda.cpp /DNDEBUG /O2 /openmp /EHsc /nologo /I"include" /Fe"bin\ebsynth.dll" /DEBSYNTH_API="__declspec(dllexport)" /link /IMPLIB:"lib\ebsynth.lib" || goto error
del ebsynth.obj;ebsynth_cpu.obj;ebsynth_nocuda.obj 2> NUL
goto :EOF
:error
echo FAILED
@%COMSPEC% /C exit 1 >nul

View File

@@ -1,12 +0,0 @@
@echo off
setlocal ENABLEDELAYEDEXPANSION
for %%V in (15,14,12,11) do if exist "!VS%%V0COMNTOOLS!" call "!VS%%V0COMNTOOLS!..\..\VC\vcvarsall.bat" amd64 && goto compile
:compile
nvcc -arch compute_30 src\ebsynth.cu -m64 -O6 -w -I "include" -o "bin\ebsynth.exe" -Xcompiler "/DNDEBUG /Ox /Oy /Gy /Oi /fp:fast" -Xlinker "/IMPLIB:\"lib\ebsynth.lib\"" || goto error
goto :EOF
:error
echo FAILED
@%COMSPEC% /C exit 1 >nul

551
src/ebsynth.cpp Normal file
View File

@@ -0,0 +1,551 @@
// This software is in the public domain. Where that dedication is not
// recognized, you are granted a perpetual, irrevocable license to copy
// and modify this file as you see fit.
#include "ebsynth.h"
#include "ebsynth_cpu.h"
#include "ebsynth_cuda.h"
#include <cstdio>
#include <cmath>
EBSYNTH_API
void ebsynthRun(int ebsynthBackend,
int numStyleChannels,
int numGuideChannels,
int sourceWidth,
int sourceHeight,
void* sourceStyleData,
void* sourceGuideData,
int targetWidth,
int targetHeight,
void* targetGuideData,
void* targetModulationData,
float* styleWeights,
float* guideWeights,
float uniformityWeight,
int patchSize,
int voteMode,
int numPyramidLevels,
int* numSearchVoteItersPerLevel,
int* numPatchMatchItersPerLevel,
int* stopThresholdPerLevel,
void* outputNnfData,
void* outputImageData)
{
void (*backendDispatch)(int,int,int,int,void*,void*,int,int,void*,void*,float*,float*,float,int,int,int,int*,int*,int*,void*,void*) = 0;
if (ebsynthBackend==EBSYNTH_BACKEND_CPU ) { backendDispatch = ebsynthRunCpu; }
else if (ebsynthBackend==EBSYNTH_BACKEND_CUDA) { backendDispatch = ebsynthRunCuda; }
else if (ebsynthBackend==EBSYNTH_BACKEND_AUTO) { backendDispatch = ebsynthBackendAvailableCuda() ? ebsynthRunCuda : ebsynthRunCpu; }
if (backendDispatch!=0)
{
backendDispatch(numStyleChannels,
numGuideChannels,
sourceWidth,
sourceHeight,
sourceStyleData,
sourceGuideData,
targetWidth,
targetHeight,
targetGuideData,
targetModulationData,
styleWeights,
guideWeights,
uniformityWeight,
patchSize,
voteMode,
numPyramidLevels,
numSearchVoteItersPerLevel,
numPatchMatchItersPerLevel,
stopThresholdPerLevel,
outputNnfData,
outputImageData);
}
}
EBSYNTH_API
int ebsynthBackendAvailable(int ebsynthBackend)
{
if (ebsynthBackend==EBSYNTH_BACKEND_CPU ) { return ebsynthBackendAvailableCpu(); }
else if (ebsynthBackend==EBSYNTH_BACKEND_CUDA) { return ebsynthBackendAvailableCuda(); }
else if (ebsynthBackend==EBSYNTH_BACKEND_AUTO) { return ebsynthBackendAvailableCpu() || ebsynthBackendAvailableCuda(); }
return 0;
}
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
#include <cstdio>
#include <cmath>
#include <vector>
#include <string>
#include <algorithm>
#include "jzq.h"
template<typename FUNC>
bool tryToParseArg(const std::vector<std::string>& args,int* inout_argi,const char* name,bool* out_fail,FUNC handler)
{
int& argi = *inout_argi;
bool& fail = *out_fail;
if (argi<0 || argi>=args.size()) { fail = true; return false; }
if (args[argi]==name)
{
argi++;
fail = !handler();
return true;
}
fail = false; return false;
}
bool tryToParseIntArg(const std::vector<std::string>& args,int* inout_argi,const char* name,int* out_value,bool* out_fail)
{
return tryToParseArg(args,inout_argi,name,out_fail,[&]
{
int& argi = *inout_argi;
if (argi<args.size())
{
const std::string& arg = args[argi];
try
{
std::size_t pos = 0;
*out_value = std::stoi(arg,&pos);
if (pos!=arg.size()) { printf("error: bad %s argument '%s'\n",name,arg.c_str()); return false; }
return true;
}
catch(...)
{
printf("error: bad %s argument '%s'\n",name,arg.c_str());
return false;
}
}
printf("error: missing argument for the %s option\n",name);
return false;
});
}
bool tryToParseFloatArg(const std::vector<std::string>& args,int* inout_argi,const char* name,float* out_value,bool* out_fail)
{
return tryToParseArg(args,inout_argi,name,out_fail,[&]
{
int& argi = *inout_argi;
if (argi<args.size())
{
const std::string& arg = args[argi];
try
{
std::size_t pos = 0;
*out_value = std::stof(arg,&pos);
if (pos!=arg.size()) { printf("error: bad %s argument '%s'\n",name,arg.c_str()); return false; }
return true;
}
catch(...)
{
printf("error: bad %s argument '%s'\n",name,args[argi].c_str());
return false;
}
}
printf("error: missing argument for the %s option\n",name);
return false;
});
}
bool tryToParseStringArg(const std::vector<std::string>& args,int* inout_argi,const char* name,std::string* out_value,bool* out_fail)
{
return tryToParseArg(args,inout_argi,name,out_fail,[&]
{
int& argi = *inout_argi;
if (argi<args.size())
{
*out_value = args[argi];
return true;
}
printf("error: missing argument for the %s option\n",name);
return false;
});
}
bool tryToParseStringPairArg(const std::vector<std::string>& args,int* inout_argi,const char* name,std::pair<std::string,std::string>* out_value,bool* out_fail)
{
return tryToParseArg(args,inout_argi,name,out_fail,[&]
{
int& argi = *inout_argi;
if ((argi+1)<args.size())
{
*out_value = std::make_pair(args[argi],args[argi+1]);
argi++;
return true;
}
printf("error: missing argument for the %s option\n",name);
return false;
});
}
#define STB_IMAGE_IMPLEMENTATION
#include "stb_image.h"
#define STB_IMAGE_WRITE_IMPLEMENTATION
#include "stb_image_write.h"
unsigned char* tryLoad(const std::string& fileName,int* width,int* height)
{
unsigned char* data = stbi_load(fileName.c_str(),width,height,NULL,4);
if (data==NULL)
{
printf("error: failed to load '%s'\n",fileName.c_str());
printf("%s\n",stbi_failure_reason());
exit(1);
}
return data;
}
int evalNumChannels(const unsigned char* data,const int numPixels)
{
bool isGray = true;
bool hasAlpha = false;
for(int xy=0;xy<numPixels;xy++)
{
const unsigned char r = data[xy*4+0];
const unsigned char g = data[xy*4+1];
const unsigned char b = data[xy*4+2];
const unsigned char a = data[xy*4+3];
if (!(r==g && g==b)) { isGray = false; }
if (a<255) { hasAlpha = true; }
}
const int numChannels = (isGray ? 1 : 3) + (hasAlpha ? 1 : 0);
return numChannels;
}
V2i pyramidLevelSize(const V2i& sizeBase,const int level)
{
return V2i(V2f(sizeBase)*std::pow(2.0f,-float(level)));
}
std::string backendToString(const int ebsynthBackend)
{
if (ebsynthBackend==EBSYNTH_BACKEND_CPU) { return "cpu"; }
else if (ebsynthBackend==EBSYNTH_BACKEND_CUDA) { return "cuda"; }
else if (ebsynthBackend==EBSYNTH_BACKEND_AUTO) { return "auto"; }
return "unknown";
}
int main(int argc,char** argv)
{
if (argc<2)
{
printf("usage: %s [options]\n",argv[0]);
printf("\n");
printf("options:\n");
printf(" -style <style.png>\n");
printf(" -guide <source.png> <target.png>\n");
printf(" -output <output.png>\n");
printf(" -weight <value>\n");
printf(" -uniformity <value>\n");
printf(" -patchsize <size>\n");
printf(" -pyramidlevels <number>\n");
printf(" -searchvoteiters <number>\n");
printf(" -patchmatchiters <number>\n");
printf(" -stopthreshold <value>\n");
printf(" -backend [cpu|cuda]\n");
printf("\n");
return 1;
}
std::string styleFileName;
float styleWeight = NAN;
std::string outputFileName = "output.png";
struct Guide
{
std::string sourceFileName;
std::string targetFileName;
float weight;
int sourceWidth;
int sourceHeight;
unsigned char* sourceData;
int targetWidth;
int targetHeight;
unsigned char* targetData;
int numChannels;
};
std::vector<Guide> guides;
float uniformityWeight = 3500;
int patchSize = 5;
int numPyramidLevels = -1;
int numSearchVoteIters = 6;
int numPatchMatchIters = 4;
int stopThreshold = 5;
int backend = ebsynthBackendAvailable(EBSYNTH_BACKEND_CUDA) ? EBSYNTH_BACKEND_CUDA : EBSYNTH_BACKEND_CPU;
{
std::vector<std::string> args(argc);
for(int i=0;i<argc;i++) { args[i] = argv[i]; }
bool fail = false;
int argi = 1;
float* precedingStyleOrGuideWeight = 0;
while(argi<argc && !fail)
{
float weight;
std::pair<std::string,std::string> guidePair;
std::string backendName;
if (tryToParseStringArg(args,&argi,"-style",&styleFileName,&fail))
{
styleWeight = NAN;
precedingStyleOrGuideWeight = &styleWeight;
argi++;
}
else if (tryToParseStringPairArg(args,&argi,"-guide",&guidePair,&fail))
{
Guide guide;
guide.sourceFileName = guidePair.first;
guide.targetFileName = guidePair.second;
guide.weight = NAN;
guides.push_back(guide);
precedingStyleOrGuideWeight = &guides[guides.size()-1].weight;
argi++;
}
else if (tryToParseStringArg(args,&argi,"-output",&outputFileName,&fail))
{
argi++;
}
else if (tryToParseFloatArg(args,&argi,"-weight",&weight,&fail))
{
if (precedingStyleOrGuideWeight!=0) { *precedingStyleOrGuideWeight = weight; }
else { printf("error: at least one -style or -guide option must precede the -weight option!\n"); return 1; }
argi++;
}
else if (tryToParseFloatArg(args,&argi,"-uniformity",&uniformityWeight,&fail)) { argi++; }
else if (tryToParseIntArg(args,&argi,"-patchsize",&patchSize,&fail))
{
if (patchSize<3) { printf("error: patchsize is too small!\n"); return 1; }
if (patchSize%2==0) { printf("error: patchsize must be an odd number!\n"); return 1; }
argi++;
}
else if (tryToParseIntArg(args,&argi,"-pyramidlevels",&numPyramidLevels,&fail))
{
if (numPyramidLevels<1) { printf("error: bad argument for -pyramidlevels!\n"); return 1; }
argi++;
}
else if (tryToParseIntArg(args,&argi,"-searchvoteiters",&numSearchVoteIters,&fail))
{
if (numSearchVoteIters<0) { printf("error: bad argument for -searchvoteiters!\n"); return 1; }
argi++;
}
else if (tryToParseIntArg(args,&argi,"-patchmatchiters",&numPatchMatchIters,&fail))
{
if (numPatchMatchIters<0) { printf("error: bad argument for -patchmatchiters!\n"); return 1; }
argi++;
}
else if (tryToParseIntArg(args,&argi,"-stopthreshold",&stopThreshold,&fail))
{
if (stopThreshold<0) { printf("error: bad argument for -stopthreshold!\n"); return 1; }
argi++;
}
else if (tryToParseStringArg(args,&argi,"-backend",&backendName,&fail))
{
if (backendName=="cpu" ) { backend = EBSYNTH_BACKEND_CPU; }
else if (backendName=="cuda") { backend = EBSYNTH_BACKEND_CUDA; }
else { printf("error: unrecognized backend '%s'\n",backendName.c_str()); return 1; }
if (!ebsynthBackendAvailable(backend)) { printf("error: the %s backend is not available!\n",backendToString(backend).c_str()); return 1; }
argi++;
}
else
{
printf("error: unrecognized option '%s'\n",args[argi].c_str());
fail = true;
}
}
if (fail) { return 1; }
}
const int numGuides = guides.size();
int sourceWidth = 0;
int sourceHeight = 0;
unsigned char* sourceStyleData = tryLoad(styleFileName,&sourceWidth,&sourceHeight);
const int numStyleChannelsTotal = evalNumChannels(sourceStyleData,sourceWidth*sourceHeight);
std::vector<unsigned char> sourceStyle(sourceWidth*sourceHeight*numStyleChannelsTotal);
for(int xy=0;xy<sourceWidth*sourceHeight;xy++)
{
if (numStyleChannelsTotal>0) { sourceStyle[xy*numStyleChannelsTotal+0] = sourceStyleData[xy*4+0]; }
if (numStyleChannelsTotal==2) { sourceStyle[xy*numStyleChannelsTotal+1] = sourceStyleData[xy*4+3]; }
else if (numStyleChannelsTotal>1) { sourceStyle[xy*numStyleChannelsTotal+1] = sourceStyleData[xy*4+1]; }
if (numStyleChannelsTotal>2) { sourceStyle[xy*numStyleChannelsTotal+2] = sourceStyleData[xy*4+2]; }
if (numStyleChannelsTotal>3) { sourceStyle[xy*numStyleChannelsTotal+3] = sourceStyleData[xy*4+3]; }
}
int targetWidth = 0;
int targetHeight = 0;
int numGuideChannelsTotal = 0;
for(int i=0;i<numGuides;i++)
{
Guide& guide = guides[i];
guide.sourceData = tryLoad(guide.sourceFileName,&guide.sourceWidth,&guide.sourceHeight);
guide.targetData = tryLoad(guide.targetFileName,&guide.targetWidth,&guide.targetHeight);
if (guide.sourceWidth!=sourceWidth || guide.sourceHeight!=sourceHeight) { printf("error: source guide '%s' doesn't match the resolution of '%s'\n",guide.sourceFileName.c_str(),styleFileName.c_str()); return 1; }
if (i>0 && (guide.targetWidth!=targetWidth || guide.targetHeight!=targetHeight)) { printf("error: target guide '%s' doesn't match the resolution of '%s'\n",guide.targetFileName.c_str(),guides[0].targetFileName.c_str()); return 1; }
else if (i==0) { targetWidth = guide.targetWidth; targetHeight = guide.targetHeight; }
guide.numChannels = std::max(evalNumChannels(guide.sourceData,sourceWidth*sourceHeight),
evalNumChannels(guide.targetData,targetWidth*targetHeight));
numGuideChannelsTotal += guide.numChannels;
}
if (numStyleChannelsTotal>EBSYNTH_MAX_STYLE_CHANNELS) { printf("error: too many style channels (%d), maximum number is %d\n",numStyleChannelsTotal,EBSYNTH_MAX_STYLE_CHANNELS); return 1; }
if (numGuideChannelsTotal>EBSYNTH_MAX_GUIDE_CHANNELS) { printf("error: too many guide channels (%d), maximum number is %d\n",numGuideChannelsTotal,EBSYNTH_MAX_GUIDE_CHANNELS); return 1; }
std::vector<unsigned char> sourceGuides(sourceWidth*sourceHeight*numGuideChannelsTotal);
for(int xy=0;xy<sourceWidth*sourceHeight;xy++)
{
int c = 0;
for(int i=0;i<numGuides;i++)
{
const int numChannels = guides[i].numChannels;
if (numChannels>0) { sourceGuides[xy*numGuideChannelsTotal+c+0] = guides[i].sourceData[xy*4+0]; }
if (numChannels==2) { sourceGuides[xy*numGuideChannelsTotal+c+1] = guides[i].sourceData[xy*4+3]; }
else if (numChannels>1) { sourceGuides[xy*numGuideChannelsTotal+c+1] = guides[i].sourceData[xy*4+1]; }
if (numChannels>2) { sourceGuides[xy*numGuideChannelsTotal+c+2] = guides[i].sourceData[xy*4+2]; }
if (numChannels>3) { sourceGuides[xy*numGuideChannelsTotal+c+3] = guides[i].sourceData[xy*4+3]; }
c += numChannels;
}
}
std::vector<unsigned char> targetGuides(targetWidth*targetHeight*numGuideChannelsTotal);
for(int xy=0;xy<targetWidth*targetHeight;xy++)
{
int c = 0;
for(int i=0;i<numGuides;i++)
{
const int numChannels = guides[i].numChannels;
if (numChannels>0) { targetGuides[xy*numGuideChannelsTotal+c+0] = guides[i].targetData[xy*4+0]; }
if (numChannels==2) { targetGuides[xy*numGuideChannelsTotal+c+1] = guides[i].targetData[xy*4+3]; }
else if (numChannels>1) { targetGuides[xy*numGuideChannelsTotal+c+1] = guides[i].targetData[xy*4+1]; }
if (numChannels>2) { targetGuides[xy*numGuideChannelsTotal+c+2] = guides[i].targetData[xy*4+2]; }
if (numChannels>3) { targetGuides[xy*numGuideChannelsTotal+c+3] = guides[i].targetData[xy*4+3]; }
c += numChannels;
}
}
std::vector<float> styleWeights(numStyleChannelsTotal);
if (isnan(styleWeight)) { styleWeight = 1.0f; }
for(int i=0;i<numStyleChannelsTotal;i++) { styleWeights[i] = styleWeight / float(numStyleChannelsTotal); }
for(int i=0;i<numGuides;i++) { if (isnan(guides[i].weight)) { guides[i].weight = 1.0f/float(numGuides); } }
std::vector<float> guideWeights(numGuideChannelsTotal);
{
int c = 0;
for(int i=0;i<numGuides;i++)
{
const int numChannels = guides[i].numChannels;
for(int j=0;j<numChannels;j++)
{
guideWeights[c+j] = guides[i].weight / float(numChannels);
}
c += numChannels;
}
}
int maxPyramidLevels = 0;
for(int level=32;level>=0;level--)
{
if (min(pyramidLevelSize(std::min(V2i(sourceWidth,sourceHeight),V2i(targetWidth,targetHeight)),level)) >= (2*patchSize+1))
{
maxPyramidLevels = level+1;
break;
}
}
if (numPyramidLevels==-1) { numPyramidLevels = maxPyramidLevels; }
numPyramidLevels = std::min(numPyramidLevels,maxPyramidLevels);
std::vector<int> numSearchVoteItersPerLevel(numPyramidLevels);
std::vector<int> numPatchMatchItersPerLevel(numPyramidLevels);
std::vector<int> stopThresholdPerLevel(numPyramidLevels);
for(int i=0;i<numPyramidLevels;i++)
{
numSearchVoteItersPerLevel[i] = numSearchVoteIters;
numPatchMatchItersPerLevel[i] = numPatchMatchIters;
stopThresholdPerLevel[i] = stopThreshold;
}
std::vector<unsigned char> output(targetWidth*targetHeight*numStyleChannelsTotal);
printf("uniformity: %.0f\n",uniformityWeight);
printf("patchsize: %d\n",patchSize);
printf("pyramidlevels: %d\n",numPyramidLevels);
printf("searchvoteiters: %d\n",numSearchVoteIters);
printf("patchmatchiters: %d\n",numPatchMatchIters);
printf("stopthreshold: %d\n",stopThreshold);
printf("backend: %s\n",backendToString(backend).c_str());
ebsynthRun(backend,
numStyleChannelsTotal,
numGuideChannelsTotal,
sourceWidth,
sourceHeight,
sourceStyle.data(),
sourceGuides.data(),
targetWidth,
targetHeight,
targetGuides.data(),
NULL,
styleWeights.data(),
guideWeights.data(),
uniformityWeight,
patchSize,
EBSYNTH_VOTEMODE_PLAIN,
numPyramidLevels,
numSearchVoteItersPerLevel.data(),
numPatchMatchItersPerLevel.data(),
stopThresholdPerLevel.data(),
NULL,
output.data());
stbi_write_png(outputFileName.c_str(),targetWidth,targetHeight,numStyleChannelsTotal,output.data(),numStyleChannelsTotal*targetWidth);
printf("result was written to %s\n",outputFileName.c_str());
stbi_image_free(sourceStyleData);
for(int i=0;i<numGuides;i++)
{
stbi_image_free(guides[i].sourceData);
stbi_image_free(guides[i].targetData);
}
return 0;
}

1037
src/ebsynth_cpu.cpp Normal file

File diff suppressed because it is too large Load Diff

32
src/ebsynth_cpu.h Normal file
View File

@@ -0,0 +1,32 @@
// This software is in the public domain. Where that dedication is not
// recognized, you are granted a perpetual, irrevocable license to copy
// and modify this file as you see fit.
#ifndef EBSYNTH_CPU_H_
#define EBSYNTH_CPU_H_
void ebsynthRunCpu(int numStyleChannels,
int numGuideChannels,
int sourceWidth,
int sourceHeight,
void* sourceStyleData,
void* sourceGuideData,
int targetWidth,
int targetHeight,
void* targetGuideData,
void* targetModulationData,
float* styleWeights,
float* guideWeights,
float uniformityWeight,
int patchSize,
int voteMode,
int numPyramidLevels,
int* numSearchVoteItersPerLevel,
int* numPatchMatchItersPerLevel,
int* stopThresholdPerLevel,
void* outputNnfData,
void* outputImageData);
int ebsynthBackendAvailableCpu();
#endif

View File

@@ -3,11 +3,369 @@
// and modify this file as you see fit. // and modify this file as you see fit.
#include "ebsynth.h" #include "ebsynth.h"
#include "patchmatch_gpu.h" #include "ebsynth_cuda_texarray2.h"
#include "ebsynth_cuda_memarray2.h"
#include <cmath>
#include <cfloat>
#include <stdint.h>
#define FOR(A,X,Y) for(int Y=0;Y<A.height();Y++) for(int X=0;X<A.width();X++) #define FOR(A,X,Y) for(int Y=0;Y<A.height();Y++) for(int X=0;X<A.width();X++)
A2V2i nnfInitRandom(const V2i& targetSize, typedef Vec<1,float> V1f;
typedef Array2<Vec<1,float>> A2V1f;
struct pcgState
{
uint64_t state;
uint64_t increment;
};
__device__ void pcgAdvance(pcgState* rng)
{
rng->state = rng->state * 6364136223846793005ULL + rng->increment;
}
__device__ uint32_t pcgOutput(uint64_t state)
{
return (uint32_t)(((state >> 22u) ^ state) >> ((state >> 61u) + 22u));
}
__device__ uint32_t pcgRand(pcgState* rng)
{
uint64_t oldstate = rng->state;
pcgAdvance(rng);
return pcgOutput(oldstate);
}
__device__ void pcgInit(pcgState* rng,uint64_t seed,uint64_t stream)
{
rng->state = 0U;
rng->increment = (stream << 1u) | 1u;
pcgAdvance(rng);
rng->state += seed;
pcgAdvance(rng);
}
__global__ void krnlInitRngStates(const int width,
const int height,
pcgState* rngStates)
{
const int x = blockDim.x*blockIdx.x + threadIdx.x;
const int y = blockDim.y*blockIdx.y + threadIdx.y;
if (x<width && y<height)
{
const int idx = x+y*width;
pcgInit(&rngStates[idx],1337,idx);
}
}
pcgState* initGpuRng(const int width,
const int height)
{
pcgState* gpuRngStates;
cudaMalloc(&gpuRngStates,width*height*sizeof(pcgState));
const dim3 threadsPerBlock(16,16);
const dim3 numBlocks((width+threadsPerBlock.x)/threadsPerBlock.x,
(height+threadsPerBlock.y)/threadsPerBlock.y);
krnlInitRngStates<<<numBlocks,threadsPerBlock>>>(width,height,gpuRngStates);
return gpuRngStates;
}
template<typename FUNC>
__global__ void krnlEvalErrorPass(const int patchWidth,
FUNC patchError,
const TexArray2<2,int> NNF,
TexArray2<1,float> E)
{
const int x = blockDim.x*blockIdx.x + threadIdx.x;
const int y = blockDim.y*blockIdx.y + threadIdx.y;
if (x<NNF.width && y<NNF.height)
{
const V2i n = NNF(x,y);
E.write(x,y,V1f(patchError(patchWidth,x,y,n[0],n[1],FLT_MAX)));
}
}
void __device__ updateOmega(MemArray2<int>& Omega,const int patchWidth,const int bx,const int by,const int incdec)
{
const int r = patchWidth/2;
for(int oy=-r;oy<=+r;oy++)
for(int ox=-r;ox<=+r;ox++)
{
const int x = bx+ox;
const int y = by+oy;
atomicAdd(&Omega.data[x+y*Omega.width],incdec);
//Omega.data[x+y*Omega.width] += incdec;
}
}
int __device__ patchOmega(const int patchWidth,const int bx,const int by,const MemArray2<int>& Omega)
{
const int r = patchWidth/2;
int sum = 0;
for(int oy=-r;oy<=+r;oy++)
for(int ox=-r;ox<=+r;ox++)
{
const int x = bx+ox;
const int y = by+oy;
sum += Omega.data[x+y*Omega.width]; /// XXX: atomic read instead ??
}
return sum;
}
template<typename FUNC>
__device__ void tryPatch(const V2i& sizeA,
const V2i& sizeB,
MemArray2<int>& Omega,
const int patchWidth,
FUNC patchError,
const float lambda,
const int ax,
const int ay,
const int bx,
const int by,
V2i& nbest,
float& ebest)
{
const float omegaBest = (float(sizeA(0)*sizeA(1)) /
float(sizeB(0)*sizeB(1))) * float(patchWidth*patchWidth);
const float curOcc = (float(patchOmega(patchWidth,nbest(0),nbest(1),Omega))/float(patchWidth*patchWidth))/omegaBest;
const float newOcc = (float(patchOmega(patchWidth, bx, by,Omega))/float(patchWidth*patchWidth))/omegaBest;
const float curErr = ebest;
const float newErr = patchError(patchWidth,ax,ay,bx,by,curErr+lambda*curOcc);
if ((newErr+lambda*newOcc) < (curErr+lambda*curOcc))
{
updateOmega(Omega,patchWidth, bx, by,+1);
updateOmega(Omega,patchWidth,nbest(0),nbest(1),-1);
nbest = V2i(bx,by);
ebest = newErr;
}
}
template<typename FUNC>
__device__ void tryNeighborsOffset(const int x,
const int y,
const int ox,
const int oy,
V2i& nbest,
float& ebest,
const V2i& sizeA,
const V2i& sizeB,
MemArray2<int>& Omega,
const int patchWidth,
FUNC patchError,
const float lambda,
const TexArray2<2,int>& NNF)
{
const int hpw = patchWidth/2;
const V2i on = NNF(x+ox,y+oy);
const int nx = on(0)-ox;
const int ny = on(1)-oy;
if (nx>=hpw && nx<sizeB(0)-hpw &&
ny>=hpw && ny<sizeB(1)-hpw)
{
tryPatch(sizeA,sizeB,Omega,patchWidth,patchError,lambda,x,y,nx,ny,nbest,ebest);
}
}
template<typename FUNC>
__global__ void krnlPropagationPass(const V2i sizeA,
const V2i sizeB,
MemArray2<int> Omega,
const int patchWidth,
FUNC patchError,
const float lambda,
const int r,
const TexArray2<2,int> NNF,
TexArray2<2,int> NNF2,
TexArray2<1,float> E,
TexArray2<1,unsigned char> mask)
{
const int x = blockDim.x*blockIdx.x + threadIdx.x;
const int y = blockDim.y*blockIdx.y + threadIdx.y;
if (x<sizeA(0) && y<sizeA(1))
{
V2i nbest = NNF(x,y);
float ebest = E(x,y)(0);
if (mask(x,y)[0]==255)
{
tryNeighborsOffset(x,y,-r,0,nbest,ebest,sizeA,sizeB,Omega,patchWidth,patchError,lambda,NNF);
tryNeighborsOffset(x,y,+r,0,nbest,ebest,sizeA,sizeB,Omega,patchWidth,patchError,lambda,NNF);
tryNeighborsOffset(x,y,0,-r,nbest,ebest,sizeA,sizeB,Omega,patchWidth,patchError,lambda,NNF);
tryNeighborsOffset(x,y,0,+r,nbest,ebest,sizeA,sizeB,Omega,patchWidth,patchError,lambda,NNF);
}
E.write(x,y,V1f(ebest));
NNF2.write(x,y,nbest);
}
}
template<typename FUNC>
__device__ void tryRandomOffsetInRadius(const int r,
const V2i& sizeA,
const V2i& sizeB,
MemArray2<int>& Omega,
const int patchWidth,
FUNC patchError,
const float lambda,
const int x,
const int y,
const V2i& norg,
V2i& nbest,
float& ebest,
pcgState* rngState)
{
const int hpw = patchWidth/2;
const int xmin = max(norg(0)-r,hpw);
const int xmax = min(norg(0)+r,sizeB(0)-1-hpw);
const int ymin = max(norg(1)-r,hpw);
const int ymax = min(norg(1)+r,sizeB(1)-1-hpw);
const int nx = xmin+(pcgRand(rngState)%(xmax-xmin+1));
const int ny = ymin+(pcgRand(rngState)%(ymax-ymin+1));
tryPatch(sizeA,sizeB,Omega,patchWidth,patchError,lambda,x,y,nx,ny,nbest,ebest);
}
/*
template<typename FUNC>
__global__ void krnlRandomSearchPass(const V2i sizeA,
const V2i sizeB,
MemArray2<int> Omega,
const int patchWidth,
FUNC patchError,
const float lambda,
TexArray2<2,int> NNF,
TexArray2<1,float> E,
TexArray2<1,unsigned char> mask,
pcgState* rngStates)
{
const int x = blockDim.x*blockIdx.x + threadIdx.x;
const int y = blockDim.y*blockIdx.y + threadIdx.y;
if (x<sizeA(0) && y<sizeA(1))
{
if (mask(x,y)[0]==255)
{
V2i nbest = NNF(x,y);
float ebest = E(x,y)(0);
const V2i norg = nbest;
for(int r=1;r<max(sizeB(0),sizeB(1))/2;r=r*2)
{
tryRandomOffsetInRadius(r,sizeA,sizeB,Omega,patchWidth,patchError,lambda,x,y,norg,nbest,ebest,&rngStates[x+y*NNF.width]);
}
E.write(x,y,V1f(ebest));
NNF.write(x,y,nbest);
}
}
}
*/
template<typename FUNC>
__global__ void krnlRandomSearchPass(const V2i sizeA,
const V2i sizeB,
MemArray2<int> Omega,
const int patchWidth,
FUNC patchError,
const float lambda,
const int radius,
TexArray2<2,int> NNF,
TexArray2<1,float> E,
TexArray2<1,unsigned char> mask,
pcgState* rngStates)
{
const int x = blockDim.x*blockIdx.x + threadIdx.x;
const int y = blockDim.y*blockIdx.y + threadIdx.y;
if (x<sizeA(0) && y<sizeA(1))
{
if (mask(x,y)[0]==255)
{
V2i nbest = NNF(x,y);
float ebest = E(x,y)(0);
const V2i norg = nbest;
tryRandomOffsetInRadius(radius,sizeA,sizeB,Omega,patchWidth,patchError,lambda,x,y,norg,nbest,ebest,&rngStates[x+y*NNF.width]);
E.write(x,y,V1f(ebest));
NNF.write(x,y,nbest);
}
}
}
template<typename FUNC>
void patchmatchGPU(const V2i sizeA,
const V2i sizeB,
MemArray2<int>& Omega,
const int patchWidth,
FUNC patchError,
const float lambda,
const int numIters,
const int numThreadsPerBlock,
TexArray2<2,int>& NNF,
TexArray2<2,int>& NNF2,
TexArray2<1,float>& E,
TexArray2<1,unsigned char>& mask,
pcgState* rngStates)
{
const dim3 threadsPerBlock = dim3(numThreadsPerBlock,numThreadsPerBlock);
const dim3 numBlocks = dim3((NNF.width+threadsPerBlock.x)/threadsPerBlock.x,
(NNF.height+threadsPerBlock.y)/threadsPerBlock.y);
krnlEvalErrorPass<<<numBlocks,threadsPerBlock>>>(patchWidth,patchError,NNF,E);
checkCudaError(cudaDeviceSynchronize());
for(int i=0;i<numIters;i++)
{
krnlPropagationPass<<<numBlocks,threadsPerBlock>>>(sizeA,sizeB,Omega,patchWidth,patchError,lambda,4,NNF,NNF2,E,mask); std::swap(NNF,NNF2);
checkCudaError(cudaDeviceSynchronize());
krnlPropagationPass<<<numBlocks,threadsPerBlock>>>(sizeA,sizeB,Omega,patchWidth,patchError,lambda,2,NNF,NNF2,E,mask); std::swap(NNF,NNF2);
checkCudaError(cudaDeviceSynchronize());
krnlPropagationPass<<<numBlocks,threadsPerBlock>>>(sizeA,sizeB,Omega,patchWidth,patchError,lambda,1,NNF,NNF2,E,mask); std::swap(NNF,NNF2);
checkCudaError(cudaDeviceSynchronize());
for(int r=1;r<max(sizeB(0),sizeB(1))/2;r=r*2)
{
krnlRandomSearchPass<<<numBlocks,threadsPerBlock>>>(sizeA,sizeB,Omega,patchWidth,patchError,lambda,r,NNF,E,mask,rngStates);
}
checkCudaError(cudaDeviceSynchronize());
}
krnlEvalErrorPass<<<numBlocks,threadsPerBlock>>>(patchWidth,patchError,NNF,E);
checkCudaError(cudaDeviceSynchronize());
}
static A2V2i nnfInitRandom(const V2i& targetSize,
const V2i& sourceSize, const V2i& sourceSize,
const int patchSize) const int patchSize)
{ {
@@ -26,7 +384,7 @@ A2V2i nnfInitRandom(const V2i& targetSize,
return NNF; return NNF;
} }
A2V2i nnfUpscale(const A2V2i& NNF, static A2V2i nnfUpscale(const A2V2i& NNF,
const int patchSize, const int patchSize,
const V2i& targetSize, const V2i& targetSize,
const V2i& sourceSize) const V2i& sourceSize)
@@ -381,14 +739,13 @@ struct PatchSSD_Split_Modulation
} }
}; };
V2i pyramidLevelSize(const V2i& sizeBase,const int numLevels,const int level) static V2i pyramidLevelSize(const V2i& sizeBase,const int numLevels,const int level)
{ {
return V2i(V2f(sizeBase)*pow(2.0f,-float(numLevels-1-level))); return V2i(V2f(sizeBase)*std::pow(2.0f,-float(numLevels-1-level)));
} }
template<int NS,int NG> template<int NS,int NG>
void runEbsynth(int ebsynthBackend, void ebsynthCuda(int numStyleChannels,
int numStyleChannels,
int numGuideChannels, int numGuideChannels,
int sourceWidth, int sourceWidth,
int sourceHeight, int sourceHeight,
@@ -407,7 +764,8 @@ void runEbsynth(int ebsynthBackend,
int* numSearchVoteItersPerLevel, int* numSearchVoteItersPerLevel,
int* numPatchMatchItersPerLevel, int* numPatchMatchItersPerLevel,
int* stopThresholdPerLevel, int* stopThresholdPerLevel,
void* outputData) void* outputNnfData,
void* outputImageData)
{ {
const int levelCount = numPyramidLevels; const int levelCount = numPyramidLevels;
@@ -706,7 +1064,11 @@ void runEbsynth(int ebsynthBackend,
} }
} }
if (level==levelCount-1) { copy(&outputData,pyramid[pyramid.size()-1].targetStyle); } if (level==levelCount-1)
{
if (outputNnfData!=NULL) { copy(&outputNnfData,pyramid[level].NNF); }
copy(&outputImageData,pyramid[level].targetStyle);
}
pyramid[level].sourceStyle.destroy(); pyramid[level].sourceStyle.destroy();
pyramid[level].sourceGuide.destroy(); pyramid[level].sourceGuide.destroy();
@@ -726,8 +1088,7 @@ void runEbsynth(int ebsynthBackend,
checkCudaError( cudaFree(rngStates) ); checkCudaError( cudaFree(rngStates) );
} }
EBSYNTH_API void ebsynthRun(int ebsynthBackend, void ebsynthRunCuda(int numStyleChannels,
int numStyleChannels,
int numGuideChannels, int numGuideChannels,
int sourceWidth, int sourceWidth,
int sourceHeight, int sourceHeight,
@@ -746,42 +1107,41 @@ EBSYNTH_API void ebsynthRun(int ebsynthBackend,
int* numSearchVoteItersPerLevel, int* numSearchVoteItersPerLevel,
int* numPatchMatchItersPerLevel, int* numPatchMatchItersPerLevel,
int* stopThresholdPerLevel, int* stopThresholdPerLevel,
void* outputData void* outputNnfData,
) void* outputImageData)
{ {
void (*const dispatchEbsynth[EBSYNTH_MAX_GUIDE_CHANNELS][EBSYNTH_MAX_STYLE_CHANNELS])(int,int,int,int,int,void*,void*,int,int,void*,void*,float*,float*,float,int,int,int,int*,int*,int*,void*) = void (*const dispatchEbsynth[EBSYNTH_MAX_GUIDE_CHANNELS][EBSYNTH_MAX_STYLE_CHANNELS])(int,int,int,int,void*,void*,int,int,void*,void*,float*,float*,float,int,int,int,int*,int*,int*,void*,void*) =
{ {
{ runEbsynth<1, 1>, runEbsynth<2, 1>, runEbsynth<3, 1>, runEbsynth<4, 1>, runEbsynth<5, 1>, runEbsynth<6, 1>, runEbsynth<7, 1>, runEbsynth<8, 1> }, { ebsynthCuda<1, 1>, ebsynthCuda<2, 1>, ebsynthCuda<3, 1>, ebsynthCuda<4, 1>, ebsynthCuda<5, 1>, ebsynthCuda<6, 1>, ebsynthCuda<7, 1>, ebsynthCuda<8, 1> },
{ runEbsynth<1, 2>, runEbsynth<2, 2>, runEbsynth<3, 2>, runEbsynth<4, 2>, runEbsynth<5, 2>, runEbsynth<6, 2>, runEbsynth<7, 2>, runEbsynth<8, 2> }, { ebsynthCuda<1, 2>, ebsynthCuda<2, 2>, ebsynthCuda<3, 2>, ebsynthCuda<4, 2>, ebsynthCuda<5, 2>, ebsynthCuda<6, 2>, ebsynthCuda<7, 2>, ebsynthCuda<8, 2> },
{ runEbsynth<1, 3>, runEbsynth<2, 3>, runEbsynth<3, 3>, runEbsynth<4, 3>, runEbsynth<5, 3>, runEbsynth<6, 3>, runEbsynth<7, 3>, runEbsynth<8, 3> }, { ebsynthCuda<1, 3>, ebsynthCuda<2, 3>, ebsynthCuda<3, 3>, ebsynthCuda<4, 3>, ebsynthCuda<5, 3>, ebsynthCuda<6, 3>, ebsynthCuda<7, 3>, ebsynthCuda<8, 3> },
{ runEbsynth<1, 4>, runEbsynth<2, 4>, runEbsynth<3, 4>, runEbsynth<4, 4>, runEbsynth<5, 4>, runEbsynth<6, 4>, runEbsynth<7, 4>, runEbsynth<8, 4> }, { ebsynthCuda<1, 4>, ebsynthCuda<2, 4>, ebsynthCuda<3, 4>, ebsynthCuda<4, 4>, ebsynthCuda<5, 4>, ebsynthCuda<6, 4>, ebsynthCuda<7, 4>, ebsynthCuda<8, 4> },
{ runEbsynth<1, 5>, runEbsynth<2, 5>, runEbsynth<3, 5>, runEbsynth<4, 5>, runEbsynth<5, 5>, runEbsynth<6, 5>, runEbsynth<7, 5>, runEbsynth<8, 5> }, { ebsynthCuda<1, 5>, ebsynthCuda<2, 5>, ebsynthCuda<3, 5>, ebsynthCuda<4, 5>, ebsynthCuda<5, 5>, ebsynthCuda<6, 5>, ebsynthCuda<7, 5>, ebsynthCuda<8, 5> },
{ runEbsynth<1, 6>, runEbsynth<2, 6>, runEbsynth<3, 6>, runEbsynth<4, 6>, runEbsynth<5, 6>, runEbsynth<6, 6>, runEbsynth<7, 6>, runEbsynth<8, 6> }, { ebsynthCuda<1, 6>, ebsynthCuda<2, 6>, ebsynthCuda<3, 6>, ebsynthCuda<4, 6>, ebsynthCuda<5, 6>, ebsynthCuda<6, 6>, ebsynthCuda<7, 6>, ebsynthCuda<8, 6> },
{ runEbsynth<1, 7>, runEbsynth<2, 7>, runEbsynth<3, 7>, runEbsynth<4, 7>, runEbsynth<5, 7>, runEbsynth<6, 7>, runEbsynth<7, 7>, runEbsynth<8, 7> }, { ebsynthCuda<1, 7>, ebsynthCuda<2, 7>, ebsynthCuda<3, 7>, ebsynthCuda<4, 7>, ebsynthCuda<5, 7>, ebsynthCuda<6, 7>, ebsynthCuda<7, 7>, ebsynthCuda<8, 7> },
{ runEbsynth<1, 8>, runEbsynth<2, 8>, runEbsynth<3, 8>, runEbsynth<4, 8>, runEbsynth<5, 8>, runEbsynth<6, 8>, runEbsynth<7, 8>, runEbsynth<8, 8> }, { ebsynthCuda<1, 8>, ebsynthCuda<2, 8>, ebsynthCuda<3, 8>, ebsynthCuda<4, 8>, ebsynthCuda<5, 8>, ebsynthCuda<6, 8>, ebsynthCuda<7, 8>, ebsynthCuda<8, 8> },
{ runEbsynth<1, 9>, runEbsynth<2, 9>, runEbsynth<3, 9>, runEbsynth<4, 9>, runEbsynth<5, 9>, runEbsynth<6, 9>, runEbsynth<7, 9>, runEbsynth<8, 9> }, { ebsynthCuda<1, 9>, ebsynthCuda<2, 9>, ebsynthCuda<3, 9>, ebsynthCuda<4, 9>, ebsynthCuda<5, 9>, ebsynthCuda<6, 9>, ebsynthCuda<7, 9>, ebsynthCuda<8, 9> },
{ runEbsynth<1,10>, runEbsynth<2,10>, runEbsynth<3,10>, runEbsynth<4,10>, runEbsynth<5,10>, runEbsynth<6,10>, runEbsynth<7,10>, runEbsynth<8,10> }, { ebsynthCuda<1,10>, ebsynthCuda<2,10>, ebsynthCuda<3,10>, ebsynthCuda<4,10>, ebsynthCuda<5,10>, ebsynthCuda<6,10>, ebsynthCuda<7,10>, ebsynthCuda<8,10> },
{ runEbsynth<1,11>, runEbsynth<2,11>, runEbsynth<3,11>, runEbsynth<4,11>, runEbsynth<5,11>, runEbsynth<6,11>, runEbsynth<7,11>, runEbsynth<8,11> }, { ebsynthCuda<1,11>, ebsynthCuda<2,11>, ebsynthCuda<3,11>, ebsynthCuda<4,11>, ebsynthCuda<5,11>, ebsynthCuda<6,11>, ebsynthCuda<7,11>, ebsynthCuda<8,11> },
{ runEbsynth<1,12>, runEbsynth<2,12>, runEbsynth<3,12>, runEbsynth<4,12>, runEbsynth<5,12>, runEbsynth<6,12>, runEbsynth<7,12>, runEbsynth<8,12> }, { ebsynthCuda<1,12>, ebsynthCuda<2,12>, ebsynthCuda<3,12>, ebsynthCuda<4,12>, ebsynthCuda<5,12>, ebsynthCuda<6,12>, ebsynthCuda<7,12>, ebsynthCuda<8,12> },
{ runEbsynth<1,13>, runEbsynth<2,13>, runEbsynth<3,13>, runEbsynth<4,13>, runEbsynth<5,13>, runEbsynth<6,13>, runEbsynth<7,13>, runEbsynth<8,13> }, { ebsynthCuda<1,13>, ebsynthCuda<2,13>, ebsynthCuda<3,13>, ebsynthCuda<4,13>, ebsynthCuda<5,13>, ebsynthCuda<6,13>, ebsynthCuda<7,13>, ebsynthCuda<8,13> },
{ runEbsynth<1,14>, runEbsynth<2,14>, runEbsynth<3,14>, runEbsynth<4,14>, runEbsynth<5,14>, runEbsynth<6,14>, runEbsynth<7,14>, runEbsynth<8,14> }, { ebsynthCuda<1,14>, ebsynthCuda<2,14>, ebsynthCuda<3,14>, ebsynthCuda<4,14>, ebsynthCuda<5,14>, ebsynthCuda<6,14>, ebsynthCuda<7,14>, ebsynthCuda<8,14> },
{ runEbsynth<1,15>, runEbsynth<2,15>, runEbsynth<3,15>, runEbsynth<4,15>, runEbsynth<5,15>, runEbsynth<6,15>, runEbsynth<7,15>, runEbsynth<8,15> }, { ebsynthCuda<1,15>, ebsynthCuda<2,15>, ebsynthCuda<3,15>, ebsynthCuda<4,15>, ebsynthCuda<5,15>, ebsynthCuda<6,15>, ebsynthCuda<7,15>, ebsynthCuda<8,15> },
{ runEbsynth<1,16>, runEbsynth<2,16>, runEbsynth<3,16>, runEbsynth<4,16>, runEbsynth<5,16>, runEbsynth<6,16>, runEbsynth<7,16>, runEbsynth<8,16> }, { ebsynthCuda<1,16>, ebsynthCuda<2,16>, ebsynthCuda<3,16>, ebsynthCuda<4,16>, ebsynthCuda<5,16>, ebsynthCuda<6,16>, ebsynthCuda<7,16>, ebsynthCuda<8,16> },
{ runEbsynth<1,17>, runEbsynth<2,17>, runEbsynth<3,17>, runEbsynth<4,17>, runEbsynth<5,17>, runEbsynth<6,17>, runEbsynth<7,17>, runEbsynth<8,17> }, { ebsynthCuda<1,17>, ebsynthCuda<2,17>, ebsynthCuda<3,17>, ebsynthCuda<4,17>, ebsynthCuda<5,17>, ebsynthCuda<6,17>, ebsynthCuda<7,17>, ebsynthCuda<8,17> },
{ runEbsynth<1,18>, runEbsynth<2,18>, runEbsynth<3,18>, runEbsynth<4,18>, runEbsynth<5,18>, runEbsynth<6,18>, runEbsynth<7,18>, runEbsynth<8,18> }, { ebsynthCuda<1,18>, ebsynthCuda<2,18>, ebsynthCuda<3,18>, ebsynthCuda<4,18>, ebsynthCuda<5,18>, ebsynthCuda<6,18>, ebsynthCuda<7,18>, ebsynthCuda<8,18> },
{ runEbsynth<1,19>, runEbsynth<2,19>, runEbsynth<3,19>, runEbsynth<4,19>, runEbsynth<5,19>, runEbsynth<6,19>, runEbsynth<7,19>, runEbsynth<8,19> }, { ebsynthCuda<1,19>, ebsynthCuda<2,19>, ebsynthCuda<3,19>, ebsynthCuda<4,19>, ebsynthCuda<5,19>, ebsynthCuda<6,19>, ebsynthCuda<7,19>, ebsynthCuda<8,19> },
{ runEbsynth<1,20>, runEbsynth<2,20>, runEbsynth<3,20>, runEbsynth<4,20>, runEbsynth<5,20>, runEbsynth<6,20>, runEbsynth<7,20>, runEbsynth<8,20> }, { ebsynthCuda<1,20>, ebsynthCuda<2,20>, ebsynthCuda<3,20>, ebsynthCuda<4,20>, ebsynthCuda<5,20>, ebsynthCuda<6,20>, ebsynthCuda<7,20>, ebsynthCuda<8,20> },
{ runEbsynth<1,21>, runEbsynth<2,21>, runEbsynth<3,21>, runEbsynth<4,21>, runEbsynth<5,21>, runEbsynth<6,21>, runEbsynth<7,21>, runEbsynth<8,21> }, { ebsynthCuda<1,21>, ebsynthCuda<2,21>, ebsynthCuda<3,21>, ebsynthCuda<4,21>, ebsynthCuda<5,21>, ebsynthCuda<6,21>, ebsynthCuda<7,21>, ebsynthCuda<8,21> },
{ runEbsynth<1,22>, runEbsynth<2,22>, runEbsynth<3,22>, runEbsynth<4,22>, runEbsynth<5,22>, runEbsynth<6,22>, runEbsynth<7,22>, runEbsynth<8,22> }, { ebsynthCuda<1,22>, ebsynthCuda<2,22>, ebsynthCuda<3,22>, ebsynthCuda<4,22>, ebsynthCuda<5,22>, ebsynthCuda<6,22>, ebsynthCuda<7,22>, ebsynthCuda<8,22> },
{ runEbsynth<1,23>, runEbsynth<2,23>, runEbsynth<3,23>, runEbsynth<4,23>, runEbsynth<5,23>, runEbsynth<6,23>, runEbsynth<7,23>, runEbsynth<8,23> }, { ebsynthCuda<1,23>, ebsynthCuda<2,23>, ebsynthCuda<3,23>, ebsynthCuda<4,23>, ebsynthCuda<5,23>, ebsynthCuda<6,23>, ebsynthCuda<7,23>, ebsynthCuda<8,23> },
{ runEbsynth<1,24>, runEbsynth<2,24>, runEbsynth<3,24>, runEbsynth<4,24>, runEbsynth<5,24>, runEbsynth<6,24>, runEbsynth<7,24>, runEbsynth<8,24> } { ebsynthCuda<1,24>, ebsynthCuda<2,24>, ebsynthCuda<3,24>, ebsynthCuda<4,24>, ebsynthCuda<5,24>, ebsynthCuda<6,24>, ebsynthCuda<7,24>, ebsynthCuda<8,24> }
}; };
if (numStyleChannels>=1 && numStyleChannels<=EBSYNTH_MAX_STYLE_CHANNELS && if (numStyleChannels>=1 && numStyleChannels<=EBSYNTH_MAX_STYLE_CHANNELS &&
numGuideChannels>=1 && numGuideChannels<=EBSYNTH_MAX_GUIDE_CHANNELS) numGuideChannels>=1 && numGuideChannels<=EBSYNTH_MAX_GUIDE_CHANNELS)
{ {
dispatchEbsynth[numGuideChannels-1][numStyleChannels-1](ebsynthBackend, dispatchEbsynth[numGuideChannels-1][numStyleChannels-1](numStyleChannels,
numStyleChannels,
numGuideChannels, numGuideChannels,
sourceWidth, sourceWidth,
sourceHeight, sourceHeight,
@@ -800,15 +1160,13 @@ EBSYNTH_API void ebsynthRun(int ebsynthBackend,
numSearchVoteItersPerLevel, numSearchVoteItersPerLevel,
numPatchMatchItersPerLevel, numPatchMatchItersPerLevel,
stopThresholdPerLevel, stopThresholdPerLevel,
outputData); outputNnfData,
outputImageData);
} }
} }
EBSYNTH_API int ebsynthBackendAvailableCuda()
int ebsynthBackendAvailable(int ebsynthBackend)
{ {
if (ebsynthBackend==EBSYNTH_BACKEND_CUDA)
{
int deviceCount = -1; int deviceCount = -1;
if (cudaGetDeviceCount(&deviceCount)!=cudaSuccess) { return 0; } if (cudaGetDeviceCount(&deviceCount)!=cudaSuccess) { return 0; }
@@ -823,462 +1181,6 @@ int ebsynthBackendAvailable(int ebsynthBackend)
} }
} }
} }
}
return 0;
}
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
#include <cstdio>
#include <cmath>
#include <vector>
#include <string>
#include <algorithm>
#include "jzq.h"
template<typename FUNC>
bool tryToParseArg(const std::vector<std::string>& args,int* inout_argi,const char* name,bool* out_fail,FUNC handler)
{
int& argi = *inout_argi;
bool& fail = *out_fail;
if (argi<0 || argi>=args.size()) { fail = true; return false; }
if (args[argi]==name)
{
argi++;
fail = !handler();
return true;
}
fail = false; return false;
}
bool tryToParseIntArg(const std::vector<std::string>& args,int* inout_argi,const char* name,int* out_value,bool* out_fail)
{
return tryToParseArg(args,inout_argi,name,out_fail,[&]
{
int& argi = *inout_argi;
if (argi<args.size())
{
const std::string& arg = args[argi];
try
{
std::size_t pos = 0;
*out_value = std::stoi(arg,&pos);
if (pos!=arg.size()) { printf("error: bad %s argument '%s'\n",name,arg.c_str()); return false; }
return true;
}
catch(...)
{
printf("error: bad %s argument '%s'\n",name,arg.c_str());
return false;
}
}
printf("error: missing argument for the %s option\n",name);
return false;
});
}
bool tryToParseFloatArg(const std::vector<std::string>& args,int* inout_argi,const char* name,float* out_value,bool* out_fail)
{
return tryToParseArg(args,inout_argi,name,out_fail,[&]
{
int& argi = *inout_argi;
if (argi<args.size())
{
const std::string& arg = args[argi];
try
{
std::size_t pos = 0;
*out_value = std::stof(arg,&pos);
if (pos!=arg.size()) { printf("error: bad %s argument '%s'\n",name,arg.c_str()); return false; }
return true;
}
catch(...)
{
printf("error: bad %s argument '%s'\n",name,args[argi].c_str());
return false;
}
}
printf("error: missing argument for the %s option\n",name);
return false;
});
}
bool tryToParseStringArg(const std::vector<std::string>& args,int* inout_argi,const char* name,std::string* out_value,bool* out_fail)
{
return tryToParseArg(args,inout_argi,name,out_fail,[&]
{
int& argi = *inout_argi;
if (argi<args.size())
{
*out_value = args[argi];
return true;
}
printf("error: missing argument for the %s option\n",name);
return false;
});
}
bool tryToParseStringPairArg(const std::vector<std::string>& args,int* inout_argi,const char* name,std::pair<std::string,std::string>* out_value,bool* out_fail)
{
return tryToParseArg(args,inout_argi,name,out_fail,[&]
{
int& argi = *inout_argi;
if ((argi+1)<args.size())
{
*out_value = std::make_pair(args[argi],args[argi+1]);
argi++;
return true;
}
printf("error: missing argument for the %s option\n",name);
return false;
});
}
#define STB_IMAGE_IMPLEMENTATION
#include "stb_image.h"
#define STB_IMAGE_WRITE_IMPLEMENTATION
#include "stb_image_write.h"
unsigned char* tryLoad(const std::string& fileName,int* width,int* height)
{
unsigned char* data = stbi_load(fileName.c_str(),width,height,NULL,4);
if (data==NULL)
{
printf("error: failed to load '%s'\n",fileName.c_str());
printf("%s\n",stbi_failure_reason());
exit(1);
}
return data;
}
int evalNumChannels(const unsigned char* data,const int numPixels)
{
bool isGray = true;
bool hasAlpha = false;
for(int xy=0;xy<numPixels;xy++)
{
const unsigned char r = data[xy*4+0];
const unsigned char g = data[xy*4+1];
const unsigned char b = data[xy*4+2];
const unsigned char a = data[xy*4+3];
if (!(r==g && g==b)) { isGray = false; }
if (a<255) { hasAlpha = true; }
}
const int numChannels = (isGray ? 1 : 3) + (hasAlpha ? 1 : 0);
return numChannels;
}
V2i pyramidLevelSize(const V2i& sizeBase,const int level)
{
return V2i(V2f(sizeBase)*pow(2.0f,-float(level)));
}
int main(int argc,char** argv)
{
if (argc<2)
{
printf("usage: %s [options]\n",argv[0]);
printf("\n");
printf("options:\n");
printf(" -style <style.png>\n");
printf(" -guide <source.png> <target.png>\n");
printf(" -output <output.png>\n");
printf(" -weight <value>\n");
printf(" -uniformity <value>\n");
printf(" -patchsize <size>\n");
printf(" -pyramidlevels <number>\n");
printf(" -searchvoteiters <number>\n");
printf(" -patchmatchiters <number>\n");
printf(" -stopthreshold <value>\n");
printf("\n");
return 1;
}
std::string styleFileName;
float styleWeight = NAN;
std::string outputFileName = "output.png";
struct Guide
{
std::string sourceFileName;
std::string targetFileName;
float weight;
int sourceWidth;
int sourceHeight;
unsigned char* sourceData;
int targetWidth;
int targetHeight;
unsigned char* targetData;
int numChannels;
};
std::vector<Guide> guides;
float uniformityWeight = 3500;
int patchSize = 5;
int numPyramidLevels = -1;
int numSearchVoteIters = 6;
int numPatchMatchIters = 4;
int stopThreshold = 5;
std::string backend;
{
std::vector<std::string> args(argc);
for(int i=0;i<argc;i++) { args[i] = argv[i]; }
bool fail = false;
int argi = 1;
float* precedingStyleOrGuideWeight = 0;
while(argi<argc && !fail)
{
float weight;
std::pair<std::string,std::string> guidePair;
if (tryToParseStringArg(args,&argi,"-style",&styleFileName,&fail))
{
styleWeight = NAN;
precedingStyleOrGuideWeight = &styleWeight;
argi++;
}
else if (tryToParseStringPairArg(args,&argi,"-guide",&guidePair,&fail))
{
Guide guide;
guide.sourceFileName = guidePair.first;
guide.targetFileName = guidePair.second;
guide.weight = NAN;
guides.push_back(guide);
precedingStyleOrGuideWeight = &guides[guides.size()-1].weight;
argi++;
}
else if (tryToParseStringArg(args,&argi,"-output",&outputFileName,&fail))
{
argi++;
}
else if (tryToParseFloatArg(args,&argi,"-weight",&weight,&fail))
{
if (precedingStyleOrGuideWeight!=0) { *precedingStyleOrGuideWeight = weight; }
else { printf("error: at least one -style or -guide option must precede the -weight option!\n"); return 1; }
argi++;
}
else if (tryToParseFloatArg(args,&argi,"-uniformity",&uniformityWeight,&fail)) { argi++; }
else if (tryToParseIntArg(args,&argi,"-patchsize",&patchSize,&fail))
{
if (patchSize<3) { printf("error: patchsize is too small!\n"); return 1; }
if (patchSize%2==0) { printf("error: patchsize must be an odd number!\n"); return 1; }
argi++;
}
else if (tryToParseIntArg(args,&argi,"-pyramidlevels",&numPyramidLevels,&fail))
{
if (numPyramidLevels<1) { printf("error: bad argument for -pyramidlevels!\n"); return 1; }
argi++;
}
else if (tryToParseIntArg(args,&argi,"-searchvoteiters",&numSearchVoteIters,&fail))
{
if (numSearchVoteIters<0) { printf("error: bad argument for -searchvoteiters!\n"); return 1; }
argi++;
}
else if (tryToParseIntArg(args,&argi,"-patchmatchiters",&numPatchMatchIters,&fail))
{
if (numPatchMatchIters<0) { printf("error: bad argument for -patchmatchiters!\n"); return 1; }
argi++;
}
else if (tryToParseIntArg(args,&argi,"-stopthreshold",&stopThreshold,&fail))
{
if (stopThreshold<0) { printf("error: bad argument for -stopthreshold!\n"); return 1; }
argi++;
}
else
{
printf("error: unrecognized option '%s'\n",args[argi].c_str());
fail = true;
}
}
if (fail) { return 1; }
}
const int numGuides = guides.size();
int sourceWidth = 0;
int sourceHeight = 0;
unsigned char* sourceStyleData = tryLoad(styleFileName,&sourceWidth,&sourceHeight);
const int numStyleChannelsTotal = evalNumChannels(sourceStyleData,sourceWidth*sourceHeight);
std::vector<unsigned char> sourceStyle(sourceWidth*sourceHeight*numStyleChannelsTotal);
for(int xy=0;xy<sourceWidth*sourceHeight;xy++)
{
if (numStyleChannelsTotal>0) { sourceStyle[xy*numStyleChannelsTotal+0] = sourceStyleData[xy*4+0]; }
if (numStyleChannelsTotal==2) { sourceStyle[xy*numStyleChannelsTotal+1] = sourceStyleData[xy*4+3]; }
else if (numStyleChannelsTotal>1) { sourceStyle[xy*numStyleChannelsTotal+1] = sourceStyleData[xy*4+1]; }
if (numStyleChannelsTotal>2) { sourceStyle[xy*numStyleChannelsTotal+2] = sourceStyleData[xy*4+2]; }
if (numStyleChannelsTotal>3) { sourceStyle[xy*numStyleChannelsTotal+3] = sourceStyleData[xy*4+3]; }
}
int targetWidth = 0;
int targetHeight = 0;
int numGuideChannelsTotal = 0;
for(int i=0;i<numGuides;i++)
{
Guide& guide = guides[i];
guide.sourceData = tryLoad(guide.sourceFileName,&guide.sourceWidth,&guide.sourceHeight);
guide.targetData = tryLoad(guide.targetFileName,&guide.targetWidth,&guide.targetHeight);
if (guide.sourceWidth!=sourceWidth || guide.sourceHeight!=sourceHeight) { printf("error: source guide '%s' doesn't match the resolution of '%s'\n",guide.sourceFileName.c_str(),styleFileName.c_str()); return 1; }
if (i>0 && (guide.targetWidth!=targetWidth || guide.targetHeight!=targetHeight)) { printf("error: target guide '%s' doesn't match the resolution of '%s'\n",guide.targetFileName.c_str(),guides[0].targetFileName.c_str()); return 1; }
else if (i==0) { targetWidth = guide.targetWidth; targetHeight = guide.targetHeight; }
guide.numChannels = std::max(evalNumChannels(guide.sourceData,sourceWidth*sourceHeight),
evalNumChannels(guide.targetData,targetWidth*targetHeight));
numGuideChannelsTotal += guide.numChannels;
}
if (numStyleChannelsTotal>EBSYNTH_MAX_STYLE_CHANNELS) { printf("error: too many style channels (%d), maximum number is %d\n",numStyleChannelsTotal,EBSYNTH_MAX_STYLE_CHANNELS); return 1; }
if (numGuideChannelsTotal>EBSYNTH_MAX_GUIDE_CHANNELS) { printf("error: too many guide channels (%d), maximum number is %d\n",numGuideChannelsTotal,EBSYNTH_MAX_GUIDE_CHANNELS); return 1; }
std::vector<unsigned char> sourceGuides(sourceWidth*sourceHeight*numGuideChannelsTotal);
for(int xy=0;xy<sourceWidth*sourceHeight;xy++)
{
int c = 0;
for(int i=0;i<numGuides;i++)
{
const int numChannels = guides[i].numChannels;
if (numChannels>0) { sourceGuides[xy*numGuideChannelsTotal+c+0] = guides[i].sourceData[xy*4+0]; }
if (numChannels==2) { sourceGuides[xy*numGuideChannelsTotal+c+1] = guides[i].sourceData[xy*4+3]; }
else if (numChannels>1) { sourceGuides[xy*numGuideChannelsTotal+c+1] = guides[i].sourceData[xy*4+1]; }
if (numChannels>2) { sourceGuides[xy*numGuideChannelsTotal+c+2] = guides[i].sourceData[xy*4+2]; }
if (numChannels>3) { sourceGuides[xy*numGuideChannelsTotal+c+3] = guides[i].sourceData[xy*4+3]; }
c += numChannels;
}
}
std::vector<unsigned char> targetGuides(targetWidth*targetHeight*numGuideChannelsTotal);
for(int xy=0;xy<targetWidth*targetHeight;xy++)
{
int c = 0;
for(int i=0;i<numGuides;i++)
{
const int numChannels = guides[i].numChannels;
if (numChannels>0) { targetGuides[xy*numGuideChannelsTotal+c+0] = guides[i].targetData[xy*4+0]; }
if (numChannels==2) { targetGuides[xy*numGuideChannelsTotal+c+1] = guides[i].targetData[xy*4+3]; }
else if (numChannels>1) { targetGuides[xy*numGuideChannelsTotal+c+1] = guides[i].targetData[xy*4+1]; }
if (numChannels>2) { targetGuides[xy*numGuideChannelsTotal+c+2] = guides[i].targetData[xy*4+2]; }
if (numChannels>3) { targetGuides[xy*numGuideChannelsTotal+c+3] = guides[i].targetData[xy*4+3]; }
c += numChannels;
}
}
std::vector<float> styleWeights(numStyleChannelsTotal);
if (isnan(styleWeight)) { styleWeight = 1.0f; }
for(int i=0;i<numStyleChannelsTotal;i++) { styleWeights[i] = styleWeight / float(numStyleChannelsTotal); }
for(int i=0;i<numGuides;i++) { if (isnan(guides[i].weight)) { guides[i].weight = 1.0f/float(numGuides); } }
std::vector<float> guideWeights(numGuideChannelsTotal);
{
int c = 0;
for(int i=0;i<numGuides;i++)
{
const int numChannels = guides[i].numChannels;
for(int j=0;j<numChannels;j++)
{
guideWeights[c+j] = guides[i].weight / float(numChannels);
}
c += numChannels;
}
}
int maxPyramidLevels = 0;
for(int level=32;level>=0;level--)
{
if (min(pyramidLevelSize(std::min(V2i(sourceWidth,sourceHeight),V2i(targetWidth,targetHeight)),level)) >= (2*patchSize+1))
{
maxPyramidLevels = level+1;
break;
}
}
if (numPyramidLevels==-1) { numPyramidLevels = maxPyramidLevels; }
numPyramidLevels = std::min(numPyramidLevels,maxPyramidLevels);
std::vector<int> numSearchVoteItersPerLevel(numPyramidLevels);
std::vector<int> numPatchMatchItersPerLevel(numPyramidLevels);
std::vector<int> stopThresholdPerLevel(numPyramidLevels);
for(int i=0;i<numPyramidLevels;i++)
{
numSearchVoteItersPerLevel[i] = numSearchVoteIters;
numPatchMatchItersPerLevel[i] = numPatchMatchIters;
stopThresholdPerLevel[i] = stopThreshold;
}
std::vector<unsigned char> output(targetWidth*targetHeight*numStyleChannelsTotal);
printf("uniformity: %.0f\n",uniformityWeight);
printf("patchsize: %d\n",patchSize);
printf("pyramidlevels: %d\n",numPyramidLevels);
printf("searchvoteiters: %d\n",numSearchVoteIters);
printf("patchmatchiters: %d\n",numPatchMatchIters);
printf("stopthreshold: %d\n",stopThreshold);
if (!ebsynthBackendAvailable(EBSYNTH_BACKEND_CUDA)) { printf("error: the CUDA backend is not available!\n"); return 1; }
ebsynthRun(EBSYNTH_BACKEND_CUDA,
numStyleChannelsTotal,
numGuideChannelsTotal,
sourceWidth,
sourceHeight,
sourceStyle.data(),
sourceGuides.data(),
targetWidth,
targetHeight,
targetGuides.data(),
NULL,
styleWeights.data(),
guideWeights.data(),
uniformityWeight,
patchSize,
EBSYNTH_VOTEMODE_PLAIN,
numPyramidLevels,
numSearchVoteItersPerLevel.data(),
numPatchMatchItersPerLevel.data(),
stopThresholdPerLevel.data(),
output.data());
stbi_write_png(outputFileName.c_str(),targetWidth,targetHeight,numStyleChannelsTotal,output.data(),numStyleChannelsTotal*targetWidth);
printf("result was written to %s\n",outputFileName.c_str());
stbi_image_free(sourceStyleData);
for(int i=0;i<numGuides;i++)
{
stbi_image_free(guides[i].sourceData);
stbi_image_free(guides[i].targetData);
}
return 0; return 0;
} }

32
src/ebsynth_cuda.h Normal file
View File

@@ -0,0 +1,32 @@
// This software is in the public domain. Where that dedication is not
// recognized, you are granted a perpetual, irrevocable license to copy
// and modify this file as you see fit.
#ifndef EBSYNTH_CUDA_H_
#define EBSYNTH_CUDA_H_
void ebsynthRunCuda(int numStyleChannels,
int numGuideChannels,
int sourceWidth,
int sourceHeight,
void* sourceStyleData,
void* sourceGuideData,
int targetWidth,
int targetHeight,
void* targetGuideData,
void* targetModulationData,
float* styleWeights,
float* guideWeights,
float uniformityWeight,
int patchSize,
int voteMode,
int numPyramidLevels,
int* numSearchVoteItersPerLevel,
int* numPatchMatchItersPerLevel,
int* stopThresholdPerLevel,
void* outputNnfData,
void* outputImageData);
int ebsynthBackendAvailableCuda();
#endif

View File

@@ -1,5 +1,5 @@
#ifndef CUDACHECK_H_ #ifndef EBSYNTH_CUDA_CHECK_H_
#define CUDACHECK_H_ #define EBSYNTH_CUDA_CHECK_H_
template<typename T> template<typename T>
bool checkCudaError_(T result,char const* const func,const char* const file,int const line) bool checkCudaError_(T result,char const* const func,const char* const file,int const line)

View File

@@ -2,11 +2,11 @@
// recognized, you are granted a perpetual, irrevocable license to copy // recognized, you are granted a perpetual, irrevocable license to copy
// and modify this file as you see fit. // and modify this file as you see fit.
#ifndef MEMARRAY2_H_ #ifndef EBSYNTH_CUDA_MEMARRAY2_H_
#define MEMARRAY2_H_ #define EBSYNTH_CUDA_MEMARRAY2_H_
#include "jzq.h" #include "jzq.h"
//#include "cudacheck.h" #include "ebsynth_cuda_check.h"
template<typename T> template<typename T>
struct MemArray2 struct MemArray2

View File

@@ -2,11 +2,11 @@
// recognized, you are granted a perpetual, irrevocable license to copy // recognized, you are granted a perpetual, irrevocable license to copy
// and modify this file as you see fit. // and modify this file as you see fit.
#ifndef TEXARRAY2_H_ #ifndef EBSYNTH_CUDA_TEXARRAY2_H_
#define TEXARRAY2_H_ #define EBSYNTH_CUDA_TEXARRAY2_H_
#include "jzq.h" #include "jzq.h"
#include "cudacheck.h" #include "ebsynth_cuda_check.h"
#include <cuda_runtime.h> #include <cuda_runtime.h>

33
src/ebsynth_nocuda.cpp Normal file
View File

@@ -0,0 +1,33 @@
// This software is in the public domain. Where that dedication is not
// recognized, you are granted a perpetual, irrevocable license to copy
// and modify this file as you see fit.
void ebsynthRunCuda(int numStyleChannels,
int numGuideChannels,
int sourceWidth,
int sourceHeight,
void* sourceStyleData,
void* sourceGuideData,
int targetWidth,
int targetHeight,
void* targetGuideData,
void* targetModulationData,
float* styleWeights,
float* guideWeights,
float uniformityWeight,
int patchSize,
int voteMode,
int numPyramidLevels,
int* numSearchVoteItersPerLevel,
int* numPatchMatchItersPerLevel,
int* stopThresholdPerLevel,
void* outputNnfData,
void* outputImageData)
{
}
int ebsynthBackendAvailableCuda()
{
return 0;
}

287
src/jzq.h
View File

@@ -13,10 +13,16 @@
#include <string> #include <string>
#include <algorithm> #include <algorithm>
template<typename T> struct zero { static __host__ __device__ T value(); }; #ifdef __CUDACC__
#define JZQ_DECORATOR __host__ __device__
#else
#define JZQ_DECORATOR
#endif
template<typename T> inline T clamp(const T& x,const T& xmin,const T& xmax); template<typename T> struct zero { static JZQ_DECORATOR T value(); };
template<typename T> inline T lerp(const T& a,const T& b,const T& t);
template<typename T> JZQ_DECORATOR inline T clamp(const T& x,const T& xmin,const T& xmax);
template<typename T> JZQ_DECORATOR inline T lerp(const T& a,const T& b,float t);
inline std::string spf(const std::string fmt,...); inline std::string spf(const std::string fmt,...);
@@ -25,53 +31,53 @@ struct Vec
{ {
T v[N]; T v[N];
__host__ __device__ Vec<N,T>(); JZQ_DECORATOR Vec<N,T>();
template<typename T2> __host__ __device__ explicit Vec<N,T>(const Vec<N,T2>& u); template<typename T2> JZQ_DECORATOR explicit Vec<N,T>(const Vec<N,T2>& u);
explicit __host__ __device__ Vec<N,T>(T v0); explicit JZQ_DECORATOR Vec<N,T>(T v0);
__host__ __device__ Vec<N,T>(T v0,T v1); JZQ_DECORATOR Vec<N,T>(T v0,T v1);
__host__ __device__ Vec<N,T>(T v0,T v1,T v2); JZQ_DECORATOR Vec<N,T>(T v0,T v1,T v2);
__host__ __device__ Vec<N,T>(T v0,T v1,T v2,T v3); JZQ_DECORATOR Vec<N,T>(T v0,T v1,T v2,T v3);
__host__ __device__ Vec<N,T>(T v0,T v1,T v2,T v3,T v4); JZQ_DECORATOR Vec<N,T>(T v0,T v1,T v2,T v3,T v4);
__host__ __device__ Vec<N,T>(T v0,T v1,T v2,T v3,T v4,T v5); JZQ_DECORATOR Vec<N,T>(T v0,T v1,T v2,T v3,T v4,T v5);
__host__ __device__ T& operator()(int i); JZQ_DECORATOR T& operator()(int i);
__host__ __device__ const T& operator()(int i) const; JZQ_DECORATOR const T& operator()(int i) const;
__host__ __device__ T& operator[](int i); JZQ_DECORATOR T& operator[](int i);
__host__ __device__ const T& operator[](int i) const; JZQ_DECORATOR const T& operator[](int i) const;
__host__ __device__ Vec<N,T> operator*=(const Vec<N,T>& u); JZQ_DECORATOR Vec<N,T> operator*=(const Vec<N,T>& u);
__host__ __device__ Vec<N,T> operator+=(const Vec<N,T>& u); JZQ_DECORATOR Vec<N,T> operator+=(const Vec<N,T>& u);
__host__ __device__ Vec<N,T> operator*=(T s); JZQ_DECORATOR Vec<N,T> operator*=(T s);
__host__ __device__ Vec<N,T> operator+=(T s); JZQ_DECORATOR Vec<N,T> operator+=(T s);
}; };
template<int N,typename T> Vec<N,T> __host__ __device__ operator-(const Vec<N,T>& u); template<int N,typename T> Vec<N,T> JZQ_DECORATOR operator-(const Vec<N,T>& u);
template<int N,typename T> Vec<N,T> __host__ __device__ operator+(const Vec<N,T>& u,const Vec<N,T>& v); template<int N,typename T> Vec<N,T> JZQ_DECORATOR operator+(const Vec<N,T>& u,const Vec<N,T>& v);
template<int N,typename T> Vec<N,T> __host__ __device__ operator-(const Vec<N,T>& u,const Vec<N,T>& v); template<int N,typename T> Vec<N,T> JZQ_DECORATOR operator-(const Vec<N,T>& u,const Vec<N,T>& v);
template<int N,typename T> Vec<N,T> __host__ __device__ operator-(const Vec<N,T>& u,const T v); template<int N,typename T> Vec<N,T> JZQ_DECORATOR operator-(const Vec<N,T>& u,const T v);
template<int N,typename T> Vec<N,T> __host__ __device__ operator*(const Vec<N,T>& u,const Vec<N,T>& v); template<int N,typename T> Vec<N,T> JZQ_DECORATOR operator*(const Vec<N,T>& u,const Vec<N,T>& v);
template<int N,typename T> Vec<N,T> __host__ __device__ operator/(const Vec<N,T>& u,const Vec<N,T>& v); template<int N,typename T> Vec<N,T> JZQ_DECORATOR operator/(const Vec<N,T>& u,const Vec<N,T>& v);
template<int N,typename T> Vec<N,T> __host__ __device__ operator*(const T s,const Vec<N,T>& u); template<int N,typename T> Vec<N,T> JZQ_DECORATOR operator*(const T s,const Vec<N,T>& u);
template<int N,typename T> Vec<N,T> __host__ __device__ operator*(const Vec<N,T>& u,const T s); template<int N,typename T> Vec<N,T> JZQ_DECORATOR operator*(const Vec<N,T>& u,const T s);
template<int N,typename T> Vec<N,T> __host__ __device__ operator/(const Vec<N,T>& u,const T s); template<int N,typename T> Vec<N,T> JZQ_DECORATOR operator/(const Vec<N,T>& u,const T s);
template<int N,typename T> Vec<N,bool> __host__ __device__ operator<(const Vec<N,T>& u,const Vec<N,T>& v); template<int N,typename T> Vec<N,bool> JZQ_DECORATOR operator<(const Vec<N,T>& u,const Vec<N,T>& v);
template<int N,typename T> Vec<N,bool> __host__ __device__ operator>(const Vec<N,T>& u,const Vec<N,T>& v); template<int N,typename T> Vec<N,bool> JZQ_DECORATOR operator>(const Vec<N,T>& u,const Vec<N,T>& v);
template<int N,typename T> Vec<N,bool> __host__ __device__ operator<=(const Vec<N,T>& u,const Vec<N,T>& v); template<int N,typename T> Vec<N,bool> JZQ_DECORATOR operator<=(const Vec<N,T>& u,const Vec<N,T>& v);
template<int N,typename T> Vec<N,bool> __host__ __device__ operator>=(const Vec<N,T>& u,const Vec<N,T>& v); template<int N,typename T> Vec<N,bool> JZQ_DECORATOR operator>=(const Vec<N,T>& u,const Vec<N,T>& v);
template<int N,typename T> Vec<N,bool> __host__ __device__ operator==(const Vec<N,T>& u,const Vec<N,T>& v); template<int N,typename T> Vec<N,bool> JZQ_DECORATOR operator==(const Vec<N,T>& u,const Vec<N,T>& v);
template<int N,typename T> Vec<N,bool> __host__ __device__ operator!=(const Vec<N,T>& u,const Vec<N,T>& v); template<int N,typename T> Vec<N,bool> JZQ_DECORATOR operator!=(const Vec<N,T>& u,const Vec<N,T>& v);
template<int N,typename T> inline T dot(const Vec<N,T>& u,const Vec<N,T>& v); template<int N,typename T> JZQ_DECORATOR inline T dot(const Vec<N,T>& u,const Vec<N,T>& v);
template<typename T> inline T cross(const Vec<2,T> &a,const Vec<2,T> &b); template<typename T> JZQ_DECORATOR inline T cross(const Vec<2,T> &a,const Vec<2,T> &b);
template<typename T> inline Vec<3,T> cross(const Vec<3,T> &a,const Vec<3,T> &b); template<typename T> JZQ_DECORATOR inline Vec<3,T> cross(const Vec<3,T> &a,const Vec<3,T> &b);
template<int N,typename T> inline T norm(const Vec<N,T>& u); template<int N,typename T> JZQ_DECORATOR inline T norm(const Vec<N,T>& u);
template<int N,typename T> inline Vec<N,T> normalize(const Vec<N,T>& u); template<int N,typename T> JZQ_DECORATOR inline Vec<N,T> normalize(const Vec<N,T>& u);
template<int N,typename T> inline T min(const Vec<N,T>& u); template<int N,typename T> JZQ_DECORATOR inline T min(const Vec<N,T>& u);
template<int N,typename T> inline T max(const Vec<N,T>& u); template<int N,typename T> JZQ_DECORATOR inline T max(const Vec<N,T>& u);
template<int N,typename T> inline T sum(const Vec<N,T>& u); template<int N,typename T> JZQ_DECORATOR inline T sum(const Vec<N,T>& u);
namespace std namespace std
{ {
template<int N,typename T> inline Vec<N,T> min(const Vec<N,T>& u,const Vec<N,T>& v); template<int N,typename T> inline Vec<N,T> min(const Vec<N,T>& u,const Vec<N,T>& v);
@@ -196,6 +202,7 @@ public:
const T* data() const; const T* data() const;
void clear(); void clear();
void swap(Array3<T>& b); void swap(Array3<T>& b);
bool empty() const;
private: private:
Vec<3,int> s; Vec<3,int> s;
@@ -542,19 +549,19 @@ typedef Array3< Vec<4,unsigned short> > A3V4us;
typedef Array3< Vec<4,char> > A3V4c; typedef Array3< Vec<4,char> > A3V4c;
typedef Array3< Vec<4,unsigned char> > A3V4uc; typedef Array3< Vec<4,unsigned char> > A3V4uc;
template<> struct zero<char > { static __host__ __device__ char value() { return 0; } }; template<> struct zero<char > { static JZQ_DECORATOR char value() { return 0; } };
template<> struct zero<unsigned char > { static __host__ __device__ unsigned char value() { return 0; } }; template<> struct zero<unsigned char > { static JZQ_DECORATOR unsigned char value() { return 0; } };
template<> struct zero<short > { static __host__ __device__ short value() { return 0; } }; template<> struct zero<short > { static JZQ_DECORATOR short value() { return 0; } };
template<> struct zero<unsigned short> { static __host__ __device__ unsigned short value() { return 0; } }; template<> struct zero<unsigned short> { static JZQ_DECORATOR unsigned short value() { return 0; } };
template<> struct zero<int > { static __host__ __device__ int value() { return 0; } }; template<> struct zero<int > { static JZQ_DECORATOR int value() { return 0; } };
template<> struct zero<unsigned int > { static __host__ __device__ unsigned int value() { return 0; } }; template<> struct zero<unsigned int > { static JZQ_DECORATOR unsigned int value() { return 0; } };
template<> struct zero<float > { static __host__ __device__ float value() { return 0.0f; } }; template<> struct zero<float > { static JZQ_DECORATOR float value() { return 0.0f; } };
template<> struct zero<double > { static __host__ __device__ double value() { return 0.0; } }; template<> struct zero<double > { static JZQ_DECORATOR double value() { return 0.0; } };
template<int N,typename T> template<int N,typename T>
struct zero<Vec<N,T>> struct zero<Vec<N,T>>
{ {
static __host__ __device__ Vec<N,T> value() static JZQ_DECORATOR Vec<N,T> value()
{ {
Vec<N,T> z; Vec<N,T> z;
for(int i=0;i<N;i++) { z[i] = zero<T>::value(); } for(int i=0;i<N;i++) { z[i] = zero<T>::value(); }
@@ -565,7 +572,7 @@ struct zero<Vec<N,T>>
template<int M,int N,typename T> template<int M,int N,typename T>
struct zero<Mat<M,N,T>> struct zero<Mat<M,N,T>>
{ {
static __host__ __device__ Mat<M,N,T> value() static JZQ_DECORATOR Mat<M,N,T> value()
{ {
Mat<M,N,T> z; Mat<M,N,T> z;
for(int i=0;i<M;i++) for(int i=0;i<M;i++)
@@ -577,16 +584,16 @@ struct zero<Mat<M,N,T>>
} }
}; };
template <typename T> inline template <typename T> JZQ_DECORATOR inline
T clamp(const T& x,const T& xmin,const T& xmax) T clamp(const T& x,const T& xmin,const T& xmax)
{ {
return std::min(std::max(x,xmin),xmax); return std::min(std::max(x,xmin),xmax);
} }
template <typename T> inline template <typename T> JZQ_DECORATOR inline
T lerp(const T& a,const T& b,const T& t) T lerp(const T& a,const T& b,float t)
{ {
return (1.0-t)*a+t*b; return (1.0f-t)*a+t*b;
} }
inline std::string spf(const std::string fmt,...) inline std::string spf(const std::string fmt,...)
@@ -623,13 +630,13 @@ inline std::string spf(const std::string fmt,...)
} }
template<int N,typename T> template<int N,typename T>
__host__ __device__ JZQ_DECORATOR
Vec<N,T>::Vec() Vec<N,T>::Vec()
{ {
} }
template<int N,typename T> template<int N,typename T>
__host__ __device__ JZQ_DECORATOR
Vec<N,T>::Vec(T v0) Vec<N,T>::Vec(T v0)
{ {
assert(N==1); assert(N==1);
@@ -637,7 +644,7 @@ Vec<N,T>::Vec(T v0)
} }
template<int N,typename T> template<int N,typename T>
__host__ __device__ JZQ_DECORATOR
Vec<N,T>::Vec(T v0,T v1) Vec<N,T>::Vec(T v0,T v1)
{ {
assert(N==2); assert(N==2);
@@ -645,7 +652,7 @@ Vec<N,T>::Vec(T v0,T v1)
} }
template<int N,typename T> template<int N,typename T>
__host__ __device__ JZQ_DECORATOR
Vec<N,T>::Vec(T v0,T v1,T v2) Vec<N,T>::Vec(T v0,T v1,T v2)
{ {
assert(N==3); assert(N==3);
@@ -653,7 +660,7 @@ Vec<N,T>::Vec(T v0,T v1,T v2)
} }
template<int N,typename T> template<int N,typename T>
__host__ __device__ JZQ_DECORATOR
Vec<N,T>::Vec(T v0,T v1,T v2,T v3) Vec<N,T>::Vec(T v0,T v1,T v2,T v3)
{ {
assert(N==4); assert(N==4);
@@ -661,7 +668,7 @@ Vec<N,T>::Vec(T v0,T v1,T v2,T v3)
} }
template<int N,typename T> template<int N,typename T>
__host__ __device__ JZQ_DECORATOR
Vec<N,T>::Vec(T v0,T v1,T v2,T v3,T v4) Vec<N,T>::Vec(T v0,T v1,T v2,T v3,T v4)
{ {
assert(N==5); assert(N==5);
@@ -669,7 +676,7 @@ Vec<N,T>::Vec(T v0,T v1,T v2,T v3,T v4)
} }
template<int N,typename T> template<int N,typename T>
__host__ __device__ JZQ_DECORATOR
Vec<N,T>::Vec(T v0,T v1,T v2,T v3,T v4,T v5) Vec<N,T>::Vec(T v0,T v1,T v2,T v3,T v4,T v5)
{ {
assert(N==6); assert(N==6);
@@ -677,7 +684,7 @@ Vec<N,T>::Vec(T v0,T v1,T v2,T v3,T v4,T v5)
} }
template<int N,typename T> template<typename T2> template<int N,typename T> template<typename T2>
__host__ __device__ JZQ_DECORATOR
Vec<N,T>::Vec(const Vec<N,T2>& u) Vec<N,T>::Vec(const Vec<N,T2>& u)
{ {
for(int i=0;i<N;i++) for(int i=0;i<N;i++)
@@ -687,7 +694,7 @@ Vec<N,T>::Vec(const Vec<N,T2>& u)
} }
template<int N,typename T> template<int N,typename T>
__host__ __device__ JZQ_DECORATOR
T& Vec<N,T>::operator()(int i) T& Vec<N,T>::operator()(int i)
{ {
assert(i>=0 && i<N); assert(i>=0 && i<N);
@@ -695,7 +702,7 @@ T& Vec<N,T>::operator()(int i)
} }
template<int N,typename T> template<int N,typename T>
__host__ __device__ JZQ_DECORATOR
const T& Vec<N,T>::operator()(int i) const const T& Vec<N,T>::operator()(int i) const
{ {
assert(i>=0 && i<N); assert(i>=0 && i<N);
@@ -703,7 +710,7 @@ const T& Vec<N,T>::operator()(int i) const
} }
template<int N,typename T> template<int N,typename T>
__host__ __device__ JZQ_DECORATOR
T& Vec<N,T>::operator[](int i) T& Vec<N,T>::operator[](int i)
{ {
assert(i>=0 && i<N); assert(i>=0 && i<N);
@@ -711,7 +718,7 @@ T& Vec<N,T>::operator[](int i)
} }
template<int N,typename T> template<int N,typename T>
__host__ __device__ JZQ_DECORATOR
const T& Vec<N,T>::operator[](int i) const const T& Vec<N,T>::operator[](int i) const
{ {
assert(i>=0 && i<N); assert(i>=0 && i<N);
@@ -719,7 +726,7 @@ const T& Vec<N,T>::operator[](int i) const
} }
template<int N,typename T> template<int N,typename T>
__host__ __device__ JZQ_DECORATOR
Vec<N,T> Vec<N,T>::operator*=(const Vec<N,T>& u) Vec<N,T> Vec<N,T>::operator*=(const Vec<N,T>& u)
{ {
for(int i=0;i<N;i++) v[i]*=u(i); for(int i=0;i<N;i++) v[i]*=u(i);
@@ -727,7 +734,7 @@ Vec<N,T> Vec<N,T>::operator*=(const Vec<N,T>& u)
} }
template<int N,typename T> template<int N,typename T>
__host__ __device__ JZQ_DECORATOR
Vec<N,T> Vec<N,T>::operator+=(const Vec<N,T>& u) Vec<N,T> Vec<N,T>::operator+=(const Vec<N,T>& u)
{ {
for(int i=0;i<N;i++) v[i]+=u(i); for(int i=0;i<N;i++) v[i]+=u(i);
@@ -735,7 +742,7 @@ Vec<N,T> Vec<N,T>::operator+=(const Vec<N,T>& u)
} }
template<int N,typename T> template<int N,typename T>
__host__ __device__ JZQ_DECORATOR
Vec<N,T> Vec<N,T>::operator*=(T s) Vec<N,T> Vec<N,T>::operator*=(T s)
{ {
for(int i=0;i<N;i++) v[i]*=s; for(int i=0;i<N;i++) v[i]*=s;
@@ -743,7 +750,7 @@ Vec<N,T> Vec<N,T>::operator*=(T s)
} }
template<int N,typename T> template<int N,typename T>
__host__ __device__ JZQ_DECORATOR
Vec<N,T> Vec<N,T>::operator+=(T s) Vec<N,T> Vec<N,T>::operator+=(T s)
{ {
for(int i=0;i<N;i++) v[i]+=s; for(int i=0;i<N;i++) v[i]+=s;
@@ -751,7 +758,7 @@ Vec<N,T> Vec<N,T>::operator+=(T s)
} }
template<int N,typename T> template<int N,typename T>
__host__ __device__ JZQ_DECORATOR
Vec<N,T> operator-(const Vec<N,T>& u) Vec<N,T> operator-(const Vec<N,T>& u)
{ {
Vec<N,T> r; Vec<N,T> r;
@@ -760,7 +767,7 @@ Vec<N,T> operator-(const Vec<N,T>& u)
} }
template<int N,typename T> template<int N,typename T>
__host__ __device__ JZQ_DECORATOR
Vec<N,T> operator+(const Vec<N,T>& u,const Vec<N,T>& v) Vec<N,T> operator+(const Vec<N,T>& u,const Vec<N,T>& v)
{ {
Vec<N,T> r; Vec<N,T> r;
@@ -769,7 +776,7 @@ Vec<N,T> operator+(const Vec<N,T>& u,const Vec<N,T>& v)
} }
template<int N,typename T> template<int N,typename T>
__host__ __device__ JZQ_DECORATOR
Vec<N,T> operator-(const Vec<N,T>& u,const Vec<N,T>& v) Vec<N,T> operator-(const Vec<N,T>& u,const Vec<N,T>& v)
{ {
Vec<N,T> r; Vec<N,T> r;
@@ -778,7 +785,7 @@ Vec<N,T> operator-(const Vec<N,T>& u,const Vec<N,T>& v)
} }
template<int N,typename T> template<int N,typename T>
__host__ __device__ JZQ_DECORATOR
Vec<N,T> operator-(const Vec<N,T>& u,const T v) Vec<N,T> operator-(const Vec<N,T>& u,const T v)
{ {
Vec<N,T> r; Vec<N,T> r;
@@ -787,7 +794,7 @@ Vec<N,T> operator-(const Vec<N,T>& u,const T v)
} }
template<int N,typename T> template<int N,typename T>
__host__ __device__ JZQ_DECORATOR
Vec<N,T> operator*(const Vec<N,T>& u,const Vec<N,T>& v) Vec<N,T> operator*(const Vec<N,T>& u,const Vec<N,T>& v)
{ {
Vec<N,T> r; Vec<N,T> r;
@@ -796,7 +803,7 @@ Vec<N,T> operator*(const Vec<N,T>& u,const Vec<N,T>& v)
} }
template<int N,typename T> template<int N,typename T>
__host__ __device__ JZQ_DECORATOR
Vec<N,T> operator/(const Vec<N,T>& u,const Vec<N,T>& v) Vec<N,T> operator/(const Vec<N,T>& u,const Vec<N,T>& v)
{ {
Vec<N,T> r; Vec<N,T> r;
@@ -805,7 +812,7 @@ Vec<N,T> operator/(const Vec<N,T>& u,const Vec<N,T>& v)
} }
template<int N,typename T> template<int N,typename T>
__host__ __device__ JZQ_DECORATOR
Vec<N,T> operator*(const T s,const Vec<N,T>& u) Vec<N,T> operator*(const T s,const Vec<N,T>& u)
{ {
Vec<N,T> r; Vec<N,T> r;
@@ -814,7 +821,7 @@ Vec<N,T> operator*(const T s,const Vec<N,T>& u)
} }
template<int N,typename T> template<int N,typename T>
__host__ __device__ JZQ_DECORATOR
Vec<N,T> operator*(const Vec<N,T>& u,const T s) Vec<N,T> operator*(const Vec<N,T>& u,const T s)
{ {
Vec<N,T> r; Vec<N,T> r;
@@ -823,7 +830,7 @@ Vec<N,T> operator*(const Vec<N,T>& u,const T s)
} }
template<int N,typename T> template<int N,typename T>
__host__ __device__ JZQ_DECORATOR
Vec<N,T> operator/(const Vec<N,T>& u,const T s) Vec<N,T> operator/(const Vec<N,T>& u,const T s)
{ {
Vec<N,T> r; Vec<N,T> r;
@@ -832,7 +839,7 @@ Vec<N,T> operator/(const Vec<N,T>& u,const T s)
} }
template<int N,typename T> template<int N,typename T>
__host__ __device__ JZQ_DECORATOR
Vec<N,bool> operator<(const Vec<N,T>& u,const Vec<N,T>& v) Vec<N,bool> operator<(const Vec<N,T>& u,const Vec<N,T>& v)
{ {
Vec<N,bool> r; Vec<N,bool> r;
@@ -841,7 +848,7 @@ Vec<N,bool> operator<(const Vec<N,T>& u,const Vec<N,T>& v)
} }
template<int N,typename T> template<int N,typename T>
__host__ __device__ JZQ_DECORATOR
Vec<N,bool> operator>(const Vec<N,T>& u,const Vec<N,T>& v) Vec<N,bool> operator>(const Vec<N,T>& u,const Vec<N,T>& v)
{ {
Vec<N,bool> r; Vec<N,bool> r;
@@ -850,7 +857,7 @@ Vec<N,bool> operator>(const Vec<N,T>& u,const Vec<N,T>& v)
} }
template<int N,typename T> template<int N,typename T>
__host__ __device__ JZQ_DECORATOR
Vec<N,bool> operator<=(const Vec<N,T>& u,const Vec<N,T>& v) Vec<N,bool> operator<=(const Vec<N,T>& u,const Vec<N,T>& v)
{ {
Vec<N,bool> r; Vec<N,bool> r;
@@ -859,7 +866,7 @@ Vec<N,bool> operator<=(const Vec<N,T>& u,const Vec<N,T>& v)
} }
template<int N,typename T> template<int N,typename T>
__host__ __device__ JZQ_DECORATOR
Vec<N,bool> operator>=(const Vec<N,T>& u,const Vec<N,T>& v) Vec<N,bool> operator>=(const Vec<N,T>& u,const Vec<N,T>& v)
{ {
Vec<N,bool> r; Vec<N,bool> r;
@@ -868,7 +875,7 @@ Vec<N,bool> operator>=(const Vec<N,T>& u,const Vec<N,T>& v)
} }
template<int N,typename T> template<int N,typename T>
__host__ __device__ JZQ_DECORATOR
Vec<N,bool> operator==(const Vec<N,T>& u,const Vec<N,T>& v) Vec<N,bool> operator==(const Vec<N,T>& u,const Vec<N,T>& v)
{ {
Vec<N,bool> r; Vec<N,bool> r;
@@ -877,7 +884,7 @@ Vec<N,bool> operator==(const Vec<N,T>& u,const Vec<N,T>& v)
} }
template<int N,typename T> template<int N,typename T>
__host__ __device__ JZQ_DECORATOR
Vec<N,bool> operator!=(const Vec<N,T>& u,const Vec<N,T>& v) Vec<N,bool> operator!=(const Vec<N,T>& u,const Vec<N,T>& v)
{ {
Vec<N,bool> r; Vec<N,bool> r;
@@ -886,7 +893,7 @@ Vec<N,bool> operator!=(const Vec<N,T>& u,const Vec<N,T>& v)
} }
template<int N,typename T> template<int N,typename T>
inline T dot(const Vec<N,T>& u,const Vec<N,T>& v) JZQ_DECORATOR inline T dot(const Vec<N,T>& u,const Vec<N,T>& v)
{ {
assert(N>0); assert(N>0);
T sumprod = u(0)*v(0); T sumprod = u(0)*v(0);
@@ -895,13 +902,13 @@ inline T dot(const Vec<N,T>& u,const Vec<N,T>& v)
} }
template<typename T> template<typename T>
inline T cross(const Vec<2,T> &a,const Vec<2,T> &b) JZQ_DECORATOR inline T cross(const Vec<2,T> &a,const Vec<2,T> &b)
{ {
return a[0]*b[1]-a[1]*b[0]; return a[0]*b[1]-a[1]*b[0];
} }
template<typename T> template<typename T>
inline Vec<3,T> cross(const Vec<3,T> &a,const Vec<3,T> &b) JZQ_DECORATOR inline Vec<3,T> cross(const Vec<3,T> &a,const Vec<3,T> &b)
{ {
return Vec<3,T>(a[1]*b[2]-a[2]*b[1], return Vec<3,T>(a[1]*b[2]-a[2]*b[1],
a[2]*b[0]-a[0]*b[2], a[2]*b[0]-a[0]*b[2],
@@ -909,19 +916,19 @@ inline Vec<3,T> cross(const Vec<3,T> &a,const Vec<3,T> &b)
} }
template<int N,typename T> template<int N,typename T>
inline T norm(const Vec<N,T>& u) JZQ_DECORATOR inline T norm(const Vec<N,T>& u)
{ {
return std::sqrt(dot(u,u)); return std::sqrt(dot(u,u));
} }
template<int N,typename T> template<int N,typename T>
inline Vec<N,T> normalize(const Vec<N,T>& u) JZQ_DECORATOR inline Vec<N,T> normalize(const Vec<N,T>& u)
{ {
return u/norm(u); return u/norm(u);
} }
template<int N> template<int N>
inline bool any(const Vec<N,bool>& u) JZQ_DECORATOR inline bool any(const Vec<N,bool>& u)
{ {
for(int i=0;i<N;i++) for(int i=0;i<N;i++)
{ {
@@ -931,7 +938,7 @@ inline bool any(const Vec<N,bool>& u)
} }
template<int N> template<int N>
inline bool all(const Vec<N,bool>& u) JZQ_DECORATOR inline bool all(const Vec<N,bool>& u)
{ {
for(int i=0;i<N;i++) for(int i=0;i<N;i++)
{ {
@@ -941,7 +948,7 @@ inline bool all(const Vec<N,bool>& u)
} }
template<int N,typename T> template<int N,typename T>
inline T min(const Vec<N,T>& u) JZQ_DECORATOR inline T min(const Vec<N,T>& u)
{ {
assert(N>0); assert(N>0);
@@ -959,7 +966,7 @@ inline T min(const Vec<N,T>& u)
} }
template<int N,typename T> template<int N,typename T>
inline T max(const Vec<N,T>& u) JZQ_DECORATOR inline T max(const Vec<N,T>& u)
{ {
assert(N>0); assert(N>0);
@@ -977,7 +984,7 @@ inline T max(const Vec<N,T>& u)
} }
template<int N,typename T> template<int N,typename T>
inline T sum(const Vec<N,T>& u) JZQ_DECORATOR inline T sum(const Vec<N,T>& u)
{ {
assert(N>0); assert(N>0);
@@ -1868,6 +1875,12 @@ void Array3<T>::swap(Array3<T>& b)
b.d = tmp_d; b.d = tmp_d;
} }
template<typename T>
bool Array3<T>::empty() const
{
return (numel()==0);
}
template<typename T> template<typename T>
Vec3i size(const Array3<T>& a) Vec3i size(const Array3<T>& a)
{ {
@@ -1898,4 +1911,84 @@ void swap(Array3<T>& a,Array3<T>& b)
a.swap(b); a.swap(b);
} }
template<typename T>
void fill(Array3<T>* a,const T& value)
{
assert(a!=0);
assert(a->numel()>0);
const int n = a->numel();
T* d = a->data();
for(int i=0;i<n;i++) d[i] = value;
}
template<typename T>
Array3<T> a3read(const std::string& fileName)
{
Array3<T> A;
if(!a3read(&A,fileName)) { return Array3<T>(); }
return A;
}
template<typename T>
bool a3read(Array3<T>* out_A,const std::string& fileName)
{
FILE* f = fopen(fileName.c_str(),"rb");
if(!f) { return false; }
int w,h,d,s;
if(fread(&w,sizeof(w),1,f)!=1 ||
fread(&h,sizeof(h),1,f)!=1 ||
fread(&d,sizeof(d),1,f)!=1 ||
fread(&s,sizeof(s),1,f)!=1 ||
((w*h*d)<1) || s!=sizeof(T))
{
fclose(f);
return false;
}
Array3<T> A(w,h,d);
if(fread(A.data(),sizeof(T)*w*h*d,1,f)!=1)
{
fclose(f);
return false;
}
if(out_A!=0) { *out_A = A; }
fclose(f);
return true;
}
template<typename T>
bool a3write(const Array3<T>& A,const std::string& fileName)
{
if(A.numel()==0) { return false; }
FILE* f = fopen(fileName.c_str(),"wb");
if(!f) { return false; }
const int w = A.width();
const int h = A.height();
const int d = A.depth();
const int s = sizeof(T);
if(fwrite(&w,sizeof(w),1,f)!=1 ||
fwrite(&h,sizeof(h),1,f)!=1 ||
fwrite(&d,sizeof(d),1,f)!=1 ||
fwrite(&s,sizeof(s),1,f)!=1 ||
fwrite(A.data(),sizeof(T)*w*h*d,1,f)!=1)
{
fclose(f);
return false;
}
fclose(f);
return true;
}
#endif #endif

View File

@@ -1,410 +0,0 @@
// This software is in the public domain. Where that dedication is not
// recognized, you are granted a perpetual, irrevocable license to copy
// and modify this file as you see fit.
#ifndef PATCHMATCH_GPU_H_
#define PATCHMATCH_GPU_H_
#include <stdint.h>
#include <cfloat>
#include "texarray2.h"
#include "memarray2.h"
struct pcgState
{
uint64_t state;
uint64_t increment;
};
__device__ void pcgAdvance(pcgState* rng)
{
rng->state = rng->state * 6364136223846793005ULL + rng->increment;
}
__device__ uint32_t pcgOutput(uint64_t state)
{
return (uint32_t)(((state >> 22u) ^ state) >> ((state >> 61u) + 22u));
}
__device__ uint32_t pcgRand(pcgState* rng)
{
uint64_t oldstate = rng->state;
pcgAdvance(rng);
return pcgOutput(oldstate);
}
__device__ void pcgInit(pcgState* rng,uint64_t seed,uint64_t stream)
{
rng->state = 0U;
rng->increment = (stream << 1u) | 1u;
pcgAdvance(rng);
rng->state += seed;
pcgAdvance(rng);
}
typedef Vec<1,float> V1f;
typedef Array2<Vec<1,float>> A2V1f;
__global__ void krnlInitRngStates(const int width,
const int height,
pcgState* rngStates)
{
const int x = blockDim.x*blockIdx.x + threadIdx.x;
const int y = blockDim.y*blockIdx.y + threadIdx.y;
if (x<width && y<height)
{
const int idx = x+y*width;
pcgInit(&rngStates[idx],1337,idx);
}
}
pcgState* initGpuRng(const int width,
const int height)
{
pcgState* gpuRngStates;
cudaMalloc(&gpuRngStates,width*height*sizeof(pcgState));
const dim3 threadsPerBlock(16,16);
const dim3 numBlocks((width+threadsPerBlock.x)/threadsPerBlock.x,
(height+threadsPerBlock.y)/threadsPerBlock.y);
krnlInitRngStates<<<numBlocks,threadsPerBlock>>>(width,height,gpuRngStates);
return gpuRngStates;
}
template<int N,typename T,int M>
struct PatchSSD
{
const TexArray2<N,T,M> A;
const TexArray2<N,T,M> B;
const Vec<N,float> weights;
PatchSSD(const TexArray2<N,T,M>& A,
const TexArray2<N,T,M>& B,
const Vec<N,float>& weights)
: A(A),B(B),weights(weights) {}
__device__ float operator()(int patchWidth,
const int ax,
const int ay,
const int bx,
const int by,
const float ebest)
{
const int hpw = patchWidth/2;
float ssd = 0;
for(int py=-hpw;py<=+hpw;py++)
{
for(int px=-hpw;px<=+hpw;px++)
{
const Vec<N,T> pixelA = A(ax + px, ay + py);
const Vec<N,T> pixelB = B(bx + px, by + py);
for(int i=0;i<N;i++)
{
const float diff = float(pixelA[i])-float(pixelB[i]);
ssd += weights[i]*diff*diff;
}
}
if (ssd>ebest) { return ssd; }
}
return ssd;
}
};
template<typename FUNC>
__global__ void krnlEvalErrorPass(const int patchWidth,
FUNC patchError,
const TexArray2<2,int> NNF,
TexArray2<1,float> E)
{
const int x = blockDim.x*blockIdx.x + threadIdx.x;
const int y = blockDim.y*blockIdx.y + threadIdx.y;
if (x<NNF.width && y<NNF.height)
{
const V2i n = NNF(x,y);
E.write(x,y,V1f(patchError(patchWidth,x,y,n[0],n[1],FLT_MAX)));
}
}
void __device__ updateOmega(MemArray2<int>& Omega,const int patchWidth,const int bx,const int by,const int incdec)
{
const int r = patchWidth/2;
for(int oy=-r;oy<=+r;oy++)
for(int ox=-r;ox<=+r;ox++)
{
const int x = bx+ox;
const int y = by+oy;
atomicAdd(&Omega.data[x+y*Omega.width],incdec);
//Omega.data[x+y*Omega.width] += incdec;
}
}
int __device__ patchOmega(const int patchWidth,const int bx,const int by,const MemArray2<int>& Omega)
{
const int r = patchWidth/2;
int sum = 0;
for(int oy=-r;oy<=+r;oy++)
for(int ox=-r;ox<=+r;ox++)
{
const int x = bx+ox;
const int y = by+oy;
sum += Omega.data[x+y*Omega.width]; /// XXX: atomic read instead ??
}
return sum;
}
template<typename FUNC>
__device__ void tryPatch(const V2i& sizeA,
const V2i& sizeB,
MemArray2<int>& Omega,
const int patchWidth,
FUNC patchError,
const float lambda,
const int ax,
const int ay,
const int bx,
const int by,
V2i& nbest,
float& ebest)
{
const float omegaBest = (float(sizeA(0)*sizeA(1)) /
float(sizeB(0)*sizeB(1))) * float(patchWidth*patchWidth);
const float curOcc = (float(patchOmega(patchWidth,nbest(0),nbest(1),Omega))/float(patchWidth*patchWidth))/omegaBest;
const float newOcc = (float(patchOmega(patchWidth, bx, by,Omega))/float(patchWidth*patchWidth))/omegaBest;
const float curErr = ebest;
const float newErr = patchError(patchWidth,ax,ay,bx,by,curErr+lambda*curOcc);
if ((newErr+lambda*newOcc) < (curErr+lambda*curOcc))
{
updateOmega(Omega,patchWidth, bx, by,+1);
updateOmega(Omega,patchWidth,nbest(0),nbest(1),-1);
nbest = V2i(bx,by);
ebest = newErr;
}
}
template<typename FUNC>
__device__ void tryNeighborsOffset(const int x,
const int y,
const int ox,
const int oy,
V2i& nbest,
float& ebest,
const V2i& sizeA,
const V2i& sizeB,
MemArray2<int>& Omega,
const int patchWidth,
FUNC patchError,
const float lambda,
const TexArray2<2,int>& NNF)
{
const int hpw = patchWidth/2;
const V2i on = NNF(x+ox,y+oy);
const int nx = on(0)-ox;
const int ny = on(1)-oy;
if (nx>=hpw && nx<sizeB(0)-hpw &&
ny>=hpw && ny<sizeB(1)-hpw)
{
tryPatch(sizeA,sizeB,Omega,patchWidth,patchError,lambda,x,y,nx,ny,nbest,ebest);
}
}
template<typename FUNC>
__global__ void krnlPropagationPass(const V2i sizeA,
const V2i sizeB,
MemArray2<int> Omega,
const int patchWidth,
FUNC patchError,
const float lambda,
const int r,
const TexArray2<2,int> NNF,
TexArray2<2,int> NNF2,
TexArray2<1,float> E,
TexArray2<1,unsigned char> mask)
{
const int x = blockDim.x*blockIdx.x + threadIdx.x;
const int y = blockDim.y*blockIdx.y + threadIdx.y;
if (x<sizeA(0) && y<sizeA(1))
{
V2i nbest = NNF(x,y);
float ebest = E(x,y)(0);
if (mask(x,y)[0]==255)
{
tryNeighborsOffset(x,y,-r,0,nbest,ebest,sizeA,sizeB,Omega,patchWidth,patchError,lambda,NNF);
tryNeighborsOffset(x,y,+r,0,nbest,ebest,sizeA,sizeB,Omega,patchWidth,patchError,lambda,NNF);
tryNeighborsOffset(x,y,0,-r,nbest,ebest,sizeA,sizeB,Omega,patchWidth,patchError,lambda,NNF);
tryNeighborsOffset(x,y,0,+r,nbest,ebest,sizeA,sizeB,Omega,patchWidth,patchError,lambda,NNF);
}
E.write(x,y,V1f(ebest));
NNF2.write(x,y,nbest);
}
}
template<typename FUNC>
__device__ void tryRandomOffsetInRadius(const int r,
const V2i& sizeA,
const V2i& sizeB,
MemArray2<int>& Omega,
const int patchWidth,
FUNC patchError,
const float lambda,
const int x,
const int y,
const V2i& norg,
V2i& nbest,
float& ebest,
pcgState* rngState)
{
const int hpw = patchWidth/2;
const int xmin = max(norg(0)-r,hpw);
const int xmax = min(norg(0)+r,sizeB(0)-1-hpw);
const int ymin = max(norg(1)-r,hpw);
const int ymax = min(norg(1)+r,sizeB(1)-1-hpw);
const int nx = xmin+(pcgRand(rngState)%(xmax-xmin+1));
const int ny = ymin+(pcgRand(rngState)%(ymax-ymin+1));
tryPatch(sizeA,sizeB,Omega,patchWidth,patchError,lambda,x,y,nx,ny,nbest,ebest);
}
/*
template<typename FUNC>
__global__ void krnlRandomSearchPass(const V2i sizeA,
const V2i sizeB,
MemArray2<int> Omega,
const int patchWidth,
FUNC patchError,
const float lambda,
TexArray2<2,int> NNF,
TexArray2<1,float> E,
TexArray2<1,unsigned char> mask,
pcgState* rngStates)
{
const int x = blockDim.x*blockIdx.x + threadIdx.x;
const int y = blockDim.y*blockIdx.y + threadIdx.y;
if (x<sizeA(0) && y<sizeA(1))
{
if (mask(x,y)[0]==255)
{
V2i nbest = NNF(x,y);
float ebest = E(x,y)(0);
const V2i norg = nbest;
for(int r=1;r<max(sizeB(0),sizeB(1))/2;r=r*2)
{
tryRandomOffsetInRadius(r,sizeA,sizeB,Omega,patchWidth,patchError,lambda,x,y,norg,nbest,ebest,&rngStates[x+y*NNF.width]);
}
E.write(x,y,V1f(ebest));
NNF.write(x,y,nbest);
}
}
}
*/
template<typename FUNC>
__global__ void krnlRandomSearchPass(const V2i sizeA,
const V2i sizeB,
MemArray2<int> Omega,
const int patchWidth,
FUNC patchError,
const float lambda,
const int radius,
TexArray2<2,int> NNF,
TexArray2<1,float> E,
TexArray2<1,unsigned char> mask,
pcgState* rngStates)
{
const int x = blockDim.x*blockIdx.x + threadIdx.x;
const int y = blockDim.y*blockIdx.y + threadIdx.y;
if (x<sizeA(0) && y<sizeA(1))
{
if (mask(x,y)[0]==255)
{
V2i nbest = NNF(x,y);
float ebest = E(x,y)(0);
const V2i norg = nbest;
tryRandomOffsetInRadius(radius,sizeA,sizeB,Omega,patchWidth,patchError,lambda,x,y,norg,nbest,ebest,&rngStates[x+y*NNF.width]);
E.write(x,y,V1f(ebest));
NNF.write(x,y,nbest);
}
}
}
template<typename FUNC>
void patchmatchGPU(const V2i sizeA,
const V2i sizeB,
MemArray2<int>& Omega,
const int patchWidth,
FUNC patchError,
const float lambda,
const int numIters,
const int numThreadsPerBlock,
TexArray2<2,int>& NNF,
TexArray2<2,int>& NNF2,
TexArray2<1,float>& E,
TexArray2<1,unsigned char>& mask,
pcgState* rngStates)
{
const dim3 threadsPerBlock = dim3(numThreadsPerBlock,numThreadsPerBlock);
const dim3 numBlocks = dim3((NNF.width+threadsPerBlock.x)/threadsPerBlock.x,
(NNF.height+threadsPerBlock.y)/threadsPerBlock.y);
krnlEvalErrorPass<<<numBlocks,threadsPerBlock>>>(patchWidth,patchError,NNF,E);
checkCudaError(cudaDeviceSynchronize());
for(int i=0;i<numIters;i++)
{
krnlPropagationPass<<<numBlocks,threadsPerBlock>>>(sizeA,sizeB,Omega,patchWidth,patchError,lambda,4,NNF,NNF2,E,mask); std::swap(NNF,NNF2);
checkCudaError(cudaDeviceSynchronize());
krnlPropagationPass<<<numBlocks,threadsPerBlock>>>(sizeA,sizeB,Omega,patchWidth,patchError,lambda,2,NNF,NNF2,E,mask); std::swap(NNF,NNF2);
checkCudaError(cudaDeviceSynchronize());
krnlPropagationPass<<<numBlocks,threadsPerBlock>>>(sizeA,sizeB,Omega,patchWidth,patchError,lambda,1,NNF,NNF2,E,mask); std::swap(NNF,NNF2);
checkCudaError(cudaDeviceSynchronize());
for(int r=1;r<max(sizeB(0),sizeB(1))/2;r=r*2)
{
krnlRandomSearchPass<<<numBlocks,threadsPerBlock>>>(sizeA,sizeB,Omega,patchWidth,patchError,lambda,r,NNF,E,mask,rngStates);
}
checkCudaError(cudaDeviceSynchronize());
}
krnlEvalErrorPass<<<numBlocks,threadsPerBlock>>>(patchWidth,patchError,NNF,E);
checkCudaError(cudaDeviceSynchronize());
}
#endif