/* image convolution
 * Chris Lu, 2008
 *
 * To compile: gcc -O3 -ffast-math -o convolve convolve.c -lpthread -msse -mfpmath=sse
 * 
 * This program is *very* specific about its PNM file format. It must be raw
 * 32-bit RGB, with a magic number of 'P6', followed by a newline, then the
 * width and height, another newline, then the max color value, another
 * newline, and then the raw image data. The maximum color value should be
 * 255.
 *
 * This sort of sucks. Try turning SSE on and off with the NO_SSE define,
 * and try changing the type of color_t with the typedef. There is no
 * saturation with SSE, unfortunately. Also, printfs needs to be changed
 * depending on the type of color_t.
 */

#include <stdio.h>
#include <stdlib.h>
#include <pthread.h>
#include <sys/mman.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <unistd.h>
#include <ctype.h>
#include <assert.h>

//#define NO_SSE

typedef float color_t;

#ifdef NO_SSE
typedef union {
	color_t c[4];
} pixel_t;
#else
typedef color_t v4c __attribute__((vector_size(16)));
typedef union {
	v4c v;
	color_t c[4];
} pixel_t;
#endif /* NO_SSE */

typedef struct _image {
	int w;
	int h;
	pixel_t *data;
} image_t;

typedef struct _cparam {
	image_t *image;
	image_t *kernel;
	image_t *res;
	pixel_t *mybuf;
	int start;
	int end;
} cparam_t;

typedef struct _gparam {
	pixel_t **data;
	pixel_t *resdata;
	int ndata;
	int start;
	int end;
	int scale;
} gparam_t;

void do_conv(image_t *img, image_t *kern, image_t *res, int nthreads, int scale);
pixel_t parse_img(unsigned char *file, image_t *img);
void *convolve_helper(void* vparam);
void *gather_helper(void* vparam);

int main(int argc, char** argv)
{
	int nthreads = 1;
	int fd1, fd2, size1, size2;
	int finalsize, i, scale;
	void *file1, *file2;
	image_t image, kernel, res;
	pixel_t kern_totals;

	if(argc < 4) {
		printf("Usage: %s <image> <kernel> <scale> [nthreads]\n", argv[0]);
		return 1;
	}

	if(argc >= 5) {
		if(sscanf(argv[4], "%d", &nthreads) != 1) {
			printf("number of threads is invalid\n");
			return 1;
		}
	}

	scale = atoi(argv[3]);

	if((fd1 = open(argv[1], O_RDONLY)) < 0) {
		printf("Image could not be opened\n");
		return 1;
	}
	if((fd2 = open(argv[2], O_RDONLY)) < 0) {
		printf("Kernel could not be opened\n");
		return 1;
	}
	size1 = lseek(fd1, 0, SEEK_END);
	size2 = lseek(fd2, 0, SEEK_END);
	if((file1 = mmap((void*)0x80000000, size1, PROT_READ, MAP_SHARED, fd1, 0)) <= 0) {
		perror("mmap sucks");
		return 1;
	}
	if((file2 = mmap((void*)0x80000000, size2, PROT_READ, MAP_SHARED, fd2, 0)) <= 0) {
		perror("mmap sucks");
		return 1;
	}

	parse_img((unsigned char*)file1, &image);
	kern_totals = parse_img((unsigned char*)file2, &kernel);
	fprintf(stderr, "Normalized scale = %.1f\nred no-saturate = %.1f\ngreen no-saturate = %.1f\nblue no-saturate = %.1f\n",
	        (kern_totals.c[0] + kern_totals.c[1] + kern_totals.c[2])/3,
	        kern_totals.c[0],
	        kern_totals.c[1],
	        kern_totals.c[2]);

	do_conv(&image, &kernel, &res, nthreads, scale);
	finalsize = res.w * res.h;
	printf("P6\n%d %d\n255\n", res.w, res.h);
	for(i = 0; i < finalsize; i++) {
		printf("%c%c%c", (char)res.data[i].c[0], (char)res.data[i].c[1], (char)res.data[i].c[2]);
	}

	return 0;
}

void do_conv(image_t *img, image_t *kern, image_t *res, int nthreads, int scale)
{
	int i;
	pthread_t *threads = malloc(nthreads * sizeof(pthread_t));
	cparam_t *params = malloc(nthreads * sizeof(cparam_t));
	gparam_t *gparams;
	pixel_t **data = malloc(nthreads * sizeof(pixel_t*));

	res->h = img->h + kern->h - 1;
	res->w = img->w + kern->w - 1;
	posix_memalign(&(res->data), sizeof(pixel_t), res->h * res->w * sizeof(pixel_t));

	for(i = 0; i < nthreads; i++) {
		params[i].start = img->h/nthreads * i;
		params[i].end = img->h/nthreads * (i + 1);
		if(i == (nthreads - 1))
			params[i].end = img->h;
		params[i].res = res;
		params[i].image = img;
		params[i].kernel = kern;
		posix_memalign(&(params[i].mybuf), sizeof(pixel_t), res->h * res->w * sizeof(pixel_t));
		data[i] = params[i].mybuf;
		pthread_create(&threads[i], NULL, &convolve_helper, &params[i]);
		fprintf(stderr, "convolve thread %d created...\n", i);
	}

	for(i = 0; i < nthreads; i++) {
		fprintf(stderr, "waiting for thread %d... ", i);
		pthread_join(threads[i], NULL);
		fprintf(stderr, "done\n");
	}

	gparams = malloc(nthreads * sizeof(gparam_t));
	for(i = 0; i < nthreads; i++) {
		gparams[i].data = data;
		gparams[i].ndata = nthreads;
		gparams[i].resdata = res->data;
		gparams[i].scale = scale;
		gparams[i].start = (res->h / nthreads * i) * res->w;
		gparams[i].end = (res->h / nthreads * (i + 1)) * res->w;
		if(i == (nthreads - 1))
			gparams[i].end = res->h * res->w;
		pthread_create(&threads[i], NULL, &gather_helper, &gparams[i]);
		fprintf(stderr, "gather thread %d created...\n", i);
	}

	for(i = 0; i < nthreads; i++) {
		fprintf(stderr, "waiting for thread %d... ", i);
		pthread_join(threads[i], NULL);
		fprintf(stderr, "done\n");
	}

	free(threads);
	free(gparams);
	free(params);
	for(i = 0; i < nthreads; i++) {
		free(data[i]);
	}
	free(data);

	return;
}

