/*
 * Filename: msql-import.c
 * Project:  msql-import
 *
 * Function: This program reads the contents of a flat file
 *           and loads it into a Mini SQL table.
 *           The program does not create the MiniSQL table,
 *           it must have been created beforehand.  msql-import
 *           simply sends INSERTs to the database server.
 *
 *           When importing a flat fiel that was created in
 *           the DOS world, make sure you convert the file
 *           to unix by replacing CR/LF with LF only.
 *
 * Date:     March 1996
 *
 * Author:   Pascal Forget <pascal@wsc.com>
 *
 * Copyright (C) 1995-1996 Pascal Forget.  All Rights Reserved
 *
 * PASCAL FORGET MAKES NO REPRESENTATIONS ABOUT THE SUITABILITY
 * OF THIS SOFTWARE FOR ANY PURPOSE.  IT IS SUPPLIED "AS IS"
 * WITHOUT EXPRESS OR IMPLIED WARRANTY.
 *
 * Permission to use, copy, modify, and distribute this software
 * and its documentation for any purpose and without fee is
 * hereby granted, provided that the above copyright notice
 * appear in all copies of the software.
 */

#include "msql-import.h"
#include <string.h>

#define MSQL_IMPORT_VERSION "0.0.9"
#define IMPORT_MAX_FIELDS 255
#define MAX(a,b) (a>b)a:b

#define DEBUG 1

int sock;

SafeString *query;

/*****************************************************************
 * _append_character() appends a character to str                *
 *****************************************************************/

void
append_character(SafeString *str, char c)
{
    int length = 0;
    
    if (str[0].buffer != NULL)
    {
	length = strlen(str[0].buffer);
    }
    
    set_string_capacity(str, length + 1);

    str[0].buffer[length] = c;
    str[0].buffer[length+1] = '\0';
}

/*****************************************************************
 * _append_string_buffer() appends a string buffer to str        *
 *****************************************************************/

void
append_string_buffer(SafeString *str, const char *newBuffer)
{
    if (newBuffer != NULL)
    {
	if (str[0].buffer != NULL)
	{
	    set_string_capacity(str,strlen(str[0].buffer)+strlen(newBuffer));
	}
	else
	{
	    set_string_capacity(str, strlen(newBuffer));
	}
	    
	if (str[0].buffer != NULL)
	{
	    strcat(str[0].buffer, newBuffer);
	}
	else
	{
	    strcpy(str[0].buffer, newBuffer);
	}
    }
}

/*****************************************************************
 * _copy_string_buffer() copies a string buffer in str           *
 *****************************************************************/

void
copy_string_buffer(SafeString *str, const char *newBuffer)
{
    if (newBuffer != NULL)
    {
	set_string_capacity(str, strlen(newBuffer));
	strcpy(str[0].buffer, newBuffer);
    }
    else
    {
	str[0].buffer = NULL;
    }
}

/*****************************************************************
 * _create_string() allocates and initializes a SafeString       *
 *****************************************************************/

SafeString *
create_string(void)
{
    SafeString *string = (SafeString *)malloc(sizeof(SafeString));
    string[0].buffer = (char *)NULL;
    string[0].capacity = 0;
    return string;
}

/*****************************************************************
 * _set_string_buffer() sets the String's buffer                 *
 *****************************************************************/

void
set_string_buffer(SafeString *str, char *newBuffer)
{
    str[0].buffer = newBuffer;

    if (str[0].buffer != NULL)
    {
	str[0].capacity = strlen(newBuffer);
    }
    else
    {
	str[0].capacity = 0;
    }
}

/*****************************************************************
 * _set_string_capacity() allocates a bigger buffer if needed    *
 *****************************************************************/

void
set_string_capacity(SafeString *str, unsigned int cap)
{
    if (str[0].capacity < cap)
    {
	if (str[0].buffer != NULL)
	{
	    str[0].buffer = (char *)realloc(str[0].buffer, cap+1);
	}
	else
	{
	    str[0].buffer = (char *)malloc(cap+1);
	    str[0].buffer[0] = '\0';
	}
	str[0].capacity = cap;
    }
}

