#!/usr/bin/env python3

import argparse
import xml.etree.ElementTree as ET

parser = argparse.ArgumentParser('filter an xml file for language variety specifiers')
parser.add_argument('var', help='language variety to retain (comma-separated for multiple)')
parser.add_argument('infile', help='input file')
parser.add_argument('outfile', help='output file')
parser.add_argument('-k', '--keep', help='node types to retain unchanged', action='append', default=[])
parser.add_argument('-a', '--attr', help='attribute which specifies language variety (default: v)', default='v')
args = parser.parse_args()

keep = set(args.var.split(','))

def should_keep(node):
	global args, keep
	if args.attr not in node.attrib:
		return True
	vs = set(node.attrib[args.attr].split(','))
	node.attrib.pop(args.attr)
	return not vs.isdisjoint(keep)

def filter_node(node):
	global args, keep
	rem = []
	for ch in node:
		if ch.tag in args.keep:
			continue
		elif should_keep(ch):
			filter_node(ch)
		else:
			rem.append(ch)
	for r in rem:
		node.remove(r)

tree = ET.parse(args.infile)
filter_node(tree.getroot())
tree.write(args.outfile)
