diff --git a/src/main/java/jnr/unixsocket/UnixServerSocketChannel.java b/src/main/java/jnr/unixsocket/UnixServerSocketChannel.java index afd07af..3e8f8a9 100644 --- a/src/main/java/jnr/unixsocket/UnixServerSocketChannel.java +++ b/src/main/java/jnr/unixsocket/UnixServerSocketChannel.java @@ -54,7 +54,14 @@ public UnixSocketChannel accept() throws IOException { int maxLength = addr.getMaximumLength(); IntByReference len = new IntByReference(maxLength); - int clientfd = Native.accept(getFD(), addr, len); + int clientfd=-1; + try { + begin(); + clientfd = Native.accept(getFD(), addr, len); + } + finally { + end(clientfd>=0); + } if (clientfd < 0) { if (isBlocking()) { diff --git a/src/test/java/jnr/unixsocket/AcceptInterruptTest.java b/src/test/java/jnr/unixsocket/AcceptInterruptTest.java new file mode 100644 index 0000000..4c80f37 --- /dev/null +++ b/src/test/java/jnr/unixsocket/AcceptInterruptTest.java @@ -0,0 +1,106 @@ +package jnr.unixsocket; + +import java.io.File; +import java.io.IOException; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import jnr.unixsocket.UnixServerSocketChannel; +import jnr.unixsocket.UnixSocketAddress; +import junit.framework.Assert; + +import org.junit.Test; + +public class AcceptInterruptTest { + @Test + public void testAcceptCloseInterrupt() throws Exception { + File file = File.createTempFile("test", ".sock"); + file.delete(); + file.deleteOnExit(); + + final UnixServerSocketChannel channel = UnixServerSocketChannel.open(); + channel.socket().bind(new UnixSocketAddress(file)); + + final AtomicBoolean run = new AtomicBoolean(true); + final CountDownLatch start = new CountDownLatch(1); + final CountDownLatch complete = new CountDownLatch(1); + Thread accept = new Acceptor(complete, start, channel, run); + + // Start accepting thread + accept.setDaemon(true); + accept.start(); + Assert.assertTrue(start.await(5,TimeUnit.SECONDS)); + + // Mark as no longer running + run.set(false); + + // Close and Interrupt + channel.close(); + accept.interrupt(); + Assert.assertTrue(complete.await(5,TimeUnit.SECONDS)); + } + + @Test + public void testAcceptInterrupt() throws Exception { + File file = File.createTempFile("test", ".sock"); + file.delete(); + file.deleteOnExit(); + + final UnixServerSocketChannel channel = UnixServerSocketChannel.open(); + channel.socket().bind(new UnixSocketAddress(file)); + + final AtomicBoolean run = new AtomicBoolean(true); + final CountDownLatch start = new CountDownLatch(1); + final CountDownLatch complete = new CountDownLatch(1); + Thread accept = new Acceptor(complete, start, channel, run); + + // Start accepting thread + accept.setDaemon(true); + accept.start(); + Assert.assertTrue(start.await(5,TimeUnit.SECONDS)); + + // Mark as no longer running + run.set(false); + + accept.interrupt(); + Assert.assertTrue(complete.await(5,TimeUnit.SECONDS)); + } + + private final class Acceptor extends Thread { + private final CountDownLatch complete; + private final CountDownLatch start; + private final UnixServerSocketChannel channel; + private final AtomicBoolean run; + + private Acceptor(CountDownLatch complete, CountDownLatch start, UnixServerSocketChannel channel, + AtomicBoolean run) { + this.complete = complete; + this.start = start; + this.channel = channel; + this.run = run; + } + + @Override public void run() { + try { + while(run.get()) { + if (start.getCount()>0) + start.countDown(); + try { + channel.accept(); + System.err.println("accepted"); + } + catch (IOException e) { + e.printStackTrace(); + } + finally { + System.err.println("finally"); + } + } + } + finally { + complete.countDown(); + } + } + } +}