#include <u.h>
#include <libc.h>
#include <mp.h>
#include <sat.h>
#include "dat.h"
#include "fns.h"
extern int satvar;
int
satand1(SATSolve *sat, int *a, int n)
{
int i, j, r;
int *b;
if(n < 0)
for(n = 0; a[n] != 0; n++)
;
r = 2;
for(i = j = 0; i < n; i++){
if(a[i] == 1 || a[i] == -2)
return 1;
if(a[i] == 2 || a[i] == -1)
j++;
else
r = a[i];
}
if(j >= n - 1) return r;
r = satvar++;
b = malloc(sizeof(int) * (n+1));
for(i = j = 0; i < n; i++){
if(a[i] == 2 || a[i] == -1)
continue;
b[j++] = -a[i];
sataddv(sat, -r, a[i], 0);
}
b[j++] = r;
satadd1(sat, b, j);
return r;
}
int
satandv(SATSolve *sat, ...)
{
int r;
va_list va;
va_start(va, sat);
satvafix(va);
r = satand1(sat, (int*)va, -1);
va_end(va);
return r;
}
int
sator1(SATSolve *sat, int *a, int n)
{
int i, j, r;
int *b;
if(n < 0)
for(n = 0; a[n] != 0; n++)
;
r = 1;
for(i = j = 0; i < n; i++){
if(a[i] == 2 || a[i] == -1)
return 2;
if(a[i] == 1 || a[i] == -2)
j++;
else
r = a[i];
}
if(j >= n-1) return r;
r = satvar++;
b = malloc(sizeof(int) * (n+1));
for(i = j = 0; i < n; i++){
if(a[i] == 1 || a[i] == -2)
continue;
b[j++] = a[i];
sataddv(sat, r, -a[i], 0);
}
b[j++] = -r;
satadd1(sat, b, j);
return r;
}
int
satorv(SATSolve *sat, ...)
{
va_list va;
int r;
va_start(va, sat);
satvafix(va);
r = sator1(sat, (int*)va, -1);
va_end(va);
return r;
}
typedef struct { u8int x, m; } Pi;
static Pi *π;
static int nπ;
static u64int *πm;
static void
pimp(u64int op, int n)
{
int i, j, k;
u8int δ;
nπ = 0;
for(i = 0; i < 1<<n; i++)
if((op >> i & 1) != 0){
π = realloc(π, sizeof(Pi) * (nπ + 1));
π[nπ++] = (Pi){i, 0};
}
for(i = 0; i < nπ; i++){
for(j = 0; j < i; j++){
δ = π[i].x ^ π[j].x;
if(δ == 0 || (δ & δ - 1) != 0 || π[i].m != π[j].m) continue;
if(((π[i].m | π[j].m) & δ) != 0) continue;
if(π[nπ-1].x == (π[i].x & π[j].x) && π[nπ-1].m == (π[i].m | δ)) continue;
π = realloc(π, sizeof(Pi) * (nπ + 1));
π[nπ++] = (Pi){π[i].x & π[j].x, π[i].m | δ};
}
}
for(i = k = 0; i < nπ; i++){
for(j = i+1; j < nπ; j++)
if((π[i].m & ~π[j].m) == 0 && (π[i].x & ~π[j].m) == π[j].x)
break;
if(j == nπ)
π[k++] = π[i];
}
nπ = k;
assert(nπ <= 1<<n);
}
static void
pimpmask(void)
{
int i, j;
u64int m;
πm = realloc(πm, sizeof(u64int) * nπ);
for(i = 0; i < nπ; i++){
m = 0;
for(j = π[i].m; ; j = j - 1 & π[i].m){
m |= 1ULL<<(π[i].x | j);
if(j == 0) break;
}
πm[i] = m;
}
}
static int
popcnt(u64int m)
{
m = (m & 0x5555555555555555ULL) + (m >> 1 & 0x5555555555555555ULL);
m = (m & 0x3333333333333333ULL) + (m >> 2 & 0x3333333333333333ULL);
m = (m & 0x0F0F0F0F0F0F0F0FULL) + (m >> 4 & 0x0F0F0F0F0F0F0F0FULL);
m = (m & 0x00FF00FF00FF00FFULL) + (m >> 8 & 0x00FF00FF00FF00FFULL);
m = (m & 0x0000FFFF0000FFFFULL) + (m >> 16 & 0x0000FFFF0000FFFFULL);
m = (u32int)m + (u32int)(m >> 32);
return m;
}
static u64int
pimpcover(u64int op, int)
{
int i, j, maxi, p, maxp;
u64int cov, yes, m;
yes = 0;
cov = op;
for(i = 0; i < nπ; i++){
if((yes & 1<<i) != 0) continue;
m = πm[i];
for(j = 0; j < nπ; j++){
if(j == i) continue;
m &= ~πm[j];
if(m == 0) break;
}
if(j == nπ){
yes |= 1<<i;
cov &= ~πm[i];
}
}
while(cov != 0){
j = popcnt(~cov & cov - 1);
maxi = -1;
maxp = 0;
for(i = 0; i < nπ; i++){
if((πm[i] & 1<<j) == 0) continue;
if((p = popcnt(πm[i] & cov)) > maxp)
maxi = i, maxp = p;
}
assert(maxi >= 0);
yes |= 1<<maxi;
cov &= ~πm[maxi];
}
return yes;
}
static void
pimpsat(SATSolve *sat, u64int yes, int *a, int n, int r)
{
int i, j, k;
int *cl;
cl = emalloc(sizeof(int) * (n + 1));
while(yes != 0){
i = popcnt(~yes & yes - 1);
yes &= yes - 1;
k = 0;
cl[k++] = r;
for(j = 0; j < n; j++)
if((π[i].m & 1<<j) == 0)
cl[k++] = (π[i].x >> j & 1) != 0 ? -a[j] : a[j];
// for(i = 0; i < k; i++) print("%d ", cl[i]); print("\n");
satadd1(sat, cl, k);
}
free(cl);
}
int
satlogic1(SATSolve *sat, u64int op, int *a, int n)
{
int i, j, o, r;
int s;
if(n < 0)
for(n = 0; a[n] != 0; n++)
;
assert(op >> (1<<n) == 0);
s = 0;
j = -1;
for(i = n; --i >= 0; ){
if((uint)(a[i] + 2) > 4){
if(j >= 0) break;
j = i;
}
s = s << 1 | a[i] == 2 | a[i] == -1;
}
if(i < 0){
if(j < 0) return 1 + (op >> s & 1);
o = op >> s & 1 | op >> s + (1<<j) - 1 & 2;
switch(o){
case 0: return 1;
case 1: return -a[j];
case 2: return a[j];
case 3: return 2;
}
}
r = satvar++;
pimp(op, n);
pimpmask();
pimpsat(sat, pimpcover(op, n), a, n, r);
op ^= (u64int)-1 >> 64-(1<<n);
pimp(op, n);
pimpmask();
pimpsat(sat, pimpcover(op, n), a, n, -r);
return r;
}
int
satlogicv(SATSolve *sat, u64int op, ...)
{
va_list va;
int r;
va_start(va, op);
satvafix(va);
r = satlogic1(sat, op, (int*)va, -1);
va_end(va);
return r;
}