/*****************************************************************
 * _safe_c_string_copy() makes a copy of a string and returns    *
 * the new string.                                               *
 *                                                               *
 * This function replaces strdup() because:                      *
 *      Not all Unixes implement strdup                          *
 *      Some strdup implementations crash on a NULL argument     *
 *****************************************************************/

char *
safe_c_string_copy(const char *s)
{
    char *s1;
    
    if (s == NULL)
    {
	return(NULL);
    }
    
    s1 = (char *)malloc(strlen(s) + 1);
    
    if (s1 == NULL)
    {
	return(NULL);
    }
    
    (void)strcpy(s1,s);
    
    return(s1);
}

/**********************************************************************/
/* File Id:                     strparse                              */
/* Author:                      Stan Milam.                           */
/* Date Written:                20-Feb-1995.                          */
/* Description:                                                       */
/*     The str_parse() function is used to extract fields from de-    */
/*     limited ASCII records. It is designed to deal with empty fields*/
/*     in a logical manner.                                           */
/*                                                                    */
/* Arguments:                                                         */
/*     char **str       - The address of a pointer which in turn      */
/*                        points to the string being parsed. The      */
/*                        actual pointer is modified with each call to*/
/*                        point to the beginning of the next field.   */
/*     char *delimiters - The address of the string containing the    */
/*                        characters used to delimit the fields within*/
/*                        the record.                                 */
/*                                                                    */
/* Return Value:                                                      */
/*     A pointer of type char which points to the current field in the*/
/*     parsed string.  If an empty field is encountered the address   */
/*     is that of an empty string (i.e. "" ). When there are no more  */
/*     fields in the record a NULL pointer value is returned.         */
/*                                                                    */
/**********************************************************************/

char *
str_parse( char **str, char *delimiters )
{
    char *head, *tail, *rv;

    if ( *str == NULL || **str == '\0' )
    {
        rv = NULL;
    }
    else if ( delimiters == NULL || *delimiters == '\0' )
    {
        rv = NULL;
    }
    else
    {
        rv = head = *str;
        if (( tail = strpbrk( head, delimiters ) ) == NULL)
            *str = head + strlen(head);
        else {
            *tail = '\0';
            *str = tail + 1;
        }
    }
    return rv;
}

/*****************************************************************
 * _abort_import() closes the client, then exits the program     *
 *****************************************************************/

void
abort_import(int exitCode)
{
    if (sock > -1)
    {
	msqlClose(sock);
    }
    exit(exitCode);
}

/*****************************************************************
 * alarm() prints an error message then calls abort_import()     *
 *****************************************************************/

void
alarm_msql(void)
{
    fprintf(stderr, "msql-import error: %s\n", msqlErrMsg);
    abort_import(-1);
}

void
alarm_msg(const char *message)
{
    fprintf(stderr, "%s\n", message);
    abort_import(-1);
}

/*****************************************************************
 * _append_string() returns the string s, concatenated with the  *
 * append string.  The returned string is not garanteed to be    *
 * at the same memory location as s was.                         *
 *****************************************************************/

char *
append_string(char *s, const char *append)
{
    int newLength;
    
    if ((append!= NULL) && ((newLength = strlen(append)) > 0))
    {
	s = realloc(s, strlen(s) + newLength + 1);

	strcat(s, append);
    }

    return s;
}

/*****************************************************************
 * Put a backslash in front of single quote, parentheses,        *
 * backslashes, and other special characters in the string. This *
 * is needed for inserting strings containing these characters   *
 * in a relational database.                                     *
 *****************************************************************/

#define BACKSLASH_CHAR 92

