#include "utils/network.h"
#include "utils/hostname.h"
#ifdef __linux__
#include <time.h>
#else
#include <sys/timers.h>
#endif
#include <stdio.h>
#include <stdlib.h>

#define TIMEOFDAY 1       /* $%#@ OSF/1 doesn't define this by default */
extern "C" {
extern int getclock(int clock_type, struct timespec *tp); /* ditto */
}

static struct timespec start_time;
static struct timespec end_time;

static Network* set_up_connection () {
  
  int sock;
  if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
    perror("opening stream socket");
    return 0;
  }

  struct sockaddr_in sa;
  sa.sin_family = AF_INET;
  sa.sin_addr.s_addr = INADDR_ANY;
  sa.sin_port = htons(4747);
  if (!findhost("note", &(sa.sin_addr))) {
    fprintf(stderr, "Couldn't find host\n");
    exit(EXIT_FAILURE);
  }
  if (connect(sock, (struct sockaddr*) &sa, sizeof(sa)) < 0) {
      perror("connecting to OR");
      return 0;
  }

  Network* net = new Network(sock);

  return net;
}

void start_clock () {
  if (getclock(TIMEOFDAY, &start_time) != 0) {
    perror("setting start time");
    exit(EXIT_FAILURE);
  }
}

void stop_clock () {
  if (getclock(TIMEOFDAY, &end_time) != 0) {
    perror("setting end time");
    exit(EXIT_FAILURE);
  }
}

void print_elapsed (int trials) {
  struct timespec diff;
  diff.tv_sec = end_time.tv_sec - start_time.tv_sec;
  diff.tv_nsec = end_time.tv_nsec - start_time.tv_nsec;
  if (diff.tv_nsec < 0) {
    diff.tv_sec -= 1;
    diff.tv_nsec += 1000000000;
  }
  printf("Elapsed: %d.%.3d sec\n", diff.tv_sec, (diff.tv_nsec / 1000000));
  printf("Average: %d.%.9d sec per roundtrip\n", 
	 (diff.tv_sec / trials), 
	 ((((diff.tv_sec % trials) * 1000000000) + diff.tv_nsec) / trials));
}

void trial (Network* net, int size, char *buf) {
  if (net->send_buffer(&size, sizeof(int))) {
    net->flush();
    if (!net->recv_buffer(buf, size)) {
      fprintf(stderr, "Couldn't get message from server\n");
      exit(EXIT_FAILURE);
    }
  }
  else {
      fprintf(stderr, "Couldn't send message to server\n");
      exit(EXIT_FAILURE);
    }
}

void main (int argc, char* argv[]) {
  if (argc == 3) {
    int trials = atoi(argv[1]);
    int size = atoi(argv[2]);
    Network* net = set_up_connection();
    char* buf = new char[size];
    start_clock();
    for (int i = 0; i < trials; i++) {
      trial(net, size, buf);
    }
    stop_clock();
    print_elapsed(trials);
    delete buf;
    exit(EXIT_SUCCESS);
  }
  fprintf(stderr, "Usage: %s trials bytes\n"
	          "trials is the number of roundtrips;\n"
	          "bytes is the number of bytes sent "
                  "from the server each time.\n", 
	  argv[0]);
}

