#include "stdafx.h"
#include "OAuth2.h"

#include <iostream>
#include <string>
#include <afxinet.h>

/*
CString url_encode(CString s) {

DWORD options = ICU_DECODE | ICU_ENCODE_PERCENT;
DWORD bytes = 3 * s.GetLength() + 1;

CString res;
LPTSTR escapedString = res.GetBuffer(bytes + 1);
escapedString[0] = 0;

bool result = InternetCanonicalizeUrl(s.GetBuffer(), escapedString, &bytes, options);

if (result) {
res.ReleaseBufferSetLength(bytes);
}
else {
res.ReleaseBufferSetLength(0);
}

return res;
}*/

static CString HTTPPost(CString server, WORD port, CString url, CString data) {

	CString res;

	CHttpConnection* pConnection = NULL;
	CHttpFile* pFile = NULL;

	try {

		DWORD dwRet = 0;
		CInternetSession session;
		pConnection = session.GetHttpConnection(server, port);
		if (pConnection == NULL) { return ""; }

		pFile = pConnection->OpenRequest(CHttpConnection::HTTP_VERB_POST, url,
			NULL, 1, NULL, NULL,
			INTERNET_FLAG_SECURE | INTERNET_FLAG_IGNORE_CERT_CN_INVALID | INTERNET_FLAG_IGNORE_CERT_DATE_INVALID);

		if (pFile == NULL) {
			delete pConnection;
			return "";		
		}

		const int Len = data.GetLength();
		TCHAR * pData = data.GetBuffer();
		char * pBytes = new char[Len + 1];
		char * pDest = pBytes;

		for (int i = 0; i < Len; ++i) {
			*pDest++ = (char)*pData++;
		}
		*pDest = 0;

		CString strHeaders = "Content-Type: application/x-www-form-urlencoded";
		BOOL result = pFile->SendRequest(strHeaders, (LPVOID)pBytes, Len);
		BOOL bQuery = pFile->QueryInfoStatusCode(dwRet);

		delete[] pBytes;

		if (dwRet == HTTP_STATUS_OK) {

			char szBuf[4096] = { 0 };

			while (true) {

				UINT nRead = pFile->Read(szBuf, 4096);
				if (nRead > 0) {

					CString tmp;
					LPTSTR pS = tmp.GetBuffer(nRead + 1);
					for (UINT i = 0; i < nRead; ++i) {
						*pS++ = szBuf[i];
					}
					tmp.ReleaseBufferSetLength(nRead);

					res += tmp;
				}
				else {
					break;
				}
			}
		}

		delete pFile;
		delete pConnection;

		return res;
	}
	catch (CInternetException *e) {
		//e->ReportError();
		//TCHAR   szCause[2550];
		//e->GetErrorMessage(szCause, 2550);
		//cout << "Error when sending HTTP request: " << szCause << "\n";
		e->Delete();
		return res;
	}
}

static CString ExtractJsonStringProperty(CString json, CString prop) {
	const TCHAR Quote = _T('\"');
	CString search = _T("\"");
	search += prop;
	search += Quote;
	int i = json.Find(search);
	if (i < 0) { return _T(""); }
	i += search.GetLength();
	const int start = json.Find(Quote, i);
	if (start < 0) { return _T(""); }
	const int end = json.Find(Quote, start + 1);
	if (end < 0) { return _T(""); }
	return json.Mid(start + 1, end - start - 1);
}

static int ExtractJsonNumberProperty(CString json, CString prop, int defaultValue) {
	const TCHAR Quote = _T('\"');
	const TCHAR Sep = _T(':');
	CString search = _T("\"");
	search += prop;
	search += Quote;
	int i = json.Find(search);
	if (i < 0) { return defaultValue; }
	i += search.GetLength();

	const int iSep = json.Find(Sep, i);
	if (iSep < 0) { return defaultValue; }

	int n = iSep + 1;

	const int N = json.GetLength();
	while (n < N && (json[n] == ' ' || json[n] == '\r' || json[n] == '\n')) {
		n++;
	}

	const int start = n;
	if (start >= N) { return defaultValue; }

	while (n < N && (json[n] != ' ' && json[n] != ',' && json[n] != ']' && json[n] != '}')) {
		n++;
	}
	const int end = n;
	if (end >= N) { return defaultValue; }

	CString s = json.Mid(start, end - start);

	try {
		return std::stoi(std::wstring(s.GetBuffer()));
	}
	catch (...) {
		return defaultValue;
	}
}

CString Google_GetOAuth2Token(CString refresh_token, int& expires_in) {

	CString data = _T("client_id=");
	data += Google_Client_ID;
	data += _T("&client_secret=");
	data += Google_Client_Secret;
	data += _T("&refresh_token=");
	data += refresh_token;
	data += _T("&grant_type=refresh_token");

	CString json = HTTPPost(_T("oauth2.googleapis.com"), 443, _T("/token"), data);

	CString access_token = ExtractJsonStringProperty(json, _T("access_token"));
	expires_in = ExtractJsonNumberProperty(json, _T("expires_in"), 3000);

	return access_token;
}


struct HttpReqResult {
	bool OK;
	CString Query;
	CString ErroMsg;
	DWORD ErrorCode;
};

#pragma comment(lib, "httpapi.lib")

#define INITIALIZE_HTTP_RESPONSE( resp, status, reason )    \
    do                                                      \
    {                                                       \
        RtlZeroMemory( (resp), sizeof(*(resp)) );           \
        (resp)->StatusCode = (status);                      \
        (resp)->pReason = (reason);                         \
        (resp)->ReasonLength = (USHORT) strlen(reason);     \
    } while (FALSE)

#define ADD_KNOWN_HEADER(Response, HeaderId, RawValue)               \
    do                                                               \
    {                                                                \
        (Response).Headers.KnownHeaders[(HeaderId)].pRawValue =      \
                                                          (RawValue);\
        (Response).Headers.KnownHeaders[(HeaderId)].RawValueLength = \
            (USHORT) strlen(RawValue);                               \
    } while(FALSE)

#define ALLOC_MEM(cb) HeapAlloc(GetProcessHeap(), 0, (cb))

#define FREE_MEM(ptr) HeapFree(GetProcessHeap(), 0, (ptr))

HttpReqResult DoReceiveRequests(IN HANDLE hReqQueue, IN PCSTR pResponse);

DWORD
SendHttpResponse(
	IN HANDLE        hReqQueue,
	IN PHTTP_REQUEST pRequest,
	IN USHORT        StatusCode,
	IN PCSTR          pReason,
	IN PCSTR          pEntity
);

static HttpReqResult MakeErrorResult(PCWSTR err, DWORD errCode) {
	HttpReqResult res;
	res.OK = false;
	res.ErroMsg = err;
	res.ErrorCode = errCode;
	return res;
}

static HttpReqResult MakeOKResult(PCWSTR query) {
	HttpReqResult res;
	res.OK = true;
	res.Query = query;
	res.ErrorCode = 0;
	return res;
}

HttpReqResult WaitForHttpRequest(PCWSTR pURL, PCSTR pResponse) {

	ULONG           retCode = 0;
	HANDLE          hReqQueue = NULL;
	int             UrlAdded = 0;
	HTTPAPI_VERSION HttpApiVersion = HTTPAPI_VERSION_1;

	retCode = HttpInitialize(HttpApiVersion, HTTP_INITIALIZE_SERVER, NULL);

	if (retCode != NO_ERROR) {
		return MakeErrorResult(L"HttpInitialize failed", retCode);
	}

	retCode = HttpCreateHttpHandle(&hReqQueue, 0);

	if (retCode != NO_ERROR) {
		HttpTerminate(HTTP_INITIALIZE_SERVER, NULL);
		return MakeErrorResult(L"HttpCreateHttpHandle failed", retCode);
	}

	retCode = HttpAddUrl(hReqQueue, pURL, NULL);

	if (retCode != NO_ERROR) {
		CloseHandle(hReqQueue);
		HttpTerminate(HTTP_INITIALIZE_SERVER, NULL);
		return MakeErrorResult(L"HttpAddUrl failed (launch PopMan with Admin privileges and try again)", retCode);
	}

	const HttpReqResult res = DoReceiveRequests(hReqQueue, pResponse);

	CloseHandle(hReqQueue);
	HttpTerminate(HTTP_INITIALIZE_SERVER, NULL);

	return res;
}

