Commit e812f6c2 authored by Dave Cridland's avatar Dave Cridland

OF-1003 OF-1002 Make XEP-0198 thread safer

parent c5db7ec9
......@@ -1269,22 +1269,7 @@ public class SessionManager extends BasicModule implements ClusterEventListener/
router.route(presence);
}
// Re-deliver unacknowledged stanzas from broken stream (XEP-0198)
if(session.getStreamManager().isEnabled()) {
session.getStreamManager().setEnabled(false); // Avoid concurrent usage.
Deque<StreamManager.UnackedPacket> unacknowledgedStanzas = session.getStreamManager().getUnacknowledgedServerStanzas();
if(!unacknowledgedStanzas.isEmpty()) {
for(StreamManager.UnackedPacket unacked : unacknowledgedStanzas) {
if (unacked.packet instanceof Message) {
Message m = (Message)unacked.packet;
Element delayInformation = m.addChildElement("delay", "urn:xmpp:delay");
delayInformation.addAttribute("stamp", XMPPDateTimeFormat.format(unacked.timestamp));
delayInformation.addAttribute("from", serverAddress.toBareJID());
}
router.route(unacked.packet);
}
}
}
session.getStreamManager().onClose(router, serverAddress);
}
finally {
// Remove the session
......
......@@ -814,15 +814,7 @@ public class LocalClientSession extends LocalSession implements ClientSession {
public void deliver(Packet packet) throws UnauthorizedException {
conn.deliver(packet);
if(streamManager.isEnabled()) {
streamManager.incrementServerSentStanzas();
// Temporarily store packet until delivery confirmed
streamManager.getUnacknowledgedServerStanzas().addLast(new StreamManager.UnackedPacket(new Date(), packet.createCopy()));
if(getNumServerPackets() % JiveGlobals.getLongProperty("stream.management.requestFrequency", 5) == 0) {
streamManager.sendServerRequest();
}
}
streamManager.sentStanza(packet);
}
@Override
......
......@@ -6,6 +6,13 @@ import java.util.LinkedList;
import org.dom4j.Element;
import org.jivesoftware.openfire.Connection;
import org.jivesoftware.openfire.PacketRouter;
import org.jivesoftware.util.JiveGlobals;
import org.jivesoftware.util.XMPPDateTimeFormat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.xmpp.packet.JID;
import org.xmpp.packet.Message;
import org.xmpp.packet.Packet;
import org.xmpp.packet.PacketError;
......@@ -16,6 +23,7 @@ import org.xmpp.packet.PacketError;
* @author jonnyheavey
*/
public class StreamManager {
private static final Logger Log = LoggerFactory.getLogger(StreamManager.class);
public static class UnackedPacket {
public final Date timestamp;
public final Packet packet;
......@@ -90,7 +98,7 @@ public class StreamManager {
/**
* Sends XEP-0198 request <r /> to client from server
*/
public void sendServerRequest() {
private void sendServerRequest() {
if(isEnabled()) {
String request = String.format("<r xmlns='%s' />", getNamespace());
getConnection().deliverRawText(request);
......@@ -115,27 +123,71 @@ public class StreamManager {
*/
public void processClientAcknowledgement(Element ack) {
if(isEnabled()) {
if(ack.attribute("h") != null) {
long count = Long.valueOf(ack.attributeValue("h"));
// Remove stanzas from temporary storage as now acknowledged
Deque<UnackedPacket> unacknowledgedStanzas = getUnacknowledgedServerStanzas();
long i = getClientProcessedStanzas();
if (count < i) {
synchronized (this) {
if (ack.attribute("h") != null) {
long count = Long.valueOf(ack.attributeValue("h"));
// Remove stanzas from temporary storage as now acknowledged
Deque<UnackedPacket> unacknowledgedStanzas = getUnacknowledgedServerStanzas();
long i = getClientProcessedStanzas();
Log.debug("Ack: h={} mine={} length={}", count, i, unacknowledgedStanzas.size());
if (count < i) {
/* Consider rollover? */
if (i > mask) {
while (count < i) {
count += mask + 1;
}
}
}
while(i < count) {
unacknowledgedStanzas.removeFirst();
i++;
Log.debug("Maybe rollover");
if (i > mask) {
while (count < i) {
Log.debug("Rolling...");
count += mask + 1;
}
}
}
while (i < count) {
unacknowledgedStanzas.removeFirst();
i++;
Log.debug("In Ack: h={} mine={} length={}", count, i, unacknowledgedStanzas.size());
}
setClientProcessedStanzas(count);
}
}
}
}
setClientProcessedStanzas(count);
public void sentStanza(Packet packet) {
if(isEnabled()) {
synchronized (this) {
incrementServerSentStanzas();
// Temporarily store packet until delivery confirmed
getUnacknowledgedServerStanzas().addLast(new StreamManager.UnackedPacket(new Date(), packet.createCopy()));
Log.debug("Added stanza of type {}, now {} / {}", packet.getClass().getName(), getServerSentStanzas(), getUnacknowledgedServerStanzas().size());
}
if(getServerSentStanzas() % JiveGlobals.getLongProperty("stream.management.requestFrequency", 5) == 0) {
sendServerRequest();
}
}
}
public void onClose(PacketRouter router, JID serverAddress) {
// Re-deliver unacknowledged stanzas from broken stream (XEP-0198)
if(isEnabled()) {
setEnabled(false); // Avoid concurrent usage.
synchronized (this) {
Deque<StreamManager.UnackedPacket> unacknowledgedStanzas = getUnacknowledgedServerStanzas();
if (!unacknowledgedStanzas.isEmpty()) {
for (StreamManager.UnackedPacket unacked : unacknowledgedStanzas) {
if (unacked.packet instanceof Message) {
Message m = (Message) unacked.packet;
Element delayInformation = m.addChildElement("delay", "urn:xmpp:delay");
delayInformation.addAttribute("stamp", XMPPDateTimeFormat.format(unacked.timestamp));
delayInformation.addAttribute("from", serverAddress.toBareJID());
}
router.route(unacked.packet);
}
}
}
}
}
/**
......@@ -160,7 +212,7 @@ public class StreamManager {
* manager belongs to.
* @param enabled
*/
public void setEnabled(boolean enabled) {
synchronized public void setEnabled(boolean enabled) {
this.enabled = enabled;
if(enabled) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment