/* $NetBSD$ */

/*-
* Copyright (c) 2011 The NetBSD Foundation, Inc.
* All rights reserved.
*
* This code is derived from software contributed to The NetBSD Foundation
* by Cherry G. Mathew <[email protected]>
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
* 1. Redistributions of source code must retain the above copyright
*    notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
*    notice, this list of conditions and the following disclaimer in the
*    documentation and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE NETBSD FOUNDATION, INC. AND CONTRIBUTORS
* ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
* TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
* PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR CONTRIBUTORS
* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
*/

#include <sys/cdefs.h>

__RCSID("$NetBSD$");

#include <assert.h>
#include <errno.h>
#include <fcntl.h>
#include <pthread.h>
#include <sched.h>
#include <semaphore.h>
#include <setjmp.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>

#include <sys/param.h>
#include <sys/sysctl.h>

#include <atf-c.h>

/* Barrier stuff */
/* We use a semaphore for barriers because it works for fork(2) as
* well as pthread(3)
*/
struct barrier {
       char *semfile; /* File name of the semaphore */
       sem_t *lock; /* operates as a mutex */
       volatile size_t stile;
};


static void
vm_barrier_init(struct barrier *bt, size_t ncpus)
{
       assert(bt != NULL);

       char _semfile[] = "/semXXXX", *semfile = _semfile;
//      semfile = mktemp(semfile);
       if (semfile == NULL) {
               fprintf(stderr, "Unable to get unique filename for semaphore\n");
               abort();
       }
       bt->lock = sem_open(semfile, O_CREAT, 0600, 1);

       if (bt->lock == SEM_FAILED) {
               fprintf(stderr, "Unable to open semaphore\n");
               perror("sem_open():");
               abort();
       }

       bt->semfile = malloc(sizeof semfile);
       if (bt->semfile == NULL) {
               fprintf(stderr, "Unable to malloc() for filename\n");
               sem_close(bt->lock);
               sem_unlink(semfile);
               abort();
       }

       strncpy(bt->semfile, _semfile, sizeof _semfile);
       bt->stile = ncpus;
}

static void
vm_barrier_destroy(struct barrier *bt)
{
       assert(bt != NULL);

       bt->stile = 0;
       sem_close(bt->lock);
       sem_unlink(bt->semfile);
       free(bt->semfile);
}

static void
vm_barrier_hold(struct barrier *bt)
{
       size_t stile;

       assert(bt != NULL);
       assert(bt->lock != NULL);
       assert(bt->stile != 0);

       printf("a) stile == %zd\n", bt->stile);
       while (sem_trywait(bt->lock)) assert(errno == EAGAIN); /* spinwait */
       bt->stile--;
       sem_post(bt->lock);
       printf("b) stile == %zd\n", bt->stile);
       do {
               while (sem_trywait(bt->lock)) /* spinwait */
                       assert(errno == EAGAIN);
               stile = bt->stile;
               sem_post(bt->lock);
       } while (stile);
       printf("c) stile == %zd\n", stile);
}


/*
* The goal of these tests is to stress the kernel pmap
* implementation, from userspace.
*/

/*
* Thread thrash: This test fires off one thread per CPU.
* Each thread makes synchronised, interleaved data accesses to a
* shared, locked page of memory. Each thread has a unique mapping to
* this page, obtained via mmap(). The mappings to the shared page are
* torn down and created afresh every time, in order to exercise the
* pmap routines. (XXX: break down this test into smaller tests that
* exercise identifiable areas of pmap.)
*
* This test only operates on a single pmap.
*/


static void
thrash(void *arg)
{
       assert(arg != NULL);
       int tid = *(int *)arg;

       /* Bind threads to given cpu */
       printf("I am thread #%d\n", tid);
       return;
}

static cpuid_t
getcpus(void)
{
       int ncpu;
       static int mib[2] = { CTL_HW, HW_NCPU };
       size_t len;


       len = sizeof ncpu;
       if (sysctl(mib, __arraycount(mib), &ncpu, &len, NULL, 0) == -1) {
               return 0;
       }

       return ncpu;
}

