001    /**************************************************************** 
002     * This work is derived from 'jnamed.java' distributed in       *
003     * 'dnsjava-2.0.5'. This original is licensed as follows:       *
004     * Copyright (c) 1999-2005, Brian Wellington                    *
005     * All rights reserved.                                         *
006     *                                                              *
007     * Redistribution and use in source and binary forms, with or   * 
008     * without modification, are permitted provided that the        *  
009     * following conditions are met:                                * 
010     *                                                              * 
011     *  * Redistributions of source code must retain the above      *
012     *    copyright notice, this list of conditions and the         *
013     *    following disclaimer.                                     *
014     *  * Redistributions in binary form must reproduce the above   *
015     *    copyright notice, this list of conditions and the         *
016     *    following disclaimer in the documentation and/or other    *
017     *    materials provided with the distribution.                 *
018     *  * Neither the name of the dnsjava project nor the names     *
019     *    of its contributors may be used to endorse or promote     *
020     *    products derived from this software without specific      *
021     *    prior written permission.                                 *
022     *                                                              *
023     * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND       *
024     * CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,  *
025     * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF     *
026     * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE     *
027     * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR         *
028     * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, *
029     * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,     *
030     * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR       *
031     * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS         *
032     * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF            *
033     * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT    *
034     * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT   *
035     * OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE          *
036     * POSSIBILITY OF SUCH DAMAGE.                                  *
037     *                                                              *
038     * Modifications are                                            * 
039     * Licensed to the Apache Software Foundation (ASF) under one   *
040     * or more contributor license agreements.  See the NOTICE file *
041     * distributed with this work for additional information        *
042     * regarding copyright ownership.  The ASF licenses this file   *
043     * to you under the Apache License, Version 2.0 (the            *
044     * "License"); you may not use this file except in compliance   *
045     * with the License.  You may obtain a copy of the License at   *
046     *                                                              *
047     *   http://www.apache.org/licenses/LICENSE-2.0                 *
048     *                                                              *
049     * Unless required by applicable law or agreed to in writing,   *
050     * software distributed under the License is distributed on an  *
051     * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY       *
052     * KIND, either express or implied.  See the License for the    *
053     * specific language governing permissions and limitations      *
054     * under the License.                                           *
055     ****************************************************************/
056    
057    package org.apache.james.jspf.tester;
058    
059    import org.xbill.DNS.AAAARecord;
060    import org.xbill.DNS.ARecord;
061    import org.xbill.DNS.Address;
062    import org.xbill.DNS.CNAMERecord;
063    import org.xbill.DNS.DClass;
064    import org.xbill.DNS.DNAMERecord;
065    import org.xbill.DNS.ExtendedFlags;
066    import org.xbill.DNS.Flags;
067    import org.xbill.DNS.Header;
068    import org.xbill.DNS.MXRecord;
069    import org.xbill.DNS.Message;
070    import org.xbill.DNS.NSRecord;
071    import org.xbill.DNS.Name;
072    import org.xbill.DNS.NameTooLongException;
073    import org.xbill.DNS.OPTRecord;
074    import org.xbill.DNS.Opcode;
075    import org.xbill.DNS.PTRRecord;
076    import org.xbill.DNS.RRset;
077    import org.xbill.DNS.Rcode;
078    import org.xbill.DNS.Record;
079    import org.xbill.DNS.SOARecord;
080    import org.xbill.DNS.SPFRecord;
081    import org.xbill.DNS.Section;
082    import org.xbill.DNS.SetResponse;
083    import org.xbill.DNS.TXTRecord;
084    import org.xbill.DNS.TextParseException;
085    import org.xbill.DNS.Type;
086    import org.xbill.DNS.Zone;
087    
088    import java.io.IOException;
089    import java.net.InetAddress;
090    import java.net.Socket;
091    import java.net.UnknownHostException;
092    import java.util.HashSet;
093    import java.util.Iterator;
094    import java.util.LinkedList;
095    import java.util.List;
096    import java.util.Map;
097    import java.util.Random;
098    import java.util.Set;
099    
100    public class DNSTestingServer implements ResponseGenerator {
101    
102        static final int FLAG_DNSSECOK = 1;
103    
104        static final int FLAG_SIGONLY = 2;
105    
106        protected Zone zone;
107        
108        private Set<Name> timeoutServers;
109        
110        Random random = new Random();
111    
112        public DNSTestingServer(String address, String porta)
113                throws TextParseException, IOException {
114    
115            Integer port = new Integer(porta != null ? porta : "53");
116            InetAddress addr = Address.getByAddress(address != null ? address
117                    : "0.0.0.0");
118    
119            Thread t;
120            t = new Thread(new TCPListener(addr, port.intValue(), this));
121            t.setDaemon(true);
122            t.start();
123    
124            t = new Thread(new UDPListener(addr, port.intValue(), this));
125            t.setDaemon(true);
126            t.start();
127    
128            zone = null;
129        }
130    
131        @SuppressWarnings("unchecked")
132        public synchronized void setData(Map<String, List<?>> map) {
133            try {
134                this.timeoutServers = new HashSet<Name>();
135                List<Record> records = new LinkedList<Record>();
136    
137                records.add(new SOARecord(Name.root, DClass.IN, 3600, Name.root,
138                        Name.root, 857623948, 0, 0, 0, 0));
139                records.add(new NSRecord(Name.root, DClass.IN, 3600, Name.root));
140    
141                Iterator<String> hosts = map.keySet().iterator();
142                while (hosts.hasNext()) {
143                    String host = (String) hosts.next();
144                    Name hostname;
145                    if (!host.endsWith(".")) {
146                        hostname = Name.fromString(host + ".");
147                    } else {
148                        hostname = Name.fromString(host);
149                    }
150    
151                    List<?> l = map.get(host);
152                    if (l != null)
153                        for (Iterator<?> i = l.iterator(); i.hasNext();) {
154                            Object o = i.next();
155                            if (o instanceof Map) {
156                                Map<String, ?> hm = (Map) o;
157    
158                                Iterator<String> types = hm.keySet().iterator();
159    
160                                while (types.hasNext()) {
161                                    String type = (String) types.next();
162                                    if ("MX".equals(type)) {
163                                        List<?> mxList = (List<?>) hm.get(type);
164                                        Iterator<?> mxs = mxList.iterator();
165                                        while (mxs.hasNext()) {
166                                            Long prio = (Long) mxs.next();
167                                            String cname = (String) mxs.next();
168                                            if (cname != null) {
169                                                if (cname.length() > 0 &&  !cname.endsWith(".")) cname += ".";
170                                                
171                                                records.add(new MXRecord(hostname,
172                                                        DClass.IN, 3600, prio
173                                                                .intValue(), Name
174                                                                .fromString(cname)));
175                                            }
176                                        }
177                                    } else {
178                                        Object value = hm.get(type);
179                                        if ("A".equals(type)) {
180                                            records.add(new ARecord(hostname,
181                                                    DClass.IN, 3600, Address
182                                                            .getByAddress((String) value)));
183                                        } else if ("AAAA".equals(type)) {
184                                            records.add(new AAAARecord(hostname,
185                                                    DClass.IN, 3600, Address
186                                                            .getByAddress((String) value)));
187                                        } else if ("SPF".equals(type)) {
188                                            if (value instanceof List<?>) {
189                                                records.add(new SPFRecord(hostname,
190                                                        DClass.IN, 3600, (List<?>) value));
191                                            } else {
192                                                records.add(new SPFRecord(hostname,
193                                                        DClass.IN, 3600, (String) value));
194                                            }
195                                        } else if ("TXT".equals(type)) {
196                                            if (value instanceof List<?>) {
197                                                records.add(new TXTRecord(hostname,
198                                                        DClass.IN, 3600, (List<?>) value));
199                                            } else {
200                                                records.add(new TXTRecord(hostname,
201                                                        DClass.IN, 3600, (String) value));
202                                            }
203                                        } else {
204                                            if (!((String) value).endsWith(".")) {
205                                                value = ((String) value)+".";
206                                            }
207                                            if ("PTR".equals(type)) {
208                                                records
209                                                        .add(new PTRRecord(
210                                                                hostname,
211                                                                DClass.IN,
212                                                                3600,
213                                                                Name
214                                                                        .fromString((String) value)));
215                                            } else if ("CNAME".equals(type)) {
216                                                records.add(new CNAMERecord(
217                                                        hostname, DClass.IN, 3600,
218                                                        Name.fromString((String) value)));
219                                            } else {
220                                                throw new IllegalStateException(
221                                                        "Unsupported type: " + type);
222                                            }
223                                        }
224                                    }
225                                }
226                            } else if ("TIMEOUT".equals(o)) {
227                                timeoutServers.add(hostname);
228                            } else {
229                                throw new IllegalStateException(
230                                        "getRecord found an unexpected data");
231                            }
232                        }
233                }
234    
235                zone = new Zone(Name.root, (Record[]) records
236                        .toArray(new Record[] {}));
237                
238            } catch (TextParseException e) {
239                // TODO Auto-generated catch block
240                e.printStackTrace();
241            } catch (UnknownHostException e) {
242                // TODO Auto-generated catch block
243                e.printStackTrace();
244            } catch (IOException e) {
245                // TODO Auto-generated catch block
246                e.printStackTrace();
247            }
248        }
249    
250        private SOARecord findSOARecord() {
251            return zone.getSOA();
252        }
253    
254        private RRset findNSRecords() {
255            return zone.getNS();
256        }
257    
258        // TODO verify why enabling this lookup will make some test to fail!
259        private RRset findARecord(Name name) {
260            return null;
261            //return zone.findExactMatch(name, Type.A);
262        }
263    
264        private SetResponse findRecords(Name name, int type) {
265            SetResponse sr = zone.findRecords(name, type);
266            
267            if (sr == null || sr.answers() == null || sr.answers().length == 0) {
268                boolean timeout = timeoutServers.contains(name);
269                if (timeout) {
270                    try {
271                        Thread.sleep(2100);
272                    }
273                    catch (InterruptedException e) {
274                    }
275                    return null;
276                }
277            }
278            
279            try {
280                Thread.sleep(random.nextInt(500));
281            }
282            catch (Exception e) {} 
283            
284            return sr;
285        }
286    
287        @SuppressWarnings("unchecked")
288        void addRRset(Name name, Message response, RRset rrset, int section,
289                int flags) {
290            for (int s = 1; s <= section; s++)
291                if (response.findRRset(name, rrset.getType(), s))
292                    return;
293            if ((flags & FLAG_SIGONLY) == 0) {
294                Iterator<Record> it = rrset.rrs();
295                while (it.hasNext()) {
296                    Record r = (Record) it.next();
297                    if (r.getName().isWild() && !name.isWild())
298                        r = r.withName(name);
299                    response.addRecord(r, section);
300                }
301            }
302            if ((flags & (FLAG_SIGONLY | FLAG_DNSSECOK)) != 0) {
303                Iterator it = rrset.sigs();
304                while (it.hasNext()) {
305                    Record r = (Record) it.next();
306                    if (r.getName().isWild() && !name.isWild())
307                        r = r.withName(name);
308                    response.addRecord(r, section);
309                }
310            }
311        }
312    
313        private void addGlue(Message response, Name name, int flags) {
314            RRset a = findARecord(name);
315            if (a == null)
316                return;
317            addRRset(name, response, a, Section.ADDITIONAL, flags);
318        }
319    
320        private void addAdditional2(Message response, int section, int flags) {
321            Record[] records = response.getSectionArray(section);
322            for (int i = 0; i < records.length; i++) {
323                Record r = records[i];
324                Name glueName = r.getAdditionalName();
325                if (glueName != null)
326                    addGlue(response, glueName, flags);
327            }
328        }
329    
330        private final void addAdditional(Message response, int flags) {
331            addAdditional2(response, Section.ANSWER, flags);
332            addAdditional2(response, Section.AUTHORITY, flags);
333        }
334    
335        byte addAnswer(Message response, Name name, int type, int dclass,
336                int iterations, int flags) {
337            SetResponse sr;
338            byte rcode = Rcode.NOERROR;
339    
340            if (iterations > 6)
341                return Rcode.NOERROR;
342    
343            if (type == Type.SIG || type == Type.RRSIG) {
344                type = Type.ANY;
345                flags |= FLAG_SIGONLY;
346            }
347    
348            sr = findRecords(name, type);
349    
350            // TIMEOUT
351            if (sr == null) {
352                return -1;
353            }
354            
355            if (sr.isNXDOMAIN() || sr.isNXRRSET()) {
356                if (sr.isNXDOMAIN())
357                    response.getHeader().setRcode(Rcode.NXDOMAIN);
358    
359                response.addRecord(findSOARecord(), Section.AUTHORITY);
360    
361                if (iterations == 0)
362                    response.getHeader().setFlag(Flags.AA);
363    
364                rcode = Rcode.NXDOMAIN;
365    
366            } else if (sr.isDelegation()) {
367                RRset nsRecords = sr.getNS();
368                addRRset(nsRecords.getName(), response, nsRecords,
369                        Section.AUTHORITY, flags);
370            } else if (sr.isCNAME()) {
371                CNAMERecord cname = sr.getCNAME();
372                RRset rrset = new RRset(cname);
373                addRRset(name, response, rrset, Section.ANSWER, flags);
374                if (iterations == 0)
375                    response.getHeader().setFlag(Flags.AA);
376                rcode = addAnswer(response, cname.getTarget(), type, dclass,
377                        iterations + 1, flags);
378            } else if (sr.isDNAME()) {
379                DNAMERecord dname = sr.getDNAME();
380                RRset rrset = new RRset(dname);
381                addRRset(name, response, rrset, Section.ANSWER, flags);
382                Name newname;
383                try {
384                    newname = name.fromDNAME(dname);
385                } catch (NameTooLongException e) {
386                    return Rcode.YXDOMAIN;
387                }
388                rrset = new RRset(new CNAMERecord(name, dclass, 0, newname));
389                addRRset(name, response, rrset, Section.ANSWER, flags);
390                if (iterations == 0)
391                    response.getHeader().setFlag(Flags.AA);
392                rcode = addAnswer(response, newname, type, dclass, iterations + 1,
393                        flags);
394            } else if (sr.isSuccessful()) {
395                RRset[] rrsets = sr.answers();
396                for (int i = 0; i < rrsets.length; i++)
397                    addRRset(name, response, rrsets[i], Section.ANSWER, flags);
398    
399                RRset findNSRecords = findNSRecords();
400                addRRset(findNSRecords.getName(), response, findNSRecords,
401                        Section.AUTHORITY, flags);
402    
403                if (iterations == 0)
404                    response.getHeader().setFlag(Flags.AA);
405            }
406            return rcode;
407        }
408    
409        public byte[] generateReply(Message query, int length, Socket s)
410                throws IOException {
411            Header header;
412            int maxLength;
413            int flags = 0;
414    
415            header = query.getHeader();
416            if (header.getFlag(Flags.QR))
417                return null;
418            if (header.getRcode() != Rcode.NOERROR)
419                return errorMessage(query, Rcode.FORMERR);
420            if (header.getOpcode() != Opcode.QUERY)
421                return errorMessage(query, Rcode.NOTIMP);
422    
423            Record queryRecord = query.getQuestion();
424    
425            OPTRecord queryOPT = query.getOPT();
426            if (queryOPT != null && queryOPT.getVersion() > 0) {
427            }
428    
429            if (s != null)
430                maxLength = 65535;
431            else if (queryOPT != null)
432                maxLength = Math.max(queryOPT.getPayloadSize(), 512);
433            else
434                maxLength = 512;
435    
436            if (queryOPT != null && (queryOPT.getFlags() & ExtendedFlags.DO) != 0)
437                flags = FLAG_DNSSECOK;
438    
439            Message response = new Message(query.getHeader().getID());
440            response.getHeader().setFlag(Flags.QR);
441            if (query.getHeader().getFlag(Flags.RD))
442                response.getHeader().setFlag(Flags.RD);
443            response.addRecord(queryRecord, Section.QUESTION);
444    
445            Name name = queryRecord.getName();
446            int type = queryRecord.getType();
447            int dclass = queryRecord.getDClass();
448            if (!Type.isRR(type) && type != Type.ANY)
449                return errorMessage(query, Rcode.NOTIMP);
450    
451            byte rcode = addAnswer(response, name, type, dclass, 0, flags);
452            
453            // TIMEOUT
454            if (rcode == -1) {
455                return null;
456            }
457            
458            if (rcode != Rcode.NOERROR && rcode != Rcode.NXDOMAIN)
459                return errorMessage(query, rcode);
460    
461            addAdditional(response, flags);
462    
463            if (queryOPT != null) {
464                int optflags = (flags == FLAG_DNSSECOK) ? ExtendedFlags.DO : 0;
465                OPTRecord opt = new OPTRecord((short) 4096, rcode, (byte) 0,
466                        optflags);
467                response.addRecord(opt, Section.ADDITIONAL);
468            }
469    
470            return response.toWire(maxLength);
471        }
472    
473        byte[] buildErrorMessage(Header header, int rcode, Record question) {
474            Message response = new Message();
475            response.setHeader(header);
476            for (int i = 0; i < 4; i++)
477                response.removeAllRecords(i);
478            if (rcode == Rcode.SERVFAIL)
479                response.addRecord(question, Section.QUESTION);
480            header.setRcode(rcode);
481            return response.toWire();
482        }
483    
484        public byte[] formerrMessage(byte[] in) {
485            Header header;
486            try {
487                header = new Header(in);
488            } catch (IOException e) {
489                return null;
490            }
491            return buildErrorMessage(header, Rcode.FORMERR, null);
492        }
493    
494        public byte[] errorMessage(Message query, int rcode) {
495            return buildErrorMessage(query.getHeader(), rcode, query.getQuestion());
496        }
497    
498        public byte[] generateReply(byte[] in, int length) {
499            Message query;
500            byte[] response = null;
501            try {
502                query = new Message(in);
503                response = generateReply(query, length, null);
504            } catch (IOException e) {
505                response = formerrMessage(in);
506            }
507            return response;
508        }
509    
510    }