char *
backslashify_special_chars(const char *s)
{
    char *str, *ptr;
    int i, j, index, len;

    const char special_chars[] ={ '\'', '\"', 0}; 
    int current = 0;
    
    if ((!s) ||
	(strlen(s) == 0) ||
	((strchr(s, '\'') == NULL) &&
	 (strchr(s, ')') == NULL) &&
	 (strchr(s, '\\') == NULL) &&
	 (strchr(s, '(') == NULL)))
    {
	return safe_c_string_copy(s);
    }

    str = safe_c_string_copy(s);
    
    /*
     * Process the backslash character first.  We can't use strchr
     * on strings looking for a backslash
     */

    i = 0;
    
    while (str[i] != '\0')
    {
	if (str[i] == BACKSLASH_CHAR)
	{
	    len = strlen(str) + 1;
	    str = (char *)realloc(str, len+2);

	    /* move the string to the right */
	    
	    j = len-1;

	    while (j>i)
	    {
		str[j+2] = str[j--];
	    }
	    str[len+1] = '\0';

	    /* insert 2 backslashes in the gap */

	    str[i] = BACKSLASH_CHAR;
	    str[i+1] = BACKSLASH_CHAR;
	    
	    i+=3; /* this is not a mistake */
	}
	
	if (str[i] != '\0')
	{
	    i++;
	}
    }
    
    /* Do the other special characters */
    
    while(special_chars[current])
    {
	index = -2;

	while ((ptr = strchr(str+index+2, special_chars[current])) != NULL)
	{
	    len = strlen(str)+1;
	    index = ptr-str;
	    str = (char *)realloc(str, len+1);
	    
	    for (i=len+1; i>=index; i--)
	    {
		str[i+1] = str[i];
	    }
	    
	    str[index] = BACKSLASH_CHAR;
	}
	current++;
    }

    return str;
}

/*****************************************************************
 * _datatypes() returns an array of the field datatypes.         *
 *                                                               *
 * msql-import uses this information to determine whether or     *
 * not to enclose the value to be inserted between quotes, and   *
 * msql-import will validate the data depending on the datatype. *
 *                                                               *
 * Algorithm:                                                    *
 *                                                               *
 *   FOR EACH field IN fields                                    *
 *       find its definition and add its type to the result      *
 *****************************************************************/

int *
datatypes(const char *table, const char *fields)
{
    int fieldCount = 0;
    char *tableCopy = safe_c_string_copy(table);
    int *types = (int *)malloc((IMPORT_MAX_FIELDS+1) * sizeof(int));
    m_result *result = msqlListFields(sock, tableCopy);
    m_field *field;
    const char *token;
    char *fieldsCopy;
    char **ptr;
    int done = 0;
    int found = 0;
    const char *buf[] = {"unknown", "int", "char", "real", "ident", "null"};
    const char **str = buf;
    
    fieldsCopy = safe_c_string_copy(fields);
    ptr = &fieldsCopy;

    free(tableCopy);

    if (fieldsCopy != NULL)
    {
	while (!done)
	{
	    if ((token = str_parse(ptr, ",")) != NULL)
	    {
		msqlFieldSeek(result, 0);
		found = 0;
		
		while((!found) && (field = msqlFetchField(result)))
		{
		    if (!strcmp(field->name, token))
		    {
			printf(" Field: %s ", field->name);
			printf(" Type: %s\n", *(str+(int)field->type));
			
			types[fieldCount++ +1] = field->type;
			found = 1;
		    }
		}
		
		if (!found)
		{
		    alarm_msg("Field definition not found. Exiting.");
		}
	    }
	    else
	    {
		done = 1;
	    }
	}
    }
    else
    {
	msqlFieldSeek(result, 0);

	while ((field = msqlFetchField(result)) != NULL)
	{
	    printf(" Field: %s\t", field->name);
	    printf(" Type: %s\n", *(str+(int)field->type));

	    types[fieldCount++ + 1] = field->type;
	}
    }
    
    msqlFreeResult(result);

#if 0
    if (fieldsCopy!= NULL)
    {
	free(fieldsCopy);
    }
#endif

    /* Put the number of fields in the first position in the array */
    
    types[0] = fieldCount;
    
    printf("\n");
    return types;
}

/*****************************************************************
 * _get_record reads one row from the data file and returns it   *
 *****************************************************************/

#define GET_REC_BLOCKSIZE 8192

char *
get_record(FILE *fp, const char record_delimiter, SafeString *record)
{
    char c;
    int i = 0;

    set_string_capacity(record, GET_REC_BLOCKSIZE);
    record[0].buffer[0] = '\0';
    
    for ( ; ; )
    {
	if (((c = getc(fp)) == EOF) || (c == record_delimiter))
	{
	    break;
	}
	else
	{
	    /* add the block to the record string */
	    append_character(record, c);
	    i++;
	}
    }

    if (i)
    {
	return (char *)record[0].buffer;
    } else
    {
	return NULL;
    }
}

