#!/usr/bin/env lua

local _G             = _G
local _VERSION       = _VERSION
local assert         = assert
local error          = error
local getmetatable   = getmetatable
local ipairs         = ipairs
local next           = next
local pairs          = pairs
local print          = print
local rawequal       = rawequal
local rawget         = rawget
local rawlen         = rawlen
local rawset         = rawset
local select         = select
local setmetatable   = setmetatable
local tonumber       = tonumber
local tostring       = tostring
local type           = type

local math      = math
local string    = string
local table     = table

local mondelefant = require("mondelefant")
local atom        = require("atom")
local json        = require("json")

local _M = {}
if _ENV then
  _ENV = _M
else
  _G[...] = _M
  setfenv(1, _M)
end


input_converters = setmetatable({}, { __mode = "k" })

input_converters["boolean"] = function(conn, value, rawtext_mode)
  if rawtext_mode then
    if value then return "t" else return "f" end
  else
    if value then return "TRUE" else return "FALSE" end
  end
end

input_converters["number"] = function(conn, value, rawtext_mode)
  if _VERSION == "Lua 5.2" then
    -- TODO: remove following compatibility hack to allow large integers (e.g. 1e14) in Lua 5.2
    local integer_string = string.format("%i", value)
    if tonumber(integer_string) == value then
      return integer_string
    else
      local number_string = tostring(value)
      if string.find(number_string, "^[0-9.e+-]+$") then
        return number_string
      else
        if rawtext_mode then return "NaN" else return "'NaN'" end
      end
    end
  end
  local integer = math.tointeger(value)
  if integer then
    return tostring(integer)
  end
  local str = tostring(value)
  if string.find(str, "^[0-9.e+-]+$") then
    return str
  end
  if rawtext_mode then return "NaN" else return "'NaN'" end
end

input_converters[atom.fraction] = function(conn, value, rawtext_mode)
  if value.invalid then
    if rawtext_mode then return "NaN" else return "'NaN'" end
  else
    local n, d = tostring(value.numerator), tostring(value.denominator)
    if string.find(n, "^%-?[0-9]+$") and string.find(d, "^%-?[0-9]+$") then
      if rawtext_mode then
        return n .. "/" .. d
      else
        return "(" .. n .. "::numeric / " .. d .. "::numeric)"
      end
    else
      if rawtext_mode then return "NaN" else return "'NaN'" end
    end
  end
end

input_converters[atom.date] = function(conn, value, rawtext_mode)
  if rawtext_mode then
    return tostring(value)
  else
    return conn:quote_string(tostring(value)) .. "::date"
  end
end

input_converters[atom.timestamp] = function(conn, value, rawtext_mode)
  if rawtext_mode then
    return tostring(value)
  else
    return conn:quote_string(tostring(value))  -- don't define type
  end
end

input_converters[atom.time] = function(conn, value, rawtext_mode)
  if rawtext_mode then
    return tostring(value)
  else
    return conn:quote_string(tostring(value)) .. "::time"
  end
end