/*
* Set the fault handler of the current process to fault_routine()
* To restore the default handler, use prep_fault(NULL);
* Returns the fault handler that has been set.
*/

static void *
prep_fault(void *fault_routine)
{
       return NULL; /* XXX: */
}

/* Thread wrappers */
/* Quick note on thread "id". Since we assume one thread per cpu, the
* cpuid is used in place of the thread "id" for all practical
* purposes.
*/

static jmp_buf sequel[MAXCPUS];

/* Fault handler, to test for legitimate page faults. */
static void
thread_pagefault(void)
{
       longjmp(sequel[0 /* XXX */], 1);
       fprintf(stderr, "pagefault handler did not longjmp() !");
       abort();
}

struct thread_arg {
       void (*func)(void *);
       void (*abortf)(void *);
       void *arg;
};

struct thread_ctx {
       cpuid_t cid; /* The cpu number we are running on */
       pthread_t pth;
       cpuset_t *cset;
       struct barrier init_bar;
       struct thread_arg ctx;
};

/* Can only be called from own thread */

static void
thread_exit(struct thread_ctx *t)
{
       assert(t != NULL);
       assert(pthread_equal(t->pth, pthread_self()));

       vm_barrier_destroy(&t->init_bar);
       cpuset_destroy(t->cset);
       free(t);
       pthread_exit(NULL);
}

static inline bool
thread_equal(struct thread_ctx *t1, struct thread_ctx *t2)
{
       return (t1 == t2);
}

/*
* Same as thread_exit, but calls abort callback, if registered,
* before exiting
*/
static void
thread_abort(struct thread_ctx *t)
{
       assert(t != NULL);
       assert(pthread_equal(t->pth, pthread_self()));

       if (t->ctx.abortf != NULL) {
               t->ctx.abortf(t->ctx.arg);
       }

       vm_barrier_destroy(&t->init_bar);
       cpuset_destroy(t->cset);
       free(t);
       pthread_exit(NULL);
}

static void *
setjmp_tramp(void *arg)
{
       assert(arg != NULL);

       pthread_t pth;
       struct thread_ctx *t = arg;

       pth = pthread_self();

       printf("child addr of t->pth == %p\n", &t->pth);
       printf("child cid == %zd\n", t->cid);
       printf("child pth == %p\n", t->pth);
       sleep(1);
       vm_barrier_hold(&t->init_bar); /* Sync with thread_spawn */
       printf("child pth after  == %p\n", t->pth);
       printf("child pth self after  == %p\n", pth);
       if (!pthread_equal(pth, t->pth)) {
               printf("not the right child\n");
               while(1);
       }

       if (setjmp(sequel[t->cid])) {
               /*
                * got here via longjmp() from fault
                * routine.
                */

               printf("caught exception\n");
               prep_fault(NULL); /* XXX: reset exception handler */
               thread_abort(t);
       }
       t->ctx.func(t->ctx.arg);
       thread_exit(t);
       return NULL;
}


static struct thread_ctx *
thread_spawn(cpuid_t cid,  /* cpu number */
            void (*func)(void *),
            void *arg,
            void (*abortf)(void *))
{
       struct thread_ctx *t;
       cpuset_t *cpuset;

       assert(func != NULL);
       assert(cid <= MAXCPUS);

       t = (struct thread_ctx *) malloc(sizeof *t);
       if (t == NULL) {
               return NULL;
       }

       cpuset = cpuset_create();

       if (cpuset == NULL) {
               printf("Could not create cpuset\n");
               free(t);
               return NULL;
       }

       if (cpuset_set(cid, cpuset) == -1) {
               printf("Could not set cpuset affinity to cpu%lu \n", cid);
               cpuset_destroy(cpuset);
               free(t);
               return NULL;
       }

       t->cset = cpuset;
       t->cid = cid;
       t->ctx.func = func;
       t->ctx.arg = arg;
       if (abortf != NULL) {
               t->ctx.abortf = abortf;
       }

       vm_barrier_init(&t->init_bar, 2);

       printf("creating new thread for func: %p\n", t->ctx.func);

       printf("addr of t->pth == %p\n", &t->pth);
       if (pthread_create(&t->pth, NULL,
                          setjmp_tramp, t)) {
               printf("error creating thread \n");
               free(t);
               return NULL;
       }
       printf("parent cid == %zd\n", t->cid);
       printf("parent pth == %p\n", t->pth);
       vm_barrier_hold(&t->init_bar); /* Sync with setjmp_tramp() */

       /* Set affinity */

       if (pthread_setaffinity_np(t->pth, cpuset_size(t->cset), t->cset)) {
               printf("error binding thread to CPU %lu\n",
                      t->cid);
               /* XXX: "destroy" the thread */
               free(t);
               return NULL;

       }

       return t;
}