/*****************************************************************
 * _import_file opens the flat file, reads it one record at a    *
 * time, and imports it into the specified Mini SQL table        *
 *****************************************************************/

void
import_file(char *table,
	    char *path,
	    char *fieldDel,
	    char *rowDel,
	    const char *fields)
{
    FILE *fp;
    int recordCount = 0; /* Number of records read from the flat file */
    int *data_types = datatypes(table, fields);
    int recs = 0; /* Number of records successfully inserted */
    SafeString *record;
    
    if ((fp = fopen(path, "r")) == NULL)
    {
	perror(path);
	abort_import(-1);
    }

    record = create_string();
    query = create_string();
    set_string_capacity(query, GET_REC_BLOCKSIZE);
    
    while (get_record(fp, rowDel[0], record) != NULL)
    {
	recs += insert_record(table,
			      record[0].buffer,
			      fields,
			      data_types,
			      recordCount++,
			      fieldDel);
    }

    free(fieldDel);
    free(rowDel);
    free(table);
    free(record);
    free(query);
    free(path);
    
    if (recs > 1)
    {
	fprintf(stdout, "%i rows successfully imported.  ", recs);
    }
    else
    {
	if (recs == 1)
	{
	    fprintf(stdout, "1 row was successfully imported.  ");
	}
	else
	{
	    fprintf(stdout, "No row was successfully imported.  ");
	}
    }

    if (recordCount-recs > 1)
    {
	fprintf(stdout, "%i rows were rejected.\n\n", recordCount-recs);
    }
    else
    {
	if (recordCount-recs == 1)
	{
	    fprintf(stdout, "1 row was rejected.\n\n");
	}
	else
	{
	    fprintf(stdout, "No row was rejected.\n\n");
	}
    }
}

/*****************************************************************
 * _insert_record sends the INSERT statement to the msql server  *
 * Returns 1 if record was inserted successfully, 0 otherwise.   *
 *****************************************************************/

int
insert_record(const char *table,
	      char *record,
	      const char *fields,
	      int *types,
	      int recordCount,
	      char *fieldDel)
{
    char **ptr = &record;
    char *token;
    char *value;
    char delimiter[2];
    int done = 0;
    int fieldNumber = 1;
    int fieldsCount = types[0]; /* The number of values to be inserted */
    
    sprintf(delimiter, "%s", fieldDel);
    query[0].buffer[0] = '\0';

    copy_string_buffer(query, "INSERT INTO ");
    append_string_buffer(query, table);
    
    if ((fields == NULL) || (strlen(fields) < 2))
    {
	append_string_buffer(query, " VALUES(");
    }
    else
    {
	append_string_buffer(query, " (");
	append_string_buffer(query, fields);
 	append_string_buffer(query, ") VALUES(");
    }
    
    if((record != NULL) && (token = str_parse(ptr,delimiter)))
    {
	switch(types[1])
	{
	  case CHAR_TYPE:
	    value = backslashify_special_chars(token);
	    append_character(query, '\'');
	    append_string_buffer(query, value);
	    append_string_buffer(query,"',");
	    
	    if (value != NULL)
	    {
		free(value);
	    }
	    
	    break;

	  case NULL_TYPE:
	    append_string_buffer(query, "NULL,");
	    break;

	  default:
	    append_string_buffer(query, token);
	    append_character(query,',');
	}
    }
    else
    {
	return 0;
    }

    while (!done)
    {
	if ((token = str_parse(ptr, delimiter)) != NULL)
	{
	    switch(types[fieldNumber++ + 1])
	    {
	      case CHAR_TYPE:        
		value = backslashify_special_chars(token);
		append_character(query, '\'');
		append_string_buffer(query, value);
		append_string_buffer(query,"',");
		
		if (value!=NULL)
		{
		    free(value);
		}
		break;
		
	      case NULL_TYPE:
		append_string_buffer(query, "NULL,");
		break;
		
	      default:
		append_string_buffer(query, token);
		append_character(query, ',');
	    }
	}
	else
	{
	    done = 1;
	}
    }

    /*
     * Add NULL values if the number of values read in the record
     * is lower than the number of field specified in the command
     * line or if the number of values in the record is lower than
     * the number of fields in the table if the field names weren't
     * specified on the command line
     */
    
    while (fieldNumber < fieldsCount)
    {
	append_string_buffer(query, "NULL,");
	fieldNumber++;
    }

    
    /* Remove the last comma and the record delimiter ('\n') */
    query[0].buffer[strlen(query[0].buffer)-1] = '\0';
    
    append_character(query, ')');

    if (msqlQuery(sock, query[0].buffer) == -1)
    {
	fprintf(stderr, "msql-import: could not import record %i: %s\n",
		recordCount, msqlErrMsg);
	fprintf(stdout, "Query : %s\n\n", query[0].buffer);
	return 0;
    }
    
    return 1;
}