input_converters["rawtable"] = function(conn, value, rawtext_mode)
  -- treat tables as arrays
  local parts = { "{" }
  for i, v in ipairs(value) do
    if i > 1 then parts[#parts+1] = "," end
    local converter =
      input_converters[getmetatable(v)] or
      input_converters[type(v)]
    if converter then
      v = converter(conn, v, true)
    else
      v = tostring(v)
    end
    parts[#parts+1] = '"'
    parts[#parts+1] = string.gsub(v, '[\\"]', '\\%0')
    parts[#parts+1] = '"'
  end
  parts[#parts+1] = "}"
  return conn:quote_string(table.concat(parts))
end


output_converters = setmetatable({}, { __mode = "k" })

output_converters.int8 = function(str) return atom.integer:load(str) end
output_converters.int4 = function(str) return atom.integer:load(str) end
output_converters.int2 = function(str) return atom.integer:load(str) end

output_converters.numeric = function(str) return atom.number:load(str) end
output_converters.float4  = function(str) return atom.number:load(str) end
output_converters.float8  = function(str) return atom.number:load(str) end

output_converters.bool = function(str) return atom.boolean:load(str) end

output_converters.date = function(str) return atom.date:load(str) end

local function timestamp_loader_func(str)
  local year, month, day, hour, minute, second = string.match(
    str,
    "^([0-9][0-9][0-9][0-9])%-([0-9][0-9])%-([0-9][0-9]) ([0-9]?[0-9]):([0-9][0-9]):([0-9][0-9])"
  )
  if year then
    return atom.timestamp{
      year   = tonumber(year),
      month  = tonumber(month),
      day    = tonumber(day),
      hour   = tonumber(hour),
      minute = tonumber(minute),
      second = tonumber(second)
    }
  else
    return atom.timestamp.invalid
  end
end
output_converters.timestamp = timestamp_loader_func
output_converters.timestamptz = timestamp_loader_func

local function time_loader_func(str)
  local hour, minute, second = string.match(
    str,
    "^([0-9]?[0-9]):([0-9][0-9]):([0-9][0-9])"
  )
  if hour then
    return atom.time{
      hour   = tonumber(hour),
      minute = tonumber(minute),
      second = tonumber(second)
    }
  else
    return atom.time.invalid
  end
end
output_converters.time = time_loader_func
output_converters.timetz = time_loader_func

local function json_loader_func(str)
  return assert(json.import(str))
end
output_converters.json = json_loader_func
output_converters.jsonb = json_loader_func

mondelefant.postgresql_connection_prototype.type_mappings = {
  int8 = atom.integer,
  int4 = atom.integer,
  int2 = atom.integer,
  bool = atom.boolean,
  date = atom.date,
  timestamp = atom.timestamp,
  time = atom.time,
  text = atom.string,
  varchar = atom.string,
  json = json,
  jsonb = json,
}


function mondelefant.postgresql_connection_prototype.input_converter(conn, value, info)
  if value == nil then
    return "NULL"
  else
    local mt = getmetatable(value)
    local converter = input_converters.mt
    if not converter then
      local t = type(value)
      if t == "table" and mt == nil then
        converter = input_converters.rawtable
      else
        converter = input_converters.t
      end
    end
    local converter =
      input_converters[getmetatable(value)] or
      input_converters[type(value)]
    if converter then
      return converter(conn, value)
    else
      return conn:quote_string(tostring(value))
    end
  end
end

function mondelefant.postgresql_connection_prototype.output_converter(conn, value, info)
  if value == nil then
    return nil
  else
    local array_type = nil
    if info.type then
      array_type = string.match(info.type, "^(.*)_array$")
    end
    if array_type then
      local result = {}
      local count = 0
      local inner_converter = output_converters[array_type]
      if not inner_converter then
        inner_converter = function(x) return x end
      end
      value = string.match(value, "^{(.*)}$")
      if not value then
        error("Could not parse database array")
      end
      local pos = 1
      while pos <= #value do
        count = count + 1
        if string.find(value, '^""$', pos) then
          result[count] = inner_converter("")
          pos = pos + 2
        elseif string.find(value, '^"",', pos) then
          result[count] = inner_converter("")
          pos = pos + 3
        elseif string.find(value, '^"', pos) then
          local p1, p2, entry = string.find(value, '^"(.-[^\\])",', pos)
          if not p1 then
            p1, p2, entry = string.find(value, '^"(.*[^\\])"$', pos)
          end
          if not entry then error("Could not parse database array") end
          entry = string.gsub(entry, "\\(.)", "%1")
          result[count] = inner_converter(entry)
          pos = p2 + 1
        else
          local p1, p2, entry = string.find(value, '^(.-),', pos)
          if not p1 then p1, p2, entry = string.find(value, '^(.*)$', pos) end
          result[count] = inner_converter(entry)
          pos = p2 + 1
        end
      end
      return result
    else
      local converter = output_converters[info.type]
      if converter then
        return converter(value)
      else
        return value
      end
    end
  end
end


function mondelefant.save_mutability_state(value)
  local jsontype = json.type(value)
  if jsontype == "object" or jsontype == "array" then
    return tostring(value)
  end
end

function mondelefant.verify_mutability_state(value, state)
  return tostring(value) ~= state
end


return _M


--[[

db = assert(mondelefant.connect{engine='postgresql', dbname='test'})
result = db:query{'SELECT ? + 1', atom.date{ year=1999, month=12, day=31}}
print(result[1][1].year)

--]]
