#include <setjmp.h>
#include "diald.h"

#define TOK_LE 256
#define TOK_GE 257
#define TOK_NE 258
#define TOK_INET 259
#define TOK_STR 260
#define TOK_NUM 261
#define TOK_ERR 262
#define TOK_EOF 263
#define ADVANCE token = token->next

struct prule {
    char *name;
} prules[FW_MAX_PRULES];
static int nprules = 0;

static struct var {
   char *name;
   int offset;
   int shift;
   unsigned int mask;
   struct var *next;
} *vars = 0;


typedef struct token {
    int offset;
    int type;
    char *str;
    struct token *next;
} Token;

char *errstr;
Token *tlist;
Token *token;
char *context;

static jmp_buf unwind;

void parse_error(char *s)
{
    syslog(LOG_ERR,"%s parsing error. Got token '%s'. %s",context,token->str,s);
    syslog(LOG_ERR,"parse string: '%s'",errstr);
    longjmp(unwind,1);
}

void tokenize(char *cntxt, int argc, char **argv)
{
    char *s, *t;
    int i, len;
    Token *prev = 0, *new;

    context = cntxt;
    /* merge the arguments into one string */

    for (len = i = 0; i < argc; i++)
	len += strlen(argv[i])+1;
    t = errstr = malloc(len);
    if (errstr == 0) { syslog(LOG_ERR,"Out of memory! AIIEEE!"); die(1); }
    strcpy(errstr,argv[0]);
    for (i = 1; i < argc; i++) { strcat(errstr," "); strcat(errstr,argv[i]); }

    tlist = 0;

    for (s = errstr; *s;) {
	new = malloc(sizeof(Token));
	if (new == 0) { syslog(LOG_ERR,"Out of memory! AIIEEE!"); die(1); }
        if (prev == 0) tlist = new; else prev->next = new;
	prev = new;
	new->next = 0;
	new->offset = s-errstr;
	if (*s == '<' && s[1] == '=') {
	    new->type = TOK_LE; s += 2;
	} else if (*s == '>' && s[1] == '=') {
	    new->type = TOK_GE; s += 2;
	} else if (*s == '!' && s[1] == '=') {
	    new->type = TOK_NE; s += 2;
	} else if (isalpha(*s) || *s == '.' || *s == '_' || *s == '-') {
	    new->type = TOK_STR;
	    while (isalnum(*s) || *s == '.' || *s == '_' || *s == '-') s++;
	} else if (*s == '0' && s[1] == 'x' && isxdigit(s[2])) {
	    new->type = TOK_NUM;
	    s += 2;
	    while (isxdigit(*s)) s++;
	} else if (*s == '0' && isdigit(s[1])) {
	    new->type = TOK_NUM;
	    s++;
	    while (isdigit(*s)) s++;
	} else if (isdigit(*s)) {
	    while (isdigit(*s)) s++;
	    if (*s == '.') {
	        new->type = TOK_INET;
		s++;
		if (!isdigit(*s)) goto tokerr;
		while (isdigit(*s)) s++;
		if (*s != '.') goto tokerr;
		s++;
		if (!isdigit(*s)) goto tokerr;
		while (isdigit(*s)) s++;
		if (*s != '.') goto tokerr;
		s++;
		if (!isdigit(*s)) goto tokerr;
		while (isdigit(*s)) s++;
	        if (*s == '.') s++;
	        goto done;
tokerr:
		new->type = TOK_ERR;
	    } else {
		new->type = TOK_NUM;
	    }
	} else {
	    new->type = *s++;
	}
done:
	len = (s-errstr)-new->offset;
	new->str = malloc(len+1);
	if (new->str == 0) { syslog(LOG_ERR,"Out of memory! AIIEEE!"); die(1); }
	strncpy(new->str,errstr+new->offset,len);
	new->str[len] = 0;
    }
    new = malloc(sizeof(Token));
    if (new == 0) { syslog(LOG_ERR,"Out of memory! AIIEEE!"); die(1); }
    if (prev == 0) tlist = new; else prev->next = new;
    prev = new;
    new->next = 0;
    new->offset = s-errstr;
    new->type = TOK_EOF;
    new->str = strdup("");
#ifdef 0
    token = tlist;
    while (token) {
	printf("'%s', ",token->str);
	token = token->next;
    }
    printf("\n");
#endif
    token = tlist;
}

void free_tokens(void)
{
    Token *next;
    if (token && token->type != TOK_EOF)
	syslog(LOG_ERR,
	    "Parsing error. Got token '%s' when end of parse was expected.",
	    token->str);
    while (tlist) {
	next = tlist->next;
	free(tlist->str);
	free(tlist);
	tlist = next;
    }
    tlist = 0;
    free(errstr);
}


void init_prule(FW_ProtocolRule *rule)
{
    rule->protocol = 0;
}

void init_filter(FW_Filter *filter)
{
    filter->prule = 0;
    filter->dir = FW_DIR_BOTH;
    filter->log = 0;
    filter->count = 0;
    filter->timeout = 0;
}

void parse_whitespace(void)
{
    if (token->type != ' ') parse_error("Expecting whitespace");
    ADVANCE;
}

void parse_new_prule_name(void)
{
    int i;
    if (token->type != TOK_STR) parse_error("Expecting a string.");
    for (i = 0; i < nprules; i++)
	if (strcmp(token->str,prules[i].name) == 0)
	    parse_error("Rule name already defined.");
    prules[nprules].name = strdup(token->str);
    ADVANCE;
}

void parse_protocol_name(FW_ProtocolRule *prule)
{
    struct protoent *proto;
    if (token->type == TOK_STR) {
	if (strcmp(token->str,"any") == 0)
	    { prule->protocol = 255; ADVANCE; return; }
        if ((proto = getprotobyname(token->str)))
	    { prule->protocol = proto->p_proto; ADVANCE; return; }
	parse_error("Expecting a protocol name or 'any'.");
    } else if (token->type != TOK_NUM) {
	int p;
	sscanf(token->str,"%i",&p);
	if (p > 254) parse_error("Expecting number from 0-254.");
	prule->protocol = p;
    } else
        parse_error("Expecting a string or a number.");
}

int parse_offset(void)
{
    int v;
    int flag = 0;
    if (token->type == '+') { flag = 1; ADVANCE; }
    if (token->type == TOK_NUM) {
	sscanf(token->str,"%i",&v);
	ADVANCE;
	if (FW_OFFSET(v) != v) parse_error("Offset definition out of range.");
	return ((flag) ? FW_DATA_OFFSET(v) : FW_IP_OFFSET(v));
    }
    parse_error("Expecting an offset definition: <num> or +<num>.");
    return 0; /* NOTREACHED */
}

void parse_prule_spec(FW_ProtocolRule *prule)
{
    int i;
    prule->codes[0] = parse_offset();
    for (i = 1; i < FW_ID_LEN; i++) {
	if (token->type != ':') parse_error("Expecting ':'");
	ADVANCE;
	prule->codes[i] = parse_offset();
    }
}

void parse_prule_name(FW_Filter *filter)
{
    int i;
    if (token->type != TOK_STR) parse_error("Expecting a string.");
    for (i = 0; i < nprules; i++)
	if (strcmp(token->str,prules[i].name) == 0) {
	    filter->prule = i;
	    ADVANCE;
	    return;
	}
    parse_error("Not a known protocol rule.");
}

void parse_timeout(FW_Filter *filter)
{
    int to;
    if (token->type != TOK_NUM) parse_error("Expecting a number.");
    sscanf(token->str,"%i",&to);
    if (to < 0 || to > (1<<sizeof(filter->timeout)*8)-1)
	parse_error("Out of acceptable range for a timeout.");
    filter->timeout = to;
    ADVANCE;
}

/* <rvalue> ::= <num> | <name> | <inet> */
int parse_rvalue(void)
{
    int v;
    if (token->type == TOK_NUM) {
	sscanf(token->str,"%i",&v);
	ADVANCE; return v;
    } else if (token->type == TOK_INET) {
	if ((v = ntohl(inet_addr(token->str))) == -1)
	    parse_error("Bad inet address specification.");
	ADVANCE; return v;
    } else if (token->type == TOK_STR) {
	struct protoent *proto;
	struct servent *serv;
	if ((proto = getprotobyname(token->str))) {
	    ADVANCE; return proto->p_proto;
	} else if (strncmp("udp.",token->str,4) == 0) {
	    if ((serv = getservbyname(token->str+4,"udp"))) {
	 	ADVANCE; return htons(serv->s_port);
	    }
	    parse_error("Not a known udp service port.");
	} else if (strncmp("tcp.",token->str,4) == 0) {
	    if ((serv = getservbyname(token->str+4,"tcp"))) {
	 	ADVANCE; return htons(serv->s_port);
	    }
	    parse_error("Not a known tcp service port.");
	}
	parse_error("Not a known value name.");
    } else {
	parse_error("Expecting an <rvalue> specification.");
    }
    return 0; /* NOTREACHED */
}


/* <varspec> ::= <offset> [(<shift>)] [&<mask>] */
void parse_varspec(struct var *variable)
{
    int shift  = 0;
    variable->offset = parse_offset();
    if (token->type == '(') {
	ADVANCE;
	if (token->type != TOK_NUM)
	    parse_error("Expecting a bit shift value.");
	sscanf(token->str,"%i",&shift);
	if (shift > 31) parse_error("Shift value must be in [0,31].");
	ADVANCE;
	if (token->type != ')') parse_error("Expecting a ')'.");
	ADVANCE;
    }
    variable->shift = shift;
    if (token->type == '&') {
	ADVANCE;
	variable->mask = parse_rvalue();
    } else {
	variable->mask = 0xffffffffU;
    }
}

void parse_var_name(struct var *variable)
{
    struct var *cvar;

    if (token->type == TOK_STR) {
	for (cvar = vars; cvar; cvar = cvar->next) {
	    if (strcmp(cvar->name,token->str) == 0)
		parse_error("Expecting a new variable name");
	}
	variable->name = strdup(token->str);
	ADVANCE;
    } else
       parse_error("Expecting a variable name.");
}

/* <varref> ::= <name> */
void parse_varref(FW_Term *term)
{
    struct var *cvar;

    if (token->type == TOK_STR) {
	for (cvar = vars; cvar; cvar = cvar->next) {
	    if (strcmp(cvar->name,token->str) == 0) {
		term->offset = cvar->offset;
		term->shift = cvar->shift;
		term->mask = cvar->mask;
		ADVANCE;
		return;
	    }
	}
	parse_error("Not a known variable name.");
    }
    parse_error("Expecting a variable name.");
}

/* <lvalue> ::= <varref> | <varref>&<rvalue> */
void parse_lvalue(FW_Term *term)
{
    parse_varref(term);
    if (token->type == '&') {
	ADVANCE;
	term->mask &= parse_rvalue();
    }
}

int parse_op(FW_Term *term)
{
    if (token->type == TOK_NE) term->op = FW_NE;
    else if (token->type == '=') term->op = FW_EQ;
    else if (token->type == TOK_GE) term->op = FW_GE;
    else if (token->type == TOK_LE) term->op = FW_LE;
    else return 0;
    ADVANCE;
    return 1;
}

/* <term> ::= <lvalue> | !<lvalue> | <lvalue> <op> <rvalue> */
void parse_term(FW_Filter *filter)
{
    if (token->type == '!') {
	ADVANCE;
	parse_lvalue(&filter->terms[filter->count]);
	filter->terms[filter->count].op = FW_EQ;
	filter->terms[filter->count].test = 0;
    } else {
	parse_lvalue(&filter->terms[filter->count]);
	if (parse_op(&filter->terms[filter->count])) {
	    filter->terms[filter->count].test = parse_rvalue();
	} else {
	    filter->terms[filter->count].op = FW_NE;
	    filter->terms[filter->count].test = 0;
	}
    }
    filter->count++;
}

void parse_terms(FW_Filter *filter)
{
    if (token->type == TOK_STR && strcmp(token->str,"any") == 0)
	{ ADVANCE; return; }
    parse_term(filter);
    while (token->type == ',') { ADVANCE; parse_term(filter); }
}

void parse_prule(void *var, char **argv)
{
    FW_ProtocolRule prule;
    struct firewall_req req;
    tokenize("prule",3,argv);
    if (setjmp(unwind)) { token = 0; free_tokens(); return; }
    parse_new_prule_name();
    parse_whitespace();
    parse_protocol_name(&prule);
    parse_whitespace();
    parse_prule_spec(&prule);
    free_tokens();
    nprules++;
    /* Save the prule in the kernel */
    req.unit = fwunit;
    req.fw_arg.rule = prule;
    ctl_firewall(IP_FW_APRULE,&req);
}

void parse_accept(void *var, char **argv)
{
    FW_Filter filter;
    struct firewall_req req;
    init_filter(&filter);
    filter.accept = 1;
    tokenize("accept",3,argv);
    if (setjmp(unwind)) { token = 0; free_tokens(); return; }
    parse_prule_name(&filter);
    parse_whitespace();
    parse_timeout(&filter);
    parse_whitespace();
    parse_terms(&filter);
    free_tokens();
    /* Save the filter the kernel */
    req.unit = fwunit;
    req.fw_arg.filter = filter;
    ctl_firewall(IP_FW_AFILT,&req);
}

void parse_reject(void *var, char **argv)
{
    FW_Filter filter;
    struct firewall_req req;
    init_filter(&filter);
    filter.accept = 0;
    tokenize("reject",2,argv);
    if (setjmp(unwind)) { token = 0; free_tokens(); return; }
    parse_prule_name(&filter);
    parse_whitespace();
    parse_terms(&filter);
    free_tokens();
    /* Save the filter the kernel */
    req.unit = fwunit;
    req.fw_arg.filter = filter;
    ctl_firewall(IP_FW_AFILT,&req);
}

void parse_var(void *var, char **argv)
{
    struct var *variable = malloc(sizeof(struct var));
    if (variable == 0) { syslog(LOG_ERR,"Out of memory! AIIEEE!"); die(1); }
    tokenize("var",2,argv);
    if (setjmp(unwind)) { token = 0; free_tokens(); return; }
    parse_var_name(variable);
    parse_whitespace();
    parse_varspec(variable);
    free_tokens();
    /* add the new variable to the linked list */
    variable->next = vars;
    vars = variable;
}

void flush_prules(void)
{
    struct firewall_req req;
    req.unit = fwunit;
    ctl_firewall(IP_FW_PFLUSH,&req);
    nprules = 0;
}

void flush_vars(void)
{
    struct var *next;
    for (; vars; vars = next) {
	next = vars->next;
	free(vars->name);
	free(vars);
    }
    vars = 0;
}

void flush_filters(void)
{
    struct firewall_req req;
    req.unit = fwunit;
    ctl_firewall(IP_FW_FFLUSH,&req);
}
