From eaea7bbfc2ae6a31ab9444a8be293deb912a9c1d Mon Sep 17 00:00:00 2001
From: Pompolic <pompolic@special-circumstanc.es>
Date: Thu, 7 Oct 2021 20:16:29 +0200
Subject: [PATCH] (WIP) Connect the LZW module to pdf parser, fix up some bugs

---
 lzw.c | 65 ++++++++++++++++++++++++++++++++++++++++-------------------
 lzw.h |  6 +++---
 pdf.c | 54 ++++++++++++++++++++++++++++++++++++-------------
 3 files changed, 87 insertions(+), 38 deletions(-)

diff --git a/lzw.c b/lzw.c
index 60de6f1..7d15c93 100644
--- a/lzw.c
+++ b/lzw.c
@@ -33,7 +33,7 @@ typedef struct LZW_context_S
 void LZW_clear_table(LZW_context_T *ctx)
 {
 	/*
-	 *  Optimizations: since we leave the entries 0-257 empty, we don't need to free() them explicitly.
+	 *  Optimizations: since we leave the entries 0-257 fixed or empty, we don't need to free() them explicitly.
 	 *  And since codes are added to the table sequentially, we don't need to look past ctx->next;
 	 */
 	for(int i = 258; i < ctx->next; ++i)
@@ -110,7 +110,7 @@ validate_LZW_9bitlitspec(HParseResult *p, void *u)
 {
 	LZW_context_T * ctx = (LZW_context_T *) u;
 	uint64_t code = H_CAST_UINT(p->ast);
-	return (ctx->next < 512 && code > 258);
+	return (ctx->next < 512 && code < 258);
 }
 
 bool
@@ -162,11 +162,13 @@ HParsedToken*
 act_LZW_literal(const HParseResult *p, void *u)
 {
 	uint64_t code = H_CAST_UINT(p->ast);
+	LZW_context_T * ctx = (LZW_context_T *) u;
 	/*
 	 * Literals go from 0-255, so they are guaranteed to fit into 1 byte. See also: validate_LZW_literal
 	 */
 	uint8_t *output = malloc(sizeof(uint8_t));
 	*output = (uint8_t) code;
+	ctx->old = code;
 
 	return H_MAKE_BYTES(output, 1);
 }
@@ -198,7 +200,7 @@ act_LZW_codeword(const HParseResult *p, void *u)
 		output_token = malloc(sizeof(uint8_t) * code_token_length); //XXX: maybe encapsulate working with the code table/HBytes
 		memcpy(output_token, code_str->token, code_token_length);
 
-		prev_string = ctx->lzw_code_table[ctx->old];
+		prev_string = ctx->lzw_code_table[ctx->old]; //XXX: segfault when reading first codeword, set old if literal, init table with literals, special-case first code
 		prev_string_length = prev_string->len;
 
 		/*
@@ -286,7 +288,7 @@ act_LZW_data(const HParseResult *p, void *u)
 		total_buffer_size += H_FIELD_BYTES(i).len;
 	}
 
-	buffer = malloc(sizeof(uint8_t) * total_buffer_size); // XXX arena alloc
+	buffer = malloc(sizeof(uint8_t) * total_buffer_size); // XXX arena alloc, calloc
 
 	/* go through parse result, merge bytes */
 	for(int i = 0; i < num_fragments; i++)
@@ -300,42 +302,63 @@ act_LZW_data(const HParseResult *p, void *u)
 }
 
 
-void init_lzw_parser()
+void init_LZW_parser()
 {
-	H_VRULE(LZW_9bitcodeword, h_bits(9, false));
-	H_VRULE(LZW_10bitcodeword, h_bits(10, false));
-	H_VRULE(LZW_11bitcodeword, h_bits(11, false));
-	H_VRULE(LZW_12bitcodeword, h_bits(12, false));
+	context = malloc(sizeof(LZW_context_T));
+	context->next = 258;
+	/* set up literals in LZW code table */
+	for(int i = 0; i < 256; ++i)
+	{
+		uint8_t *token = malloc(sizeof(uint8_t));
+		*token = i;
+		HBytes *lit = malloc(sizeof(HBytes)); // XXX: instead of HBytes*, use HBytes
+		lit->token = token;
+		lit->len = 1;
+		context->lzw_code_table[i] = lit;
+	}
 
-	H_VRULE(LZW_9bitlitspec, h_bits(9, false));
-	H_VRULE(LZW_10bitlitspec, h_bits(10, false));
-	H_VRULE(LZW_11bitlitspec, h_bits(11, false));
-	H_VRULE(LZW_12bitlitspec, h_bits(12, false));
+	H_VDRULE(LZW_9bitcodeword, h_bits(9, false), context);
+	H_VDRULE(LZW_10bitcodeword, h_bits(10, false), context);
+	H_VDRULE(LZW_11bitcodeword, h_bits(11, false), context);
+	H_VDRULE(LZW_12bitcodeword, h_bits(12, false), context);
+
+	H_VDRULE(LZW_9bitlitspec, h_bits(9, false), context);
+	H_VDRULE(LZW_10bitlitspec, h_bits(10, false), context);
+	H_VDRULE(LZW_11bitlitspec, h_bits(11, false), context);
+	H_VDRULE(LZW_12bitlitspec, h_bits(12, false), context);
 
 	H_RULE(LZW_remainingbits, h_many(h_bits(1, false))); //XXX: could validate that these bits are 0?
 	// XXX: p__take_n function to dynamically generate the rule needed to consume remaining bits?
 	// XXX: user data pointers (VDRULE, VADRULE, etc.)
 
-	H_VARULE(LZW_clear, h_choice(LZW_9bitlitspec, LZW_10bitlitspec, LZW_11bitlitspec, LZW_12bitlitspec, NULL)); //XXX: VARULE or AVRULE?
-	H_VRULE(LZW_eod, h_choice(LZW_9bitlitspec, LZW_10bitlitspec, LZW_11bitlitspec, LZW_12bitlitspec, NULL));
-	H_VRULE(LZW_literal, h_choice(LZW_9bitlitspec, LZW_10bitlitspec, LZW_11bitlitspec, LZW_12bitlitspec, NULL));
-	H_ARULE(LZW_codeword, h_choice(LZW_9bitcodeword, LZW_10bitcodeword, LZW_11bitcodeword, LZW_12bitcodeword, NULL));
+	H_AVDRULE(LZW_clear, h_choice(LZW_9bitlitspec, LZW_10bitlitspec, LZW_11bitlitspec, LZW_12bitlitspec, NULL), context);
+	H_VDRULE(LZW_eod, h_choice(LZW_9bitlitspec, LZW_10bitlitspec, LZW_11bitlitspec, LZW_12bitlitspec, NULL), context);
+	H_VDRULE(LZW_literal, h_choice(LZW_9bitlitspec, LZW_10bitlitspec, LZW_11bitlitspec, LZW_12bitlitspec, NULL), context);
+	H_ADRULE(LZW_codeword, h_choice(LZW_9bitcodeword, LZW_10bitcodeword, LZW_11bitcodeword, LZW_12bitcodeword, NULL), context);
 
-	H_RULE(LZW_data, h_sequence(LZW_clear, h_many1(h_butnot(h_choice(LZW_literal, LZW_clear, LZW_codeword, NULL), LZW_eod)), LZW_eod, LZW_remainingbits, NULL));
+	H_ADRULE(LZW_data, h_sequence(LZW_clear, h_many1(h_butnot(h_choice(LZW_literal, LZW_clear, LZW_codeword, NULL), LZW_eod)), LZW_eod, LZW_remainingbits, NULL), context);
 	p_lzwdata = LZW_data;
 }
 
-HParseResult* parse_lzw_data(const uint8_t* input, size_t length)
+HParseResult* parse_LZW_data(const uint8_t* input, size_t length)
 {
 	return h_parse(p_lzwdata, input, length);
 }
 
 void set_LZW_context(LZW_context_T *ctx)
 {
-	context = ctx;
+	*context = *ctx; // XXX unnecessary, just clear context before each parse
 }
 
 void clear_LZW_context()
 {
-	context = NULL;
+	for(int i = 258; i < 4096; ++i)
+	{
+		if(context->lzw_code_table[i] != NULL)
+		{
+			free(context->lzw_code_table[i]);
+		}
+	}
+	context->next = 258;
+	context->old = 257; //XXX: guaranteed to segfault if old isn't set before
 }
diff --git a/lzw.h b/lzw.h
index 1a3b46b..192bc7a 100644
--- a/lzw.h
+++ b/lzw.h
@@ -23,10 +23,10 @@ typedef struct LZW_context_S
 	uint64_t old;
 } LZW_context_T;
 
