#!/usr/bin/env python
#
# PyAuthenNTLM2: A mod-python module for Apache that carries out NTLM authentication
#
# pyntlm.py
#
# Copyright 2011 Legrandin <helderijs@gmail.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys, os
import base64
import time
import urllib
import argparse
import socket
import subprocess
import syslog
from struct import unpack
from threading import Lock
from binascii import hexlify
from urlparse import urlparse

from PyAuthenNTLM2.ntlm_dc_proxy import NTLM_DC_Proxy
from PyAuthenNTLM2.ntlm_ad_proxy import NTLM_AD_Proxy

class CacheConnections:

	def __init__(self):
		self._mutex = Lock()
		self._cache = {}

	def __len__(self):
		return len(self._cache)

	def remove(self, id):
		self._mutex.acquire()
		(proxy, ts) = self._cache.get(id, (None,None))
		if proxy:
			proxy.close()
			del self._cache[id]
		self._mutex.release()

	def add(self, id, proxy):
		self._mutex.acquire()
		self._cache[id] = ( proxy, int(time.time()) )
		self._mutex.release()

	def cleanTimeout(self):
		now = int(time.time())
		self._mutex.acquire()
		for id, conn in self._cache.items():
			if conn[1]+60<now:
				conn[0].close()
				del self._cache[id]
		self._mutex.release()

	def cleanAll(self):
		now = int(time.time())
		self._mutex.acquire()
		for id, conn in self._cache.items():
			conn[0].close()
			del self._cache[id]
		self._mutex.release()

	def has_key(self,id):
		return self._cache.has_key(id)

	def get_proxy(self, id):
		self._mutex.acquire()
		proxy = self._cache[id][0]
		self._mutex.release()
		return proxy

cache = CacheConnections()
def dbg(m):
	syslog.syslog(syslog.LOG_ERR, "TEST:" +m)

def ntlm_message_type(msg):
	if not msg.startswith('NTLMSSP\x00') or len(msg)<12:
		raise RuntimeError("Not a valid NTLM message: '%s'" % hexlify(msg))
	msg_type = unpack('<I', msg[8:8+4])[0]
	if msg_type not in (1,2,3):
		raise RuntimeError("Incorrect NTLM message Type: %d" % msg_type)
	return msg_type

def parse_ntlm_authenticate(msg):
	NTLMSSP_NEGOTIATE_UNICODE = 0x00000001
	idx = 28
	length, offset = unpack('<HxxI', msg[idx:idx+8])
	domain = msg[offset:offset+length]
	idx += 8
	length, offset = unpack('<HxxI', msg[idx:idx+8])
	username = msg[offset:offset+length]
	idx += 24
	flags = unpack('<I', msg[idx:idx+4])[0]
	if flags & NTLMSSP_NEGOTIATE_UNICODE:
		domain = str(domain.decode('utf-16-le'))
		username = str(username.decode('utf-16-le'))
	return username, domain

def decode_http_authorization_header(auth):
	ah = auth.split(' ')
	if len(ah) == 2:
		b64 = base64.b64decode(ah[1])
		if ah[0] == 'NTLM':
			return ('NTLM', b64)
	return False

def connect_to_proxy(server):
	try:
		if server.startswith('ldap:'):
			#not support yet
			url = urlparse(server)
			decoded_path = urllib.unquote(url.path)[1:]
			return NTLM_AD_Proxy(url.netloc, "", base=decoded_path)
		else:
			return NTLM_DC_Proxy(server, "", verbose=True)
	except Exception, e:
		raise RuntimeError("PYNTLM: Error when connect to server " + server + ": "+str(e))

def get_connected_dc():
	p = subprocess.Popen("/bin/wbinfo -P", stdout=subprocess.PIPE, shell=True)
	output, err = p.communicate()
	p_status = p.wait()
	if err:
		raise RuntimeError("wbinfo error: "+err)
	arr = output.split("\"")
	if 3 != len(arr):
		raise RuntimeError("wbinfo format wrong: " + output)
	return arr[1]

def process(cid, auth):
	try:
		ah_data = decode_http_authorization_header(auth)
	except:
		ah_data = False

	if not ah_data:
		return -1, "No NTLM authenticate header"

	try:
		ntlm_version = ntlm_message_type(ah_data[1])
		if ntlm_version != 1 and ntlm_version != 3:
			raise RuntimeError('TYPE ' + str(ntlm_version) + ' message in client request.(183)')

		if ntlm_version == 1:
			cache.cleanTimeout()
			server = get_connected_dc()
			try:
				proxy = connect_to_proxy(server)
				ntlm_challenge = proxy.negotiate(ah_data[1])
				if not ntlm_challenge:
					raise RuntimeError('failed. to get chanllenge')
			except Exception, e:
				proxy.close()
				raise RuntimeError('TYPE 1 error: ' + str(e))
			cache.add(cid, proxy)
			return cid, "NTLM " + base64.b64encode(ntlm_challenge)

		if ntlm_version == 3:

			if not cache.has_key(cid):
				raise RuntimeError("TYPE 3 error: Proxy timeout.(id="+str(cid)+") >60s")
			proxy = cache.get_proxy(cid)
			user, domain = parse_ntlm_authenticate(ah_data[1])
			if not domain:
				raise RuntimeError("user "+ domain + "\\" + user + " invalid.")
			result = proxy.authenticate(ah_data[1])

			if not result:
				raise RuntimeError('TYPE 3 error: authenticate failed.')
			cache.remove(cid)
			return 0, domain+"\\"+user

	except Exception, e:
		if cache.has_key(cid): cache.remove(cid)
		return -1, str(e)

def run():
	socket_file = "/var/run/synocgid/synopyntlm.htua.ntlm.socket"

	try:
		os.mkdir("/var/run/synocgid", 0755);
		os.remove(socket_file)
	except OSError:
		pass

	sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
	sock.bind(socket_file)
	sock.listen(5)
	while True:
		conn, addr = sock.accept()
		data = conn.recv(4096)
		if not data: continue
		arr = data.split(',')
		result = -1
		message = "Format error:" + data
		if 0 != os.system("/bin/wbinfo -p &> /dev/null"):
			message = "Ping to winbindd failed"
		elif 2 == len(arr):
			result, message = process(arr[0], arr[1])
		conn.send(str(result) + "," + message)
		conn.close()
	try:
		os.remove(socket_file)
	except OSError:
		pass
	cache.cleanAll()

def debug(cid, auth):
	result, message = process(cid, auth)
	print(str(result) + "," + message)

def synontlm(args):
	if args.auth:
		debug(5566, args.auth)
	else:
		run()

parser = argparse.ArgumentParser(description='Process NTLM negotiation.')
parser.add_argument('-a', '--auth', default=False)
args = parser.parse_args()
synontlm(args)
