socket.cc

Go to the documentation of this file.
00001 /*
00002 ** socket.cc
00003 ** Login : Julien Lemoine <speedblue@happycoders.org>
00004 ** Started on  Sat Mar  1 23:01:09 2003 Julien Lemoine
00005 ** $Id: socket.cc,v 1.17 2006/12/28 17:49:58 speedblue Exp $
00006 **
00007 ** Copyright (C) 2003,2004 Julien Lemoine
00008 ** This program is free software; you can redistribute it and/or modify
00009 ** it under the terms of the GNU Lesser General Public License as published by
00010 ** the Free Software Foundation; either version 2 of the License, or
00011 ** (at your option) any later version.
00012 **
00013 ** This program is distributed in the hope that it will be useful,
00014 ** but WITHOUT ANY WARRANTY; without even the implied warranty of
00015 ** MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00016 ** GNU Lesser General Public License for more details.
00017 **
00018 ** You should have received a copy of the GNU Lesser General Public License
00019 ** along with this program; if not, write to the Free Software
00020 ** Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
00021 */
00022 
00023 #include <iostream>
00024 #include <fstream>
00025 #include <sys/types.h>
00026 #include "socket.hh"
00027 
00028 namespace Network
00029 {
00030 
00031   Socket::Socket(SOCKET_KIND kind, SOCKET_VERSION version) :
00032     _kind(kind), _version(version), _state_timeout(0),
00033     _socket(0), _recv_flags(kind), _proto_kind(text), _empty_lines(false),
00034     _buffer(""), _tls(false)
00035   {
00036     _delim.push_back("\0");
00037 #ifdef LIBSOCKET_WIN
00038     WSADATA wsadata;
00039     if (WSAStartup(MAKEWORD(1, 1), &wsadata) != 0)
00040       throw WSAStartupError("WSAStartup failed", HERE);
00041 #endif
00042 #ifndef IPV6_ENABLED
00043     if (version == V6)
00044       throw Ipv6SupportError("lib was not compiled with ipv6 support", HERE);
00045 #endif
00046   }
00047 
00048   Socket::Socket(SOCKET_KIND kind, PROTO_KIND pkind, SOCKET_VERSION version) :
00049     _kind(kind), _version(version), _state_timeout(0),
00050     _socket(0), _recv_flags(kind), _proto_kind(pkind), _empty_lines(false),
00051     _buffer(""), _tls(false)
00052   {
00053     _delim.push_back("\0");
00054 #ifdef LIBSOCKET_WIN
00055     WSADATA wsadata;
00056     if (WSAStartup(MAKEWORD(1, 1), &wsadata) != 0)
00057       throw WSAStartupError("WSAStartup failed", HERE);
00058 #endif
00059 #ifndef IPV6_ENABLED
00060     if (version == V6)
00061       throw Ipv6SupportError("lib was not compiled with ipv6 support", HERE);
00062 #endif
00063   }
00064 
00065   Socket::~Socket()
00066   {
00067   }
00068 
00069   void  Socket::enable_tls()
00070   {
00071 #ifdef TLS
00072     int         ret;
00073 
00074     if (_kind != TCP)
00075       throw TLSError("You need to have a TCP connection", HERE);
00076     if (!connected())
00077       throw NoConnection("You need to have a connection", HERE);
00078     
00079     gnutls_transport_set_ptr(_session, (gnutls_transport_ptr)_socket);
00080     ret = gnutls_handshake(_session);
00081     if (ret < 0)
00082       {
00083         close(_socket);
00084         gnutls_deinit(_session);
00085         throw TLSError(gnutls_strerror(ret), HERE);
00086       }
00087 #else
00088     throw TLSSupportError("lib was not compiled with TLS support", HERE);
00089 #endif
00090   }
00091 
00092   void  Socket::init_tls(GnuTLSKind kind,
00093                          unsigned size, const std::string &certfile,
00094                          const std::string &keyfile,
00095                          const std::string &trustfile,
00096                          const std::string &crlfile)
00097   {
00098 #ifdef TLS
00099     static bool                                 init = false;
00100     static gnutls_dh_params                     dh_params;
00101     const int protocol_tls[] = { GNUTLS_TLS1, 0 };
00102     const int protocol_ssl[] = { GNUTLS_SSL3, 0 };
00103     const int cert_type_priority[] = { GNUTLS_CRT_X509, 
00104                                        GNUTLS_CRT_OPENPGP, 0 };
00105 
00106     if (!init)
00107       {
00108         gnutls_global_init();
00109         init = true;
00110       }
00111     _tls = true;
00112     _tls_main = true;
00113     gnutls_certificate_allocate_credentials(&_x509_cred);
00114     if (keyfile.size() > 0 && certfile.size() > 0)
00115       {
00116         std::ifstream key(keyfile.c_str()), cert(certfile.c_str());
00117         if (!key.is_open() || !cert.is_open())
00118           throw InvalidFile("key or cert invalid", HERE);
00119         key.close();
00120         cert.close();
00121         // Only for server...
00122         _nbbits = size;
00123         if (trustfile.size() > 0)
00124           gnutls_certificate_set_x509_trust_file(_x509_cred, trustfile.c_str(), 
00125                                                  GNUTLS_X509_FMT_PEM);
00126         if (crlfile.size() > 0)
00127           gnutls_certificate_set_x509_crl_file(_x509_cred, crlfile.c_str(), 
00128                                                GNUTLS_X509_FMT_PEM);
00129         gnutls_certificate_set_x509_key_file(_x509_cred, certfile.c_str(), 
00130                                              keyfile.c_str(), 
00131                                              GNUTLS_X509_FMT_PEM);
00132         gnutls_dh_params_init(&dh_params);
00133         gnutls_dh_params_generate2(dh_params, _nbbits);
00134         gnutls_certificate_set_dh_params(_x509_cred, dh_params);
00135 
00136         if (gnutls_init(&_session, GNUTLS_SERVER))
00137           throw TLSError("gnutls_init failed", HERE);
00138       }
00139     else
00140       {
00141         if (gnutls_init(&_session, GNUTLS_CLIENT))
00142           throw TLSError("gnutls_init failed", HERE);
00143       }
00144     
00145     gnutls_set_default_priority(_session);
00146     if (kind == TLS)
00147       gnutls_protocol_set_priority(_session, protocol_tls);
00148     else
00149       gnutls_protocol_set_priority(_session, protocol_ssl);
00150 
00151     if (keyfile.size() > 0 && certfile.size() > 0)
00152       {
00153         gnutls_credentials_set(_session, GNUTLS_CRD_CERTIFICATE, _x509_cred);
00154         gnutls_certificate_server_set_request(_session, GNUTLS_CERT_REQUEST);
00155         gnutls_dh_set_prime_bits(_session, _nbbits);
00156       }
00157     else
00158       {
00159         gnutls_certificate_type_set_priority(_session, cert_type_priority);
00160         gnutls_credentials_set(_session, GNUTLS_CRD_CERTIFICATE, _x509_cred);
00161       }
00162 #else
00163     throw TLSSupportError("lib was not compiled with TLS support", HERE);
00164 #endif
00165   }
00166 
00167   void  Socket::_close(int socket) const
00168   {
00169 #ifndef LIBSOCKET_WIN
00170     if (socket < 0 || close(socket) < 0)
00171       throw CloseError("Close Error", HERE);
00172     socket = 0;
00173 #else
00174     if (socket < 0 || closesocket(socket) < 0)
00175       throw CloseError("Close Error", HERE);
00176     socket = 0;
00177 #endif
00178 #ifdef TLS
00179     if (_tls)
00180       {
00181         std::cout << "Deletion..." << std::endl;
00182         gnutls_deinit(_session);
00183         if (_tls_main)
00184           {
00185             gnutls_certificate_free_credentials(_x509_cred);
00186             gnutls_global_deinit();
00187           }
00188       }
00189 #endif
00190   }
00191 
00192   void  Socket::_listen(int socket) const
00193   {
00194     if (socket < 0 || listen(socket, 5) < 0)
00195       throw ListenError("Listen Error", HERE);
00196   }
00197 
00198   void  Socket::_write_str(int socket, const std::string& str) const
00199   {
00200     int                         res = 1;
00201     unsigned int                count = 0;
00202     const char                  *buf;
00203 
00204     buf = str.c_str();
00205     if (socket < 0)
00206       throw NoConnection("No Socket", HERE);
00207     while (res && count < str.size())
00208       {
00209 #ifdef IPV6_ENABLED
00210         if (V4 == _version)
00211 #endif
00212 #ifdef TLS
00213           if (_tls)
00214             res = gnutls_record_send(_session, buf + count, str.size() - count);
00215           else
00216 #endif
00217             res = sendto(socket, buf + count, str.size() - count, SENDTO_FLAGS,
00218                          (const struct sockaddr*)&_addr, sizeof(_addr));
00219 #ifdef IPV6_ENABLED
00220         else
00221           res = sendto(socket, buf + count, str.size() - count, SENDTO_FLAGS,
00222                        (const struct sockaddr*)&_addr6, sizeof(_addr6));
00223 #endif
00224         if (res <= 0)
00225           throw ConnectionClosed("Connection Closed", HERE);
00226         count += res;
00227       }
00228   }
00229 
00230   void  Socket::_write_str_bin(int socket, const std::string& str) const
00231   {
00232     int                         res = 1;
00233     unsigned int                count = 0;
00234     char*                       buf = new char[str.size() + 2];
00235 
00236     buf[0] = str.size() / 256;
00237     buf[1] = str.size() % 256;
00238     memcpy(buf + 2, str.c_str(), str.size());
00239     if (socket < 0)
00240       {
00241         delete[] buf;
00242         throw NoConnection("No Socket", HERE);
00243       }
00244     while (res && count < str.size() + 2)
00245       {
00246 #ifdef IPV6_ENABLED
00247         if (V4 == _version)
00248 #endif
00249 #ifdef TLS
00250           if (_tls)
00251             res = gnutls_record_send(_session, buf + count, str.size() + 2 - count);
00252           else
00253 #endif
00254             res = sendto(socket, buf + count, str.size() + 2 - count, SENDTO_FLAGS,
00255                          (const struct sockaddr*)&_addr, sizeof(_addr));
00256 #ifdef IPV6_ENABLED
00257         else
00258           res = sendto(socket, buf + count, str.size() + 2 - count, SENDTO_FLAGS,
00259                        (const struct sockaddr*)&_addr6, sizeof(_addr6));
00260 #endif
00261         if (res <= 0)
00262           {
00263             delete[] buf;
00264             throw ConnectionClosed("Connection Closed", HERE);
00265           }
00266         count += res;
00267       }
00268     delete[] buf;
00269   }
00270 
00271   void  Socket::_set_timeout(bool enable, int socket, int timeout)
00272   {
00273     fd_set              fdset;
00274     struct timeval      timetowait;
00275     int         res;
00276 
00277     if (enable)
00278       timetowait.tv_sec = timeout;
00279     else
00280       timetowait.tv_sec = 65535;
00281     timetowait.tv_usec = 0;
00282     FD_ZERO(&fdset);
00283     FD_SET(socket, &fdset);
00284     if (enable)
00285       res = select(socket + 1, &fdset, NULL, NULL, &timetowait);
00286     else
00287       res = select(socket + 1, &fdset, NULL, NULL, NULL);
00288     if (res < 0)
00289       throw SelectError("Select error", HERE);
00290     if (res == 0)
00291       throw Timeout("Timeout on socket", HERE);
00292   }
00293 
00294   void  Socket::write(const std::string& str)
00295   {
00296     if (_proto_kind == binary)
00297       _write_str_bin(_socket, str);
00298     else
00299       _write_str(_socket, str);
00300   }
00301 
00302   bool  Socket::connected() const
00303   {
00304     return _socket != 0;
00305   }
00306 
00307   void  Socket::allow_empty_lines()
00308   {
00309     _empty_lines = true;
00310   }
00311 
00312   int   Socket::get_socket()
00313   {
00314     return _socket;
00315   }
00316 
00317   void  Socket::add_delim(const std::string& delim)
00318   {
00319     _delim.push_back(delim);
00320   }
00321 
00322   void  Socket::del_delim(const std::string& delim)
00323   {
00324     std::list<std::string>::iterator    it, it2;
00325 
00326     for (it = _delim.begin(); it != _delim.end(); )
00327       {
00328         if (*it == delim)
00329           {
00330             it2 = it++;
00331             _delim.erase(it2);
00332           }
00333         else
00334           it++;
00335       }
00336   }
00337 
00338   std::pair<int, int>   Socket::_find_delim(const std::string& str, int start) const
00339   {
00340     int                                         i = -1;
00341     int                                         pos = -1, size = 0;
00342     std::list<std::string>::const_iterator      it;
00343 
00344     // Looking for the first delimiter.
00345     if (_delim.size() > 0)
00346       {
00347         it = _delim.begin();
00348         while (it != _delim.end())
00349           {
00350             if (*it == "")
00351               i = str.find('\0', start);
00352             else
00353               i = str.find(*it, start);
00354             if ((i >= 0) && ((unsigned int)i < str.size()) &&
00355                 (pos < 0 || i < pos))
00356               {
00357                 pos = i;
00358                 size = it->size() ? it->size() : 1;
00359               }
00360             it++;
00361           }
00362       }
00363     return std::pair<int, int>(pos, size);
00364   }
00365 
00366   Socket&       operator<<(Socket& s, const std::string& str)
00367   {
00368     s.write(str);
00369     return s;
00370   }
00371 
00372   Socket&       operator>>(Socket& s, std::string& str)
00373   {
00374     str = s.read();
00375     return s;
00376   }
00377 }

Generated on Thu Dec 28 19:14:02 2006 for libsocket by  doxygen 1.4.7