00001
00002 #include <Parse.h>
00003 #include <assert.h>
00004
00005 #include "SlaveSocket.h"
00006 #include "InternalSocket.h"
00007 #include "SlaveHandler.h"
00008 #include "ProxyOutSocket.h"
00009 #include "ProxyBinSocket.h"
00010 #include "MasterSocket.h"
00011
00012
00013 SlaveSocket::SlaveSocket(ISocketHandler& h)
00014 :BaseSocket(h)
00015 ,m_state(0)
00016 {
00017 }
00018
00019
00020 SlaveSocket::~SlaveSocket()
00021 {
00022 }
00023
00024
00025 void SlaveSocket::OnRead()
00026 {
00027 TcpSocket::OnRead();
00028 bool need_more = false;
00029 while (!need_more && !CloseAndDelete())
00030 {
00031 size_t l = ibuf.GetLength();
00032 switch (m_state)
00033 {
00034 case 0:
00035 if (l >= 2)
00036 {
00037 ibuf.Read( (char *)&m_ip, 2);
00038 m_ip = ntohs(m_ip);
00039 m_state = 2;
00040 }
00041 else
00042 need_more = true;
00043 break;
00044 case 2:
00045 if (l >= 2)
00046 {
00047 ibuf.Read( (char *)&m_command, 2);
00048 m_command = ntohs(m_command);
00049 m_state = 3;
00050 }
00051 else
00052 need_more = true;
00053 break;
00054 case 3:
00055 if (l >= 2)
00056 {
00057 ibuf.Read( (char *)&m_length, 2);
00058 m_length = ntohs(m_length);
00059 m_packet_ptr = 0;
00060 m_state = 4;
00061 }
00062 else
00063 need_more = true;
00064 break;
00065 case 4:
00066 if (m_length)
00067 {
00068 if (l < m_length - m_packet_ptr)
00069 {
00070 ibuf.Read(m_packet + m_packet_ptr, l);
00071 m_packet_ptr += l;
00072 need_more = true;
00073 }
00074 else
00075 {
00076 ibuf.Read(m_packet + m_packet_ptr, m_length - m_packet_ptr);
00077 m_packet_ptr += m_length - m_packet_ptr;
00078 }
00079 }
00080 if (m_packet_ptr == (size_t)m_length)
00081 {
00082 mastercmd_t cmd = static_cast<mastercmd_t>(m_command);
00083 printmsg(m_ip, cmd, m_length);
00084 switch (cmd)
00085 {
00086 case M2S_OPEN:
00087 OpenConnection();
00088 break;
00089 case M2S_CLOSE:
00090 CloseConnection();
00091 break;
00092 case M2S_DATA:
00093 SendPacket();
00094 break;
00095
00096
00097
00098
00099
00100
00101
00102
00103
00104
00105
00106
00107
00108
00109
00110
00111
00112
00113
00114
00115
00116
00117
00118
00119
00120
00121
00122
00123
00124
00125
00126
00127
00128
00129
00130
00131
00132
00133
00134
00135
00136
00137
00138
00139
00140
00141
00142
00143 }
00144 m_state = 0;
00145 }
00146 break;
00147 default:
00148 printf("SlaveSocket: Bad state (%d)\n", m_state);
00149 assert(0);
00150 SetCloseAndDelete();
00151 break;
00152 }
00153 }
00154 }
00155
00156
00157 void SlaveSocket::OpenConnection()
00158 {
00159 unsigned short id;
00160 memcpy(&id, &m_ip, 2);
00161 m_packet[m_length] = 0;
00162 printf(" payload: %s\n", m_packet);
00163 Parse pa(m_packet, ":");
00164 std::string host = pa.getword();
00165 port_t port = pa.getvalue();
00166 printf(" host: %s port: %u\n", host.c_str(), port);
00167 InternalSocket *p = new InternalSocket(Handler(), this, id);
00168 p -> SetDeleteByHandler();
00169 p -> Open(host, port);
00170 Handler().Add(p);
00171 }
00172
00173
00174 void SlaveSocket::CloseConnection()
00175 {
00176 unsigned short id;
00177 memcpy(&id, &m_ip, 2);
00178 InternalSocket *p = GetSock(id);
00179 if (p)
00180 p -> SetCloseAndDelete();
00181 }
00182
00183
00184 void SlaveSocket::SendPacket()
00185 {
00186 unsigned short id;
00187 memcpy(&id, &m_ip, 2);
00188 InternalSocket *p = GetSock(id);
00189 if (p)
00190 p -> SendBuf(m_packet, m_length);
00191 }
00192
00193
00194 InternalSocket *SlaveSocket::GetSock(unsigned short id)
00195 {
00196 return static_cast<SlaveHandler&>(Handler()).GetSock(id);
00197 }
00198
00199
00200 port_t SlaveSocket::GetHostPort(const std::string& header,std::string& host)
00201 {
00202 size_t i = 0;
00203 port_t port = 0;
00204 while (i < header.size())
00205 {
00206 size_t x = i;
00207 while (header[i] != 13 && header[i] != 10 && i < header.size())
00208 {
00209 i++;
00210 }
00211 std::string line = header.substr(x, i - x);
00212 while (i < header.size() && (header[i] == 13 || header[i] == 10))
00213 {
00214 i++;
00215 }
00216 Parse pa(line, ":");
00217 std::string key = pa.getword();
00218 if (!strcasecmp(key.c_str(), "host"))
00219 {
00220 Parse pa2(pa.getrest(), ":");
00221 host = pa2.getword();
00222 port = pa2.getvalue();
00223 if (!port)
00224 port = 80;
00225 return port;
00226 }
00227 }
00228 return 0;
00229 }
00230
00231