#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <limits.h>
#include <math.h>
#include <time.h>

const float spheres[] = { 0,-6.95149f,-0.17298f,6.932f,
                          -1.64317f,-0.408679f,-3.07338f,1,
                          1.66359f,-0.256469f,-5.76149f,1,
                          1.62878f,0.840152f,-5.60949f,1,
                          1.59027f,1.74296f,-5.4756f,1,
                          -3.78282f,1.27461f,-11.7239f,1 };
const float col[] = { 0.2f, 0.4f, 0.1f,
                      0.5f, 0.2f, 0.6f,
                      0.7f, 0.4f, 0.4f,
                      0.2f, 0.6f, 0.7f,
                      0.1f, 0.8f, 0.5f,
                      0.5f, 0.3f, 0.7f };

#define WIDTH 640
#define HEIGHT 480
#define HWIDTHF ((float)((WIDTH)/2.0f))
#define HHEIGHTF ((float)((HEIGHT)/2.0f))
#define SQRTF (float)sqrt
#define SUBGRID 16

typedef struct rgb { unsigned char r, g, b; } rgb;

rgb pixels[WIDTH*HEIGHT];

float length(const float* v)
{
    return SQRTF(v[0]*v[0] + v[1]*v[1] + v[2]*v[2]);
}

void normalize(float* v)
{
    float vl = length(v);
    int i;
    for (i=0; i < 3; i++) v[i] /= vl;
}

int ray_hit_sphere(const float* o, const float* d, const float* s, float* ip)
{
    float srx = o[0] - s[0];
    float sry = o[1] - s[1];
    float srz = o[2] - s[2];
    float b = srx*d[0] + sry*d[1] + srz*d[2];
    float c = (srx*srx + sry*sry + srz*srz) - s[3]*s[3];
    float dd = b*b - c;
    if (dd > 0.0f) {
        float e = SQRTF(dd);
        float t0 = -b-e, t1 = -b+e;
        float t = t1 < t0 ? t1 : t0;
        if (t < 0.0001) return 0;
        ip[0] = o[0] + d[0]*t;
        ip[1] = o[1] + d[1]*t;
        ip[2] = o[2] + d[2]*t;
        return 1;
    } else return 0;
}

int ray_hit_spheres(const float* o, const float* d, float* ip)
{
    float cdist = (float)INT_MAX, rip[3];
    int i, closest = -1;
    for (i=0; i < sizeof(spheres)/sizeof(spheres[0]); i += 4)
        if (ray_hit_sphere(o, d, spheres + i, rip)) {
            float dist = length(rip);
            if (closest == -1 || cdist > dist) {
                cdist = dist;
                closest = i/4;
                memcpy(ip, rip, sizeof(float)*3);
            }
        }
    return closest;
}

void calc_subpixel(float* rgb, float x, float y, float sx, float sy)
{
    float dl, o[3], d[3], n[3], ip[3];
    int sidx, i;
    o[0] = o[1] = o[2] = 0.0f;
    d[0] = x/HWIDTHF - 1 + sx;
    d[1] = 1 - y/HHEIGHTF/(HWIDTHF/HHEIGHTF) + sy;
    d[2] = -1.0f;
    normalize(d);
    sidx = ray_hit_spheres(o, d, ip);
    if (sidx == -1) {
        rgb[0] += 0.062f;
        rgb[1] += 0.125f;
        rgb[2] += 0.392f;
    } else {
        for (i=0; i < 3; i++) n[i] = ip[i] - spheres[sidx*4 + i];
        normalize(n);
        dl = 0.2f + n[2];
        if (dl < 0.2f) dl = 0.2f;
        if (dl > 1) dl = 1;
        rgb[0] += col[sidx*3]*dl;
        rgb[1] += col[sidx*3 + 1]*dl;
        rgb[2] += col[sidx*3 + 2]*dl;
    }
}

void calc_pixel(rgb* rgb, int x, int y)
{
    float rgbf[3];
    int gx, gy;
    rgbf[0] = rgbf[1] = rgbf[2] = 0;
    for (gy=0; gy < SUBGRID; gy++)
        for (gx=0; gx < SUBGRID; gx++)
            calc_subpixel(rgbf, (float)x, (float)y,
                (float)gx/(float)WIDTH/(float)SUBGRID,
                (float)gy/(float)WIDTH/(float)SUBGRID);
    rgb->r = (unsigned char)(rgbf[0]/(float)(SUBGRID*SUBGRID)*255.0f);
    rgb->g = (unsigned char)(rgbf[1]/(float)(SUBGRID*SUBGRID)*255.0f);
    rgb->b = (unsigned char)(rgbf[2]/(float)(SUBGRID*SUBGRID)*255.0f);
}

void calc_image(void)
{
    int x, y;
    for (y=0; y < HEIGHT; y++)
        for (x=0; x < WIDTH; x++)
            calc_pixel(pixels + y*WIDTH + x, x, y);
}

void dump_pixels(void)
{
    int x, y;
    printf("P3\n%i %i\n255\n", WIDTH, HEIGHT);
    for (y=0; y < HEIGHT; y++) {
        for (x=0; x < WIDTH; x++) {
            if (x) printf(" ");
            printf("%i %i %i", pixels[y*WIDTH + x].r, pixels[y*WIDTH + x].g, pixels[y*WIDTH + x].b);
        }
        printf("\n");
    }
}

int __cdecl main()
{
    clock_t start = clock();
    fprintf(stderr, "Warmup pass...\n");
    calc_image();
    start = clock();
    calc_image();
    fprintf(stderr, "Elapsed time = %0.3f seconds\n",
        (float)(clock() - start)/(float)CLOCKS_PER_SEC);
    dump_pixels();
    return 0;
}