-HParser * p_lzwdata;
+HParser * p_lzwdata; // XXX can be internal
 
-void init_lzw_parser();
-HParseResult * parse_lzw_data(const uint8_t* input, size_t length);
+void init_LZW_parser();
+HParseResult * parse_LZW_data(const uint8_t* input, size_t length);
 void set_LZW_context(LZW_context_T *ctx);
 void clear_LZW_context();
 
diff --git a/pdf.c b/pdf.c
index 26a4aac..8e38bbb 100644
--- a/pdf.c
+++ b/pdf.c
@@ -3325,6 +3325,7 @@ int read_lzw_buffer(void)
     return ret_value;
 }
 
+#include "lzw.h"
 
 HParseResult *
 LZWDecode(const Dict *parms, HBytes b, HParser *p)
@@ -3332,8 +3333,9 @@ LZWDecode(const Dict *parms, HBytes b, HParser *p)
 	struct predictor pred = {1, 1, 8, 1};
 	int (*depredict)(struct predictor *, uint8_t *, size_t);
 	HParseResult *res;
+	HParseResult *tmp_res;
 	int done;
-	int ret;
+	//int ret;
 	const HParsedToken *v;
 
 	/* set up the predictor (if any) */
@@ -3383,23 +3385,46 @@ LZWDecode(const Dict *parms, HBytes b, HParser *p)
 			err(1, "LZWDecode");
 	}
 
-	lzwspec *lzw_spec = new_lzw_spec(&b);
-	bind_lzw_spec(lzw_spec);
-
-	ret = lzw_decompress(write_lzw_buffer, read_lzw_buffer);
-	if (ret) {
-		fprintf(stderr, "lzw_decompress: error (%d)\n", ret);
-		assert(!"LZWDecode: failed to decompress\n");
+	//lzwspec *lzw_spec = new_lzw_spec(&b);
+	//bind_lzw_spec(lzw_spec);
+
+	//ret = lzw_decompress(write_lzw_buffer, read_lzw_buffer);
+	//if (ret) {
+	//	fprintf(stderr, "lzw_decompress: error (%d)\n", ret);
+	//	assert(!"LZWDecode: failed to decompress\n");
+	//}
+	//done = depredict(&pred, cur_lzw_spec->lzw_buf, cur_lzw_spec->write_head-1);
+	//assert(!done);	// XXX ITERATIVE
+	LZW_context_T * ctx = malloc(sizeof(LZW_context_T));
+	ctx->next = 258;
+	clear_LZW_context();
+	tmp_res = parse_LZW_data(b.token, b.len);
+	//clear_LZW_context();
+	free(ctx);
+
+	if(!tmp_res)
+	{
+		fprintf(stderr, "parse error in LZWDecode filter");
+		return NULL;
 	}
-	done = depredict(&pred, cur_lzw_spec->lzw_buf, cur_lzw_spec->write_head-1);
-	assert(!done);	// XXX ITERATIVE
+
+	assert(tmp_res->ast->token_type == TT_BYTES);
+
+	uint8_t * tmp_buf = malloc(sizeof(uint8_t) * tmp_res->ast->bytes.len);
+	memcpy(tmp_buf, tmp_res->ast->bytes.token, tmp_res->ast->bytes.len);
+	done = depredict(&pred, tmp_buf, tmp_res->ast->bytes.len);
+	assert(!done);
+
+	//done = depredict(&pred, res->ast->bytes.token, res->ast->bytes.len);
+	//assert(!done);
 
 	// SR::TODO:: Do a H_MAKE rather than a parse and let the caller do the parse
-	res = h_parse(p, pred.out, pred.nout);
-	free(pred.out);
+	//res = h_parse(p, pred.out, pred.nout);
+	res = h_parse(p, tmp_res->ast->bytes.token, tmp_res->ast->bytes.len); // XXX depred buffer
+	//free(pred.out);
 
-	bind_lzw_spec(NULL);
-	delete_lzw_spec(lzw_spec);
+	//bind_lzw_spec(NULL);
+	//delete_lzw_spec(lzw_spec);
 
 	return res;
 }
@@ -5594,6 +5619,7 @@ main(int argc, char *argv[])
 	/* build parsers */
 	aux = (struct Env){infile, input, sz};
 	init_parser(&aux);
+	init_LZW_parser();
 
 
 	/* parse all cross-reference sections and trailer dictionaries */
-- 
GitLab