/*****************************************************************
 *  row_length verifies that table_name exists in the database,  *
 *  and returns the row length in bytes (including control bytes)*
 *****************************************************************/

int
row_length(const char *table_name)
{
    m_result *res;
    m_field *curField;
    int len = 1;
    char *tableCopy = safe_c_string_copy(table_name);
    
    /* Verify that the table_name argument is valid */

    if (table_name == NULL)
    {
	fprintf(stderr, "msql-import in row_length(), invalid table_name "
		"(NULL).  Exiting.\n");
	exit(1);
    }
    
    res = msqlListFields(sock,tableCopy);

    free(tableCopy);
    
    if (!res)
    {
	fprintf(stderr, "msql-import error : Unable to get the fields in "
		"table %s, exiting.\n", table_name);
	exit(1);
    }
    
    while((curField = msqlFetchField(res)))
    {
	len += curField->length + 1;
    }

    msqlFreeResult(res);

    if (len<3)
    {
	fprintf(stderr, "msql-import: error getting table definition. "
		"Exiting.\n");

	abort_import(1);
    }
    
    return len;
}

char *
format_delimiter(const char *str)
{
    unsigned char nl = (unsigned char)10;
    unsigned char tab = (unsigned char)9;
    char *result = (char *)NULL;
    
    if (!str)
    {
	return (char *)NULL;
    }
    
    if (strcmp(str, "\\n") == 0)
    {
	result = malloc(2);
	result[0] = nl;
	result[1] = '\0';
    }
    else
    {
	if (strcmp(str, "\\t") == 0)
	{
	    result = malloc(2);
	    result[0] = tab;
	    result[1] = '\0';
	}
	else
	{
	    result = (char *)malloc(strlen(str) + 1);
	    strcpy(result, str);
	}
    }
    return result;
}

void
printHelp(void)
{
    const char text[] =
    {
	"  msql-import loads the contents of an ASCII delimited flat file "
	"into an\n  existing MiniSQL table.  It automatically performs the "
	"type conversions, and\n  validates the data.\n\n  msql-import is "
	"invoked as follows:\n\n    msql-import -h host -d database -t table "
	"-c column_delimiter \\\n"
	"    -r record_delimiter -i input_datafile "
	"-f \"[field [,field...]]\"\n\n"
	"               host: hostname of the msql server\n"
	"           database: the Mini SQL database name\n"
	"              table: the table in which to load the data\n"
	"   column_delimiter: the character used to delimit fields within a"
	" record\n"
	"   record_delimiter: the character used to delimit records\n"
	"     input_datafile: contains the data to be imported\n"
	"             fields: import the data in those only (optional)\n\n"
	"  Example:\n\n    msql-import -h zeus -d db -t table -c \\t "
	"-r \\n -i /tmp/file \\\n "
	"    -f \"client_id,name,address\"\n\n"
	"  If the fields are not specified, then all fields in the table "
	"will be filled\n  with the data contained in the flat file, in "
	"order of appearance.\n\n"
	"  msql-import was written by Pascal Forget <pascal@wsc.com>.\n\n"
    };
	
    fprintf(stdout, "\n  msql-import %s help:\n\n", MSQL_IMPORT_VERSION);
    fprintf(stdout, text);
}

void
printUsage(void)
{
    fprintf (stderr,
	     "\nUsage: msql-import [-h host] [-c col_delimiter] "
	     "[-r rec_delimiter]\n"
	     "       -d database -t table -i input_file -f [field [,"
	     " field...]]] \n\n"
	     "       where host is host address or the hostname,"
	     " default = localhost\n"
	     "       col_delimiter is column delimiter, default = tab \n"
	     "       rec_delimiter is record delimiter, default = newline\n"
	     "       database is the mSQL database (mandatory)\n"
	     "       table is the mSQL table (mandatory)\n"
	     "       input_file is the input ascii text file (mandatory)\n"
	     "       fields are optional input fields, must be in double"
	     " quotes \n\n"
	     "       Type 'msql-import --help' for more information\n\n");
    exit (0);
}

void
printVersion(void)
{
    fprintf(stdout, "%s\n", MSQL_IMPORT_VERSION);
}

/*****************************************************************
 * _main() connects to the msql server, then imports the data    *
 *****************************************************************/

void
main(int argc, char **argv)
{
    const char *buf[] ={"--help","--info","help","info","-help","-info",0};
    const char **ptr = buf;
    int len;
    char *hostname = "localhost";
    char *rec_delimiter = "\n";
    char *col_delimiter = "\t";
    char *database = NULL;
    char *table = NULL;
    char *input_file = NULL;
    char *fields = NULL;
    int i = argc - 1;

    if (argc == 2) {
	if ((strcmp(argv[1],"--version")==0) || (strcmp(argv[1],"--ver")==0))
	{
	    printVersion();
	    exit(0);
	}
	else
	{   
	    while(*ptr)
	    {
		if (strcmp((char *)*ptr, argv[1]) == 0)
		{
		    printHelp();
		    exit(0);
		}
		ptr++;
	    }
	    fprintf(stderr, "msql-import error: unknown option %s\n", argv[1]);
	    printUsage();
	}
    }

   /* Parse calling arguments */

    while(i > 1)
    {
	if (strcmp(argv[i-1], "-h") == 0)
	{
	    hostname = argv[i];
	}
	else if (strcmp(argv[i-1], "-c") == 0)
	{
	    col_delimiter = safe_c_string_copy(argv[i]);
	}
	else if (strcmp(argv[i-1], "-r") == 0)
	{
	    rec_delimiter = safe_c_string_copy(argv[i]);
	}
	else if (strcmp(argv[i-1], "-d") == 0)
	{
	    database = argv[i];
	}
	else if (strcmp(argv[i-1], "-t") == 0)
	{
	    table = safe_c_string_copy(argv[i]);
	}
	else if (strcmp(argv[i-1], "-i") == 0)
	{
	    input_file = safe_c_string_copy(argv[i]);
	}
	else if (strcmp(argv[i-1], "-f") == 0)
	{
	    fields = safe_c_string_copy(argv[i]);
	}
	i-=2;
    }

    /* check mandatory arguments. If one is missing, print usage then exit. */
    
    if ((input_file == NULL) || (table == NULL) || (database == NULL))
    {
	printUsage();
    }

    /* Print the arguments that were interpreted */
    
    printf ("\n hostname         =  %s\n", hostname);

    if (strcmp(col_delimiter, "\t") !=0)
    {
	printf (" column delimiter =  %s\n", col_delimiter);
    }
    else
    {
	printf (" column delimiter =  TAB\n");
    }

    if (strcmp(rec_delimiter, "\n") !=0) {
	printf (" record delimiter =  %s\n", rec_delimiter);
    }
    else
    {
	printf(" record delimiter =  NEWLINE\n");
    }
    
    printf (" database         =  %s\n", database);
    printf (" table            =  %s\n", table);
    printf (" input_file       =  %s\n", input_file);

    if (fields != NULL)
    {
	printf (" fields     =  %s\n", fields);
    }

    printf("\n");

    /* Connect to the msql server */

    if (strcmp(hostname, "localhost") == 0)
    {
	sock = msqlConnect(NULL);
    }
    else
    {
	sock = msqlConnect(hostname);
    }

    /* Verify that the connection has been successfully established */

    if (sock == -1)
    {
	fprintf(stderr, "msql-import: error connecting to host. "
		"Exiting.\n");
	abort_import(1);
    }
    
    /* Set the current database */

    if (msqlSelectDB(sock, database) == -1)
    {
	fprintf(stderr, "msql-import: error opening database. "
		"Exiting.\n");
	abort_import(1);
    }
    
    /* ssk test if table exists */
    len = row_length( table );
    
    import_file(table,
		input_file,
		format_delimiter(col_delimiter),
		format_delimiter(rec_delimiter),
		fields);
}