void *convolve_helper(void *vparam)
{
	int row, col, i, j;
	cparam_t *param = (cparam_t*)vparam;

	int imgw = param->image->w, end = param->end;
	int resw = param->res->w;
	int kw = param->kernel->w, kh = param->kernel->h;
	int pixend, scale;

	pixel_t *buf = param->mybuf;
	pixel_t *kern = param->kernel->data;
	pixel_t *img = param->image->data;

	for(row = param->start; row < end; row++) {
		for(col = 0; col < imgw; col++) {
			int pix = row * imgw + col;
			for(i = 0; i < kh; i++) {
				int resbase = (row + i) * resw + col;
				int kbase = i * kw;
				for(j = 0; j < kw; j++) {
#ifdef NO_SSE
					int boffs = resbase + j;
					int koffs = kbase + j;
					buf[boffs].c[0] += kern[koffs].c[0] * img[pix].c[0];
					buf[boffs].c[1] += kern[koffs].c[1] * img[pix].c[1];
					buf[boffs].c[2] += kern[koffs].c[2] * img[pix].c[2];
#else
					buf[resbase + j].v += kern[kbase + j].v * img[pix].v;
#endif /* NO_SSE */
				}
			}
		}
	}

	return NULL;
	

}

void *gather_helper(void *vparam)
{
	int i, j, t;
	gparam_t *param = (gparam_t*)vparam;
	pixel_t *gbuf = param->resdata;
	pixel_t **data = param->data;
	float scale = (color_t)param->scale;
	pixel_t temp;
#ifdef NO_SSE
#else
	pixel_t scales;
	scales.c[0] = scales.c[1] = scales.c[2] = scales.c[3] = scale;
	v4c tv, sv = scales.v;
#endif /* NO_SSE */


	for(i = param->start; i < param->end; i++) {
#ifdef NO_SSE
		temp.c[0] = data[0][i].c[0];
		temp.c[1] = data[0][i].c[1];
		temp.c[2] = data[0][i].c[2];
#else
		tv = data[0][i].v;
#endif /* NO_SSE */
		for(j = 1; j < param->ndata; j++) {
#ifdef NO_SSE
			temp.c[0] += data[j][i].c[0];
			temp.c[1] += data[j][i].c[1];
			temp.c[2] += data[j][i].c[2];
#else
			tv += data[j][i].v;
#endif /* NO_SSE */
		}
#ifdef NO_SSE
		gbuf[i].c[0] = (t = temp.c[0] / scale) > 255 ? 255 : t;
		gbuf[i].c[1] = (t = temp.c[1] / scale) > 255 ? 255 : t;
		gbuf[i].c[2] = (t = temp.c[2] / scale) > 255 ? 255 : t;
#else
		gbuf[i].v = tv / sv;
#endif /* NO_SSE */
	}

	return NULL;
}

pixel_t parse_img(unsigned char *file, image_t *img)
{
	int pos = 3, width, height, size, i = 0, end;
	color_t rtotal = 0, gtotal = 0, btotal = 0;
	pixel_t ptotal;

	if(file[0] != 'P' || file[1] != '6' || file[2] != '\n') {
		printf("Invalid PNM image\n");
		exit(1);
	}	

	while(isblank(file[pos]))
		pos++;

	if(sscanf(&(file[pos]), "%d %d\n", &width, &height) < 2) {
		printf("Invalid PNM image\n");
		exit(1);
	}

	while(file[pos++] != '\n');
	while(file[pos++] != '\n');

	/* should be in the image data now */
	end = width * height * 3 + pos;
	img->w = width;
	img->h = height;
	posix_memalign(&(img->data), sizeof(pixel_t), width * height * sizeof(pixel_t));

	for(; pos < end; pos += 3, i++) {
		rtotal += (img->data[i].c[0] = (color_t)file[pos]);
		gtotal += (img->data[i].c[1] = (color_t)file[pos+1]);
		btotal += (img->data[i].c[2] = (color_t)file[pos+2]);
	}

	ptotal.c[0] = rtotal;
	ptotal.c[1] = gtotal;
	ptotal.c[2] = btotal;
	return ptotal;
}
