/*
  OpenMW - The completely unofficial reimplementation of Morrowind
  Copyright (C) 2008-2010  Nicolay Korslund
  Email: < korslund@gmail.com >
  WWW: http://openmw.sourceforge.net/

  This file (tes4bsa_file.cpp) is part of the OpenMW package.

  OpenMW is distributed as free software: you can redistribute it
  and/or modify it under the terms of the GNU General Public License
  version 3, as published by the Free Software Foundation.

  This program is distributed in the hope that it will be useful, but
  WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
  General Public License for more details.

  You should have received a copy of the GNU General Public License
  version 3 along with this program. If not, see
  http://www.gnu.org/licenses/ .

  TES4 stuff added by cc9cii 2018

 */
#include "tes4bsa_file.hpp"

#include <stdexcept>
#include <boost/scoped_array.hpp>
#include <boost/algorithm/string.hpp>
#include <boost/filesystem/path.hpp>
#include <boost/filesystem/fstream.hpp>
#include <boost/archive/basic_binary_iarchive.hpp>

#include <extern/bsaopthash/hash.hpp> // see: http://en.uesp.net/wiki/Tes4Mod:Hash_Calculation
#include <boost/iostreams/filtering_streambuf.hpp>
#include <boost/iostreams/copy.hpp>
#include <boost/iostreams/filter/zlib.hpp>
#include <boost/iostreams/stream.hpp>
#include <boost/iostreams/device/array.hpp>
#include <components/bsa/memorystream.hpp>
#include <iostream>

namespace Bsa
{

//special marker for invalid records,
//equal to max uint32_t value
const uint32_t TES4BSAFile::sInvalidOffset = std::numeric_limits<uint32_t>::max();

//bit marking compression on file size
const uint32_t TES4BSAFile::sCompressedFlag = 1u << 30u;


TES4BSAFile::FileRecord::FileRecord() : size(0), offset(sInvalidOffset)
{ }

bool TES4BSAFile::FileRecord::isValid() const
{
    return offset != sInvalidOffset;
}

bool TES4BSAFile::FileRecord::isCompressed(bool bsaCompressedByDefault) const
{
    bool compressionFlagEnabled = ((size & sCompressedFlag) == sCompressedFlag);

    if (bsaCompressedByDefault) {
        return !compressionFlagEnabled;
    }
    return compressionFlagEnabled;
}

std::uint32_t TES4BSAFile::FileRecord::getSizeWithoutCompressionFlag() const {
    return  size & (~sCompressedFlag);
}

TES4BSAFile::TES4BSAFile()
    : mCompressedByDefault(false), mEmbeddedFileNames(false)
{ }

TES4BSAFile::~TES4BSAFile()
{ }

void TES4BSAFile::getBZString(std::string& str, std::istream& filestream)
{
    char size = 0;
    filestream.read(&size, 1);

    boost::scoped_array<char> buf(new char[size]);
    filestream.read(buf.get(), size);

    if (buf[size - 1] != 0)
    {
        str.assign(buf.get(), size);
        if (str.size() != ((size_t)size)) {
            fail("getBZString string size mismatch");
        }
    }
    else
    {
        str.assign(buf.get(), size - 1); // don't copy null terminator
        if (str.size() != ((size_t)size - 1)) {
            fail("getBZString string size mismatch (null terminator)");
        }
    }

    return;
}

/// Read header information from the input source
void TES4BSAFile::readHeader()
{
    assert(!mIsLoaded);

    namespace bfs = boost::filesystem;
    bfs::ifstream input(bfs::path(mFilename), std::ios_base::binary);

    // Total archive size
    std::streamoff fsize = 0;
    if (input.seekg(0, std::ios_base::end))
    {
        fsize = input.tellg();
        input.seekg(0);
    }

    if (fsize < 36) // header is 36 bytes
        fail("File too small to be a valid BSA archive");

    // Get essential header numbers
    //size_t dirsize, filenum;
    std::uint32_t archiveFlags, folderCount, fileCount, totalFileNameLength;
    {
        // First 36 bytes
        std::uint32_t header[9];

        input.read(reinterpret_cast<char*>(header), 36);

        if (header[0] != 0x00415342 /*"BSA\x00"*/ || (header[1] != 0x67 /*TES4*/ && header[1] != 0x68 /*TES5*/))
            fail("Unrecognized TES4 BSA header");

        // header[2] is offset, should be 36 = 0x24 which is the size of the header

        // Oblivion - Meshes.bsa
        //
        // 0111 1000 0111 = 0x0787
        //  ^^^ ^     ^^^
        //  ||| |     ||+-- has names for dirs  (mandatory?)
        //  ||| |     |+--- has names for files (mandatory?)
        //  ||| |     +---- files are compressed by default
        //  ||| |
        //  ||| +---------- unknown (TES5: retain strings during startup)
        //  ||+------------ unknown (TES5: embedded file names)
        //  |+------------- unknown
        //  +-------------- unknown
        //
        archiveFlags = header[3];
        folderCount = header[4];
        fileCount = header[5];
        totalFileNameLength = header[7];
        
        mCompressedByDefault = (archiveFlags & 0x4) != 0;
        mEmbeddedFileNames = header[1] == 0x68 /*TES5*/ && (archiveFlags & 0x100) != 0;
    }

    // folder records
    std::uint64_t hash;
    FolderRecord fr;
    for (std::uint32_t i = 0; i < folderCount; ++i)
    {
        input.read(reinterpret_cast<char*>(&hash), 8);
        input.read(reinterpret_cast<char*>(&fr.count), 4); 
        input.read(reinterpret_cast<char*>(&fr.offset), 4); 

        std::map<std::uint64_t, FolderRecord>::const_iterator lb = mFolders.lower_bound(hash);
        if (lb != mFolders.end() && !(mFolders.key_comp()(hash, lb->first)))
            fail("Archive found duplicate folder name hash");
        else
            mFolders.insert(lb, std::pair<std::uint64_t, FolderRecord>(hash, fr));
    }

    // file record blocks
    std::uint64_t fileHash;
    FileRecord file;

    std::string folder("");
    std::uint64_t folderHash;
    if ((archiveFlags & 0x1) == 0)
    {
        folderCount = 1;
    }

    mFiles.clear();
    std::vector<std::string> fullPaths;

    for (std::uint32_t i = 0; i < folderCount; ++i)
    {
        if ((archiveFlags & 0x1) != 0)
            getBZString(folder, input);

        std::string emptyString;
        folderHash = GenOBHash(folder, emptyString);

        std::map<std::uint64_t, FolderRecord>::iterator iter = mFolders.find(folderHash);
        if (iter == mFolders.end())
            fail("Archive folder name hash not found");

        for (std::uint32_t j = 0; j < iter->second.count; ++j)
        {
            input.read(reinterpret_cast<char*>(&fileHash), 8);
            input.read(reinterpret_cast<char*>(&file.size), 4);
            input.read(reinterpret_cast<char*>(&file.offset), 4);

            std::map<std::uint64_t, FileRecord>::const_iterator lb = iter->second.files.lower_bound(fileHash);
            if (lb != iter->second.files.end() && !(iter->second.files.key_comp()(fileHash, lb->first)))
            {
                fail("Archive found duplicate file name hash");
            }

            iter->second.files.insert(lb, std::pair<std::uint64_t, FileRecord>(fileHash, file));

            FileStruct fileStruct;
            fileStruct.fileSize = file.getSizeWithoutCompressionFlag();
            fileStruct.offset = file.offset;
            fileStruct.name = nullptr;
            mFiles.push_back(fileStruct);

            fullPaths.push_back(folder);
        }
    }

    // file record blocks
    if ((archiveFlags & 0x2) != 0)
    {
        mStringBuf.resize(totalFileNameLength);
        input.read(&mStringBuf[0], mStringBuf.size());
    }

    size_t mStringBuffOffset = 0;
    size_t totalStringsSize = 0;
    for (std::uint32_t fileIndex = 0; fileIndex < mFiles.size(); ++fileIndex) {

        if (mStringBuffOffset >= totalFileNameLength) {
            fail("Corrupted names record in BSA file");
        }

        //The vector guarantees that its elements occupy contiguous memory
        mFiles[fileIndex].name = reinterpret_cast<char*>(mStringBuf.data() + mStringBuffOffset);

        fullPaths.at(fileIndex) += "\\" + std::string(mStringBuf.data() + mStringBuffOffset);

        while (mStringBuffOffset < totalFileNameLength) {
            if (mStringBuf[mStringBuffOffset] != '\0') {
                mStringBuffOffset++;
            }
            else {
                mStringBuffOffset++;
                break;
            }
        }
        //we want to keep one more 0 character at the end of each string
        totalStringsSize += fullPaths.at(fileIndex).length() + 1u;
    }
    mStringBuf.resize(totalStringsSize);

    mStringBuffOffset = 0;
    for (std::uint32_t fileIndex = 0u; fileIndex < mFiles.size(); fileIndex++) {
        size_t stringLength = fullPaths.at(fileIndex).length();

        std::copy(fullPaths.at(fileIndex).c_str(),
            //plus 1 because we also want to copy 0 at the end of the string
            fullPaths.at(fileIndex).c_str() + stringLength + 1u,
            mStringBuf.data() + mStringBuffOffset);

        mFiles[fileIndex].name = reinterpret_cast<char*>(mStringBuf.data() + mStringBuffOffset);

        mLookup[reinterpret_cast<char*>(mStringBuf.data() + mStringBuffOffset)] = fileIndex;
        mStringBuffOffset += stringLength + 1u;
    }

    if (mStringBuffOffset != mStringBuf.size()) {
        fail("Could not resolve names of files in BSA file");
    }

    convertCompressedSizesToUncompressed();
    mIsLoaded = true;
}

TES4BSAFile::FileRecord TES4BSAFile::getFileRecord(const std::string& filePath) const
{
    boost::filesystem::path p(filePath);
    std::string stem = p.stem().string();
    std::string ext = p.extension().string();
    std::string filename = p.filename().string();
    p.remove_filename();

    std::string folder = p.string();
    // GenOBHash already converts to lowercase and replaces file separators but not for path
    boost::algorithm::to_lower(folder);
    std::replace(folder.begin(), folder.end(), '/', '\\');

    std::string emptyString;
    std::uint64_t folderHash = GenOBHash(folder, emptyString);

    std::map<std::uint64_t, FolderRecord>::const_iterator it = mFolders.find(folderHash);
    if (it == mFolders.end())
    {
        return FileRecord(); // folder not found, return default which has offset of -1
    }

    boost::algorithm::to_lower(stem);
    boost::algorithm::to_lower(ext);
    std::uint64_t fileHash = GenOBHashPair(stem, ext);
    std::map<std::uint64_t, FileRecord>::const_iterator fileInDirIter = it->second.files.find(fileHash);
    if (fileInDirIter == it->second.files.end())
    {
        return FileRecord(); // file not found, return default which has offset of -1
    }

    return fileInDirIter->second;
}

Files::IStreamPtr TES4BSAFile::getFile(const FileStruct* file) {

    FileRecord fileRec = getFileRecord(file->name);
    return getFile(fileRec);
}

Files::IStreamPtr TES4BSAFile::getFile(const char* file)
{
    FileRecord fileRec = getFileRecord(file);
    if (!fileRec.isValid()) {
        fail("File not found: " + std::string(file));
    }

    return getFile(fileRec);
}

Files::IStreamPtr TES4BSAFile::getFile(const FileRecord& fileRecord) {

    if (fileRecord.isCompressed(mCompressedByDefault)) {
        Files::IStreamPtr streamPtr = Files::openConstrainedFileStream(mFilename.c_str(), fileRecord.offset, fileRecord.getSizeWithoutCompressionFlag());

        std::istream* fileStream = streamPtr.get();

        if (mEmbeddedFileNames) {
            std::string embeddedFileName;
            getBZString(embeddedFileName,*fileStream);
        }

        uint32_t uncompressedSize = 0u;
        fileStream->read(reinterpret_cast<char*>(&uncompressedSize), sizeof(uncompressedSize));

        boost::iostreams::filtering_streambuf<boost::iostreams::input> inputStreamBuf;
        inputStreamBuf.push(boost::iostreams::zlib_decompressor());
        inputStreamBuf.push(*fileStream);

        std::shared_ptr<Bsa::MemoryInputStream> memoryStreamPtr = std::make_shared<MemoryInputStream>(uncompressedSize);

        boost::iostreams::basic_array_sink<char> sr(memoryStreamPtr->getRawData(), uncompressedSize);
        boost::iostreams::copy(inputStreamBuf, sr);

        return std::shared_ptr<std::istream>(memoryStreamPtr, (std::istream*)memoryStreamPtr.get());
    }

    return Files::openConstrainedFileStream(mFilename.c_str(), fileRecord.offset, fileRecord.size);
}

BsaVersion TES4BSAFile::detectVersion(std::string filePath)
{
    namespace bfs = boost::filesystem;
    bfs::ifstream input(bfs::path(filePath), std::ios_base::binary);

    // Total archive size
    std::streamoff fsize = 0;
    if (input.seekg(0, std::ios_base::end))
    {
        fsize = input.tellg();
        input.seekg(0);
    }

    if (fsize < 12) {
        return BSAVER_UNKNOWN;
    }


    // Get essential header numbers
    size_t dirsize, filenum;
    {
        // First 12 bytes
        uint32_t head[3];

        input.read(reinterpret_cast<char*>(head), 12);

        if (head[0] == static_cast<uint32_t>(BSAVER_TES3)) {
            return BSAVER_TES3;
        }

        if (head[0] = static_cast<uint32_t>(BSAVER_TES4PLUS)) {
            return BSAVER_TES4PLUS;
        }
    }
    return BSAVER_UNKNOWN;
}

//mFiles used by OpenMW expects uncompressed sizes
void TES4BSAFile::convertCompressedSizesToUncompressed()
{
    for (auto iter = mFiles.begin(); iter != mFiles.end(); ++iter)
    {
        const FileRecord& fileRecord = getFileRecord(iter->name);
        if (!fileRecord.isValid())
        {
            fail("Could not find file " + std::string(iter->name) + " in BSA");
        }

        if (!fileRecord.isCompressed(mCompressedByDefault))
        {
            //no need to fix fileSize in mFiles - uncompressed size already set
            continue;
        }

        Files::IStreamPtr dataBegin = Files::openConstrainedFileStream(mFilename.c_str(), fileRecord.offset, fileRecord.getSizeWithoutCompressionFlag());

        if (mEmbeddedFileNames)
        {
            std::string embeddedFileName;
            getBZString(embeddedFileName, *(dataBegin.get()));
        }

        dataBegin->read(reinterpret_cast<char*>(&(iter->fileSize)), sizeof(iter->fileSize));
    }
}

} //namespace Bsa