/*
* This function reaps the context memory, not thread_wait();
* This makes it compulsory to use this function from the controlling
* thread, to make sure memory is not leaked.
*
* This is slightly lame, but we're a testing framework, not a
* threading library.
*/

static int
thread_wait(struct thread_ctx *ctx)
{

       int error;
       pthread_t pth;

       assert(ctx != NULL);

       pth = ctx->pth;
       error = pthread_join(pth, NULL);

       /* ctx is free()d by the thread on thread_exit() */

       return error;
}

struct tt {
       int tid;
       pthread_t pth;
};

static void *
thread(void *arg)
{
       struct tt *ttp = arg;
       int tid = ttp->tid;
       pthread_t pth = ttp->pth;


       printf("I am thread %d\n", tid);

       if (pthread_equal(pthread_self(), pth)) {
               printf("pthread_self() matches\n");
       }

       sleep(3);
       pthread_exit(NULL);
}

static struct tt *
spawn(void)
{
       static int tid = 0;
       struct tt *ttp;

       printf("spawn entered at tid == %d\n", tid);
       ttp = malloc(sizeof *ttp);
       if (ttp == NULL) {
               fprintf(stderr, "malloc() failed\n");
               abort();
       }

       ttp->tid = tid;
       pthread_create(&ttp->pth, NULL, thread, &ttp->tid);
       printf("spawn finished at tid == %d\n", tid);

       tid++;
       return ttp;
}


ATF_TC(test_thread);
ATF_TC_HEAD(test_thread, tc)
{
       atf_tc_set_md_var(tc, "descr",
                         "test pthreads");
}
ATF_TC_BODY(test_thread, tc)
{

       pthread_join(spawn()->pth, NULL);
       pthread_join(spawn()->pth, NULL);
}

ATF_TC(thread_thrash);
ATF_TC_HEAD(thread_thrash, tc)
{
       atf_tc_set_md_var(tc, "descr",
                         "Thrash the TLB from within a single Address Space");
}

ATF_TC_BODY(thread_thrash, tc)
{
       cpuid_t cpuno, i;
       struct thread_ctx *t[MAXCPUS];

       /* 1) Detect no. of cpus via cpuset(3) */
       cpuno = getcpus();

       printf("Detected %lu cpus\n", cpuno);
       ATF_REQUIRE(cpuno > 0);

       /* 2) Fire off threads */
       printf("new threads\n");
       (void) thread_pagefault;
       for (i = 0; i < cpuno; i++) {
               t[i] = thread_spawn(i, thrash, NULL, NULL);
               if (t[i] == NULL) {
                       printf("thread spawn failed for cpu%lu\n", i);
               }

               /* XXX: destroy the other threads ? */
               ATF_REQUIRE(t[i] != NULL);
       }

       /* Wait for threads to join */
       for (i = 0; i < cpuno; i++) {
               thread_wait(t[i]);
               printf("joined\n");
       }

}

ATF_TP_ADD_TCS(tp)
{
       ATF_TP_ADD_TC(tp, test_thread);
       ATF_TP_ADD_TC(tp, thread_thrash);
       return atf_no_error();
}