HttpReqResult DoReceiveRequests(IN HANDLE hReqQueue, IN PCSTR pResponse) {

	HTTP_REQUEST_ID    requestId = 0;
	DWORD              bytesRead = 0;
	
	const ULONG RequestBufferLength = sizeof(HTTP_REQUEST) + 8192;
	CHAR requestBuffer[RequestBufferLength] = { 0 };
	
	PHTTP_REQUEST pRequest = (PHTTP_REQUEST)requestBuffer;	

	HTTP_SET_NULL_ID(&requestId);	

	RtlZeroMemory(pRequest, RequestBufferLength);

	const ULONG result = HttpReceiveHttpRequest(
		hReqQueue,          // Req Queue
		requestId,          // Req ID
		0,                  // Flags
		pRequest,           // HTTP request buffer
		RequestBufferLength,// req buffer length
		&bytesRead,         // bytes received
		NULL                // LPOVERLAPPED
	);

	if (NO_ERROR == result) {
			
		switch (pRequest->Verb)	{
		case HttpVerbGET:

			SendHttpResponse(hReqQueue, pRequest, 200, "OK", pResponse);
			return MakeOKResult(pRequest->CookedUrl.pQueryString);

		default:

			SendHttpResponse(hReqQueue, pRequest, 503,	"Not Implemented", NULL);
			return MakeErrorResult(L"Wrong request", 0);
		}
	}	
	else if (ERROR_CONNECTION_INVALID == result) {
		return MakeErrorResult(L"Connection broken", result);
	}
	else {
		return MakeErrorResult(L"Unknown error", result);
	}	
}

DWORD SendHttpResponse(
	IN HANDLE        hReqQueue,
	IN PHTTP_REQUEST pRequest,
	IN USHORT        StatusCode,
	IN PCSTR          pReason,
	IN PCSTR          pEntityString
)
{
	HTTP_RESPONSE   response;
	HTTP_DATA_CHUNK dataChunk;
	DWORD           result = 0;
	DWORD           bytesSent = 0;

	INITIALIZE_HTTP_RESPONSE(&response, StatusCode, pReason);
	ADD_KNOWN_HEADER(response, HttpHeaderContentType, "text/html");

	if (pEntityString) {
		
		dataChunk.DataChunkType = HttpDataChunkFromMemory;
		dataChunk.FromMemory.pBuffer = (PVOID)(PSTR)pEntityString;
		dataChunk.FromMemory.BufferLength =	(ULONG)strlen(pEntityString);

		response.EntityChunkCount = 1;
		response.pEntityChunks = &dataChunk;
	}

	// 
	// Because the entity body is sent in one call, it is not
	// required to specify the Content-Length.
	//

	result = HttpSendHttpResponse(
		hReqQueue,           // ReqQueueHandle
		pRequest->RequestId, // Request ID
		0,                   // Flags
		&response,           // HTTP response
		NULL,                // pReserved1
		&bytesSent,          // bytes sent  (OPTIONAL)
		NULL,                // pReserved2  (must be NULL)
		0,                   // Reserved3   (must be 0)
		NULL,                // LPOVERLAPPED(OPTIONAL)
		NULL                 // pReserved4  (must be NULL)
	);

	return result;
}

#include "StrFunctions.h"

static CString Base64UrlEncodeNoPadding(const uint8_t * pData, size_t len) {
	CString base64 = Base64Encode(pData, len);
	base64.Replace('+', '-');
	base64.Replace('/', '_');
	base64.Replace(_T("="), _T(""));
	return base64;
}

#include <ctime>
static CString GenerateRandomDataBase64url() {
	uint8_t hash[32]; // no init so we have some "random" data
	std::time_t t = std::time(0);
	std::tm* now = std::localtime(&t);
	hash[0] = now->tm_sec;
	hash[1] = now->tm_min;
	hash[2] = now->tm_hour;
	hash[3] = now->tm_mday;
	hash[4] = now->tm_mon;
	hash[5] = now->tm_year;
	hash[6] = now->tm_wday;
	hash[7] = now->tm_yday;
	hash[8] = now->tm_isdst;
	return Base64UrlEncodeNoPadding(hash, 32);
}

#include "sha-256.h"

static CString Sha256AndBase64NoPadding(CString code) {
	uint8_t hash[32] = { 0 };
	USES_CONVERSION;
	LPSTR data = W2A(code);
	calc_sha_256(hash, data, strlen(data));
	return Base64UrlEncodeNoPadding(hash, 32);
}

static CString EscapeDataString(const CString s) {

	const char hexUpperChars[] = {
		    '0', '1', '2', '3', '4', '5', '6', '7',
			'8', '9', 'A', 'B', 'C', 'D', 'E', 'F' };

	CString res;

	for (int i = 0; i < s.GetLength(); ++i) {
		const wchar_t c = s[i];
		if (c >= 'A' && c <= 'Z') {
			res += c;
		}
		else if (c >= 'a' && c <= 'z') {
			res += c;
		}
		else if (c >= '0' && c <= '9') {
			res += c;
		}		
		else if (c == '-' || c == '.' || c == '_' || c == '~') {
			res += c;
		}
		else if (c <= 0x7F) {
			res += '%';
			res += hexUpperChars[(c & 0xf0) >> 4];
			res += hexUpperChars[c & 0xf];
		}
		else {
			wchar_t buff[2] = { c, 0 };
			char buff_utf8[8] = { 0 };
			WideCharToMultiByte(CP_UTF8, 0, buff, 1, buff_utf8, 8, NULL, NULL);
			LPSTR utf8 = buff_utf8;
			while (*utf8) {
				const unsigned char ch = *utf8++;
				res += '%';
				res += hexUpperChars[(ch & 0xf0) >> 4];
				res += hexUpperChars[ch & 0xf];
			}
		}
	}

	return res;
}

static CString GetLastErrorText() {

	LPVOID lpMsgBuf = NULL;
	DWORD dw = GetLastError();

	FormatMessage(
		FORMAT_MESSAGE_ALLOCATE_BUFFER |
		FORMAT_MESSAGE_FROM_SYSTEM |
		FORMAT_MESSAGE_IGNORE_INSERTS,
		NULL,
		dw,
		MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
		(LPTSTR)&lpMsgBuf,
		0, NULL);

	CString res((LPCTSTR)lpMsgBuf);
	LocalFree(lpMsgBuf);
	return res;
}

static CString LaunchURL(CString url) {

	CoInitializeEx(NULL, COINIT_APARTMENTTHREADED | COINIT_DISABLE_OLE1DDE);

	INT_PTR res = (INT_PTR)ShellExecute(NULL, _T("open"), url, NULL, NULL, SW_SHOWNORMAL);
	if (res <= 32) {
		return _T("Failed to launch web browser for sign-in: ") + GetLastErrorText();
	}
	return _T("");
}

static CString UrlDecode(const CString data) {
	CString res;
	const int N = data.GetLength();
	for (int i = 0; i < N; ++i) {
		const TCHAR c = data[i];
		if (c == '+') {
			res += ' ';
		}
		else if (c == '%' && i < N - 2) {
			char hexCode[3];
			hexCode[0] = (char)data[++i];
			hexCode[1] = (char)data[++i];
			hexCode[2] = 0;
			const long ch = strtol(hexCode, NULL, 16);			
			res += ((TCHAR)ch);
		}
		else {
			res += c;
		}
	}
	return res;
}

static CString GetQueryValue(CString query, CString name) {
	int i = 0;
	while (true) {
		i = query.Find(name, i);
		if (i <= 0) return _T("");
		const TCHAR prev = query[i - 1];
		i += name.GetLength();
		if (i >= query.GetLength()) return _T("");		
		if (prev != '?' && prev != '&') {
			continue;
		}
		const TCHAR next = query[i];
		if (next != '=') {
			continue;
		}
		i++;
		const int end = query.Find('&', i);
		if (end < 0) {
			return UrlDecode(query.Mid(i, query.GetLength() - i));
		}
		else {
			return UrlDecode(query.Mid(i, end - i));
		}
	}
}

#include <strsafe.h>

static CString GetErrorText(DWORD dw) {

	LPVOID lpMsgBuf = NULL;

	FormatMessage(
		FORMAT_MESSAGE_ALLOCATE_BUFFER |
		FORMAT_MESSAGE_FROM_SYSTEM |
		FORMAT_MESSAGE_IGNORE_INSERTS,
		NULL,
		dw,
		MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
		(LPTSTR)&lpMsgBuf,
		0, NULL);

	CString res((LPCTSTR)lpMsgBuf);
	LocalFree(lpMsgBuf);
	return res;
}

CString DoOAuth(const CString user, const UINT freePort) {

	const CString state = GenerateRandomDataBase64url();
	const CString codeVerifier = GenerateRandomDataBase64url();
	const CString codeChallenge = Sha256AndBase64NoPadding(codeVerifier);

	const CString AuthorizationEndpoint = _T("https://accounts.google.com/o/oauth2/v2/auth");

	CString redirectUri;
	redirectUri.Format(_T("http://127.0.0.1:%u/"), freePort);
	
	CString authorizationRequest;
	authorizationRequest.Format(_T("%s?response_type=code&scope=%s%s&redirect_uri=%s&client_id=%s&state=%s&code_challenge=%s&code_challenge_method=S256&login_hint=%s"), 
		AuthorizationEndpoint,
		(LPCWSTR)EscapeDataString(_T("https://mail.google.com/")),
		_T("%20profile"),
		(LPCWSTR)EscapeDataString(redirectUri),
		(LPCWSTR)Google_Client_ID,
		(LPCWSTR)state,
		(LPCWSTR)codeChallenge,
		(LPCWSTR)EscapeDataString(user));

	const CString errLaunch = LaunchURL(authorizationRequest);

	if (errLaunch.GetLength() > 0) {
		CString res;
		res = _T("<error>");
		res += errLaunch;
		res += _T("</error>");
		return res;
	}

	PCSTR response = "<html><head><meta http-equiv='refresh' content='10;url=https://www.ch-software.de/popman/download.htm'></head><body>Please return to PopMan.</body></html>";

	HttpReqResult query = WaitForHttpRequest(redirectUri, response);

	CString res;
	if (query.OK) {
		res = query.Query;

		const CString error = GetQueryValue(query.Query, _T("error"));		
		
		if (!error.IsEmpty()) {
			res = _T("<error>");
			res += error;
			res += _T("</error>");
			return res;
		}

		const CString code = GetQueryValue(query.Query, _T("code"));
		const CString incomingState = GetQueryValue(query.Query, _T("state"));

		if (incomingState != state) {
			res = _T("<error>Invalid state</error>");
			return res;
		}

		CString tokenRequestBody;
		tokenRequestBody.Format(_T("code=%s&redirect_uri=%s&client_id=%s&code_verifier=%s&client_secret=%s&scope=&grant_type=authorization_code"),
			(LPCWSTR)code,
			(LPCWSTR)EscapeDataString(redirectUri),
			(LPCWSTR)Google_Client_ID,
			(LPCWSTR)codeVerifier,
			(LPCWSTR)Google_Client_Secret);

		CString json = HTTPPost(_T("www.googleapis.com"), 443, _T("oauth2/v4/token"), tokenRequestBody);
		CString refresh_token = ExtractJsonStringProperty(json, _T("refresh_token"));

		res = _T("<token>");
		res += refresh_token;
		res += _T("</token>");
		return res;
	}
	else {
		res = _T("<error>");
		res += query.ErroMsg + L": " + GetErrorText(query.ErrorCode);
		res += _T("</error>");
	}

	return res;
}