Commit 3243c29a authored by Vysheng's avatar Vysheng

Big commit

parent f773c0cf
CC=cc CC=cc
CFLAGS=-c -Wall -Wextra -fPIC CFLAGS=-c -Wall -Wextra -Werror -fPIC -ggdb -O2 -fno-omit-frame-pointer -fno-strict-aliasing -rdynamic
LDFLAGS=-lreadline LDFLAGS=-lreadline -lssl -lcrypto -lrt -lz -ggdb -rdynamic
LD=cc LD=cc
SRC=main.c loop.c interface.c SRC=main.c loop.c interface.c net.c mtproto-common.c mtproto-client.c queries.c structures.c
OBJ=$(SRC:.c=.o) OBJ=$(SRC:.c=.o)
EXE=telegram EXE=telegram
......
...@@ -4,10 +4,11 @@ ...@@ -4,10 +4,11 @@
#include <assert.h> #include <assert.h>
#include <stdio.h> #include <stdio.h>
#include <stdarg.h>
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
#include "include.h" #include "include.h"
#include "queries.h"
char *default_prompt = ">"; char *default_prompt = ">";
char *get_default_prompt (void) { char *get_default_prompt (void) {
...@@ -21,11 +22,13 @@ char *complete_none (const char *text UU, int state UU) { ...@@ -21,11 +22,13 @@ char *complete_none (const char *text UU, int state UU) {
char *commands[] = { char *commands[] = {
"help", "help",
"msg", "msg",
"contact_list",
0 }; 0 };
int commands_flags[] = { int commands_flags[] = {
070, 070,
072, 072,
00,
}; };
char *a = 0; char *a = 0;
...@@ -151,5 +154,27 @@ char **complete_text (char *text, int start UU, int end UU) { ...@@ -151,5 +154,27 @@ char **complete_text (char *text, int start UU, int end UU) {
} }
void interpreter (char *line UU) { void interpreter (char *line UU) {
assert (0); if (!memcmp (line, "contact_list", 12)) {
do_update_contact_list ();
}
}
void rprintf (const char *format, ...) {
int saved_point = rl_point;
char *saved_line = rl_copy_text(0, rl_end);
rl_save_prompt();
rl_replace_line("", 0);
rl_redisplay();
va_list ap;
va_start (ap, format);
vfprintf (stdout, format, ap);
va_end (ap);
rl_restore_prompt();
rl_replace_line(saved_line, 0);
rl_point = saved_point;
rl_redisplay();
free(saved_line);
} }
...@@ -4,4 +4,6 @@ char *get_default_prompt (void); ...@@ -4,4 +4,6 @@ char *get_default_prompt (void);
char *complete_none (const char *text, int state); char *complete_none (const char *text, int state);
char **complete_text (char *text, int start, int end); char **complete_text (char *text, int start, int end);
void interpreter (char *line); void interpreter (char *line);
void rprintf (const char *format, ...) __attribute__ ((format (printf, 1, 2)));
#endif #endif
...@@ -9,62 +9,212 @@ ...@@ -9,62 +9,212 @@
#include <readline/history.h> #include <readline/history.h>
#include <errno.h> #include <errno.h>
#include <poll.h>
#include <unistd.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
#include "interface.h" #include "interface.h"
#include "net.h"
#include "mtproto-client.h"
#include "mtproto-common.h"
#include "queries.h"
#include "telegram.h"
extern char *default_username; extern char *default_username;
extern char *auth_token; extern char *auth_token;
void set_default_username (const char *s); void set_default_username (const char *s);
int default_dc_num;
void net_loop (int flags, int (*is_end)(void)) {
while (!is_end ()) {
struct pollfd fds[101];
int cc = 0;
if (flags & 1) {
fds[0].fd = 0;
fds[0].events = POLLIN;
cc ++;
}
int main_loop (void) { int x = connections_make_poll_array (fds + cc, 101 - cc) + cc;
fd_set inp, outp; double timer = next_timer_in ();
struct timeval tv; if (timer > 1000) { timer = 1000; }
while (1) { if (poll (fds, x, timer) < 0) {
FD_ZERO (&inp); /* resuming from interrupt, so not an error situation,
FD_ZERO (&outp); this generally happens when you suspend your
FD_SET (0, &inp); messenger with "C-z" and then "fg". This is allowed "
tv.tv_sec = 1; */
tv.tv_usec = 0; if (flags & 1) {
int lfd = 0;
if (select (lfd + 1, &inp, &outp, NULL, &tv) < 0) {
if (errno == EINTR) {
/* resuming from interrupt, so not an error situation,
this generally happens when you suspend your
messenger with "C-z" and then "fg". This is allowed "
*/
rl_reset_line_state (); rl_reset_line_state ();
rl_forced_update_display (); rl_forced_update_display ();
continue;
} }
perror ("select()"); work_timers ();
break; continue;
} }
work_timers ();
if (FD_ISSET (0, &inp)) { if ((flags & 1) && (fds[0].revents & POLLIN)) {
rl_callback_read_char (); rl_callback_read_char ();
} }
connections_poll_result (fds + cc, x - cc);
} }
}
int ret1 (void) { return 0; }
int main_loop (void) {
net_loop (1, ret1);
return 0; return 0;
} }
struct dc *DC_list[MAX_DC_ID + 1];
struct dc *DC_working;
int dc_working_num;
int auth_state;
char *get_auth_key_filename (void);
int zero[512];
void write_dc (int auth_file_fd, struct dc *DC) {
assert (write (auth_file_fd, &DC->port, 4) == 4);
int l = strlen (DC->ip);
assert (write (auth_file_fd, &l, 4) == 4);
assert (write (auth_file_fd, DC->ip, l) == l);
if (DC->flags & 1) {
assert (write (auth_file_fd, &DC->auth_key_id, 8) == 8);
assert (write (auth_file_fd, DC->auth_key, 256) == 256);
} else {
assert (write (auth_file_fd, zero, 256 + 8) == 256 + 8);
}
assert (write (auth_file_fd, &DC->server_salt, 8) == 8);
}
void write_auth_file (void) {
int auth_file_fd = open (get_auth_key_filename (), O_CREAT | O_RDWR, S_IRWXU);
assert (auth_file_fd >= 0);
int x = DC_SERIALIZED_MAGIC;
assert (write (auth_file_fd, &x, 4) == 4);
x = MAX_DC_ID;
assert (write (auth_file_fd, &x, 4) == 4);
assert (write (auth_file_fd, &dc_working_num, 4) == 4);
assert (write (auth_file_fd, &auth_state, 4) == 4);
int i;
for (i = 0; i <= MAX_DC_ID; i++) {
if (DC_list[i]) {
x = 1;
assert (write (auth_file_fd, &x, 4) == 4);
write_dc (auth_file_fd, DC_list[i]);
} else {
x = 0;
assert (write (auth_file_fd, &x, 4) == 4);
}
}
close (auth_file_fd);
}
void read_dc (int auth_file_fd, int id) {
int port = 0;
assert (read (auth_file_fd, &port, 4) == 4);
int l = 0;
assert (read (auth_file_fd, &l, 4) == 4);
assert (l >= 0);
char *ip = malloc (l + 1);
assert (read (auth_file_fd, ip, l) == l);
ip[l] = 0;
struct dc *DC = alloc_dc (id, ip, port);
assert (read (auth_file_fd, &DC->auth_key_id, 8) == 8);
assert (read (auth_file_fd, &DC->auth_key, 256) == 256);
assert (read (auth_file_fd, &DC->server_salt, 8) == 8);
if (DC->auth_key_id) {
DC->flags |= 1;
}
}
void empty_auth_file (void) {
struct dc *DC = alloc_dc (1, strdup (TG_SERVER), 443);
assert (DC);
dc_working_num = 1;
write_auth_file ();
}
void read_auth_file (void) {
int auth_file_fd = open (get_auth_key_filename (), O_CREAT | O_RDWR, S_IRWXU);
if (auth_file_fd < 0) {
empty_auth_file ();
}
assert (auth_file_fd >= 0);
int x;
if (read (auth_file_fd, &x, 4) < 4 || x != DC_SERIALIZED_MAGIC) {
close (auth_file_fd);
empty_auth_file ();
return;
}
assert (read (auth_file_fd, &x, 4) == 4);
assert (x >= 0 && x <= MAX_DC_ID);
assert (read (auth_file_fd, &dc_working_num, 4) == 4);
assert (read (auth_file_fd, &auth_state, 4) == 4);
int i;
for (i = 0; i <= x; i++) {
int y;
assert (read (auth_file_fd, &y, 4) == 4);
if (y) {
read_dc (auth_file_fd, i);
}
}
close (auth_file_fd);
}
int loop (void) { int loop (void) {
size_t size = 0; on_start ();
char *user = default_username; read_auth_file ();
assert (DC_list[dc_working_num]);
if (!user && !auth_token) { DC_working = DC_list[dc_working_num];
printf ("Telephone number (with '+' sign): "); if (!DC_working->auth_key_id) {
if (getline (&user, &size, stdin) == -1) { dc_authorize (DC_working);
perror ("getline()"); } else {
exit (EXIT_FAILURE); dc_create_session (DC_working);
}
if (!auth_state) {
if (!default_username) {
size_t size = 0;
char *user = 0;
if (!user && !auth_token) {
printf ("Telephone number (with '+' sign): ");
if (getline (&user, &size, stdin) == -1) {
perror ("getline()");
exit (EXIT_FAILURE);
}
user[strlen (user) - 1] = 0;
set_default_username (user);
}
} }
user[strlen (user) - 1] = '\0'; do_send_code (default_username);
set_default_username (user); char *code = 0;
size_t size = 0;
printf ("Code from sms: ");
while (1) {
if (getline (&code, &size, stdin) == -1) {
perror ("getline()");
exit (EXIT_FAILURE);
}
code[strlen (code) - 1] = 0;
if (do_send_code_result (code) >= 0) {
break;
}
printf ("Invalid code. Try again: ");
}
auth_state = 1;
} }
write_auth_file ();
fflush (stdin); fflush (stdin);
fflush (stdout);
fflush (stderr);
rl_callback_handler_install (get_default_prompt (), interpreter); rl_callback_handler_install (get_default_prompt (), interpreter);
rl_attempted_completion_function = (CPPFunction *) complete_text; rl_attempted_completion_function = (CPPFunction *) complete_text;
......
#ifndef __LOOP_H__ #ifndef __LOOP_H__
#define __LOOP_H__ #define __LOOP_H__
int loop (void); int loop (void);
void net_loop (int flags, int (*end)(void));
void write_auth_file (void);
#endif #endif
...@@ -27,14 +27,18 @@ ...@@ -27,14 +27,18 @@
#include <sys/stat.h> #include <sys/stat.h>
#include <time.h> #include <time.h>
#include <fcntl.h> #include <fcntl.h>
#include <execinfo.h>
#include <signal.h>
#include "loop.h" #include "loop.h"
#include "mtproto-client.h"
#define PROGNAME "telegram-client" #define PROGNAME "telegram-client"
#define VERSION "0.01" #define VERSION "0.01"
#define CONFIG_DIRECTORY ".telegram/" #define CONFIG_DIRECTORY ".telegram/"
#define CONFIG_FILE CONFIG_DIRECTORY "config" #define CONFIG_FILE CONFIG_DIRECTORY "config"
#define AUTH_KEY_FILE CONFIG_DIRECTORY "auth"
#define DOWNLOADS_DIRECTORY "downloads/" #define DOWNLOADS_DIRECTORY "downloads/"
#define CONFIG_DIRECTORY_MODE 0700 #define CONFIG_DIRECTORY_MODE 0700
...@@ -72,6 +76,13 @@ void get_terminal_attributes (void) { ...@@ -72,6 +76,13 @@ void get_terminal_attributes (void) {
old_lflag = term.c_lflag; old_lflag = term.c_lflag;
old_vtime = term.c_cc[VTIME]; old_vtime = term.c_cc[VTIME];
} }
void set_terminal_attributes (void) {
if (tcsetattr (STDIN_FILENO, 0, &term) < 0) {
perror ("tcsetattr()");
exit (EXIT_FAILURE);
}
}
/* }}} */ /* }}} */
char *get_home_directory (void) { char *get_home_directory (void) {
...@@ -107,6 +118,15 @@ char *get_config_filename (void) { ...@@ -107,6 +118,15 @@ char *get_config_filename (void) {
return config_filename; return config_filename;
} }
char *get_auth_key_filename (void) {
char *auth_key_filename;
int length = strlen (get_home_directory ()) + strlen (AUTH_KEY_FILE) + 2;
auth_key_filename = (char *) calloc (length, sizeof (char));
sprintf (auth_key_filename, "%s/" AUTH_KEY_FILE, get_home_directory ());
return auth_key_filename;
}
char *get_downloads_directory (void) char *get_downloads_directory (void)
{ {
char *downloads_directory; char *downloads_directory;
...@@ -149,6 +169,11 @@ void running_for_first_time (void) { ...@@ -149,6 +169,11 @@ void running_for_first_time (void) {
exit (EXIT_FAILURE); exit (EXIT_FAILURE);
} }
close (config_file_fd); close (config_file_fd);
int auth_file_fd = open (get_auth_key_filename (), O_CREAT | O_RDWR, S_IRWXU);
int x = -1;
assert (write (auth_file_fd, &x, 4) == 4);
close (auth_file_fd);
printf ("[%s] created\n", config_filename); printf ("[%s] created\n", config_filename);
/* create downloads directory */ /* create downloads directory */
...@@ -170,13 +195,26 @@ void usage (void) { ...@@ -170,13 +195,26 @@ void usage (void) {
exit (1); exit (1);
} }
extern char *rsa_public_key_name;
extern int verbosity;
extern int default_dc_num;
void args_parse (int argc, char **argv) { void args_parse (int argc, char **argv) {
int opt = 0; int opt = 0;
while ((opt = getopt (argc, argv, "u:h")) != -1) { while ((opt = getopt (argc, argv, "u:hk:vn:")) != -1) {
switch (opt) { switch (opt) {
case 'u': case 'u':
set_default_username (optarg); set_default_username (optarg);
break; break;
case 'k':
rsa_public_key_name = strdup (optarg);
break;
case 'v':
verbosity ++;
break;
case 'n':
default_dc_num = atoi (optarg);
break;
case 'h': case 'h':
default: default:
usage (); usage ();
...@@ -185,7 +223,23 @@ void args_parse (int argc, char **argv) { ...@@ -185,7 +223,23 @@ void args_parse (int argc, char **argv) {
} }
} }
void print_backtrace (void) {
void *buffer[255];
const int calls = backtrace (buffer, sizeof (buffer) / sizeof (void *));
backtrace_symbols_fd (buffer, calls, 1);
exit(EXIT_FAILURE);
}
void sig_handler (int signum) {
set_terminal_attributes ();
printf ("signal %d received\n", signum);
print_backtrace ();
}
int main (int argc, char **argv) { int main (int argc, char **argv) {
signal (SIGSEGV, sig_handler);
signal (SIGABRT, sig_handler);
running_for_first_time (); running_for_first_time ();
get_terminal_attributes (); get_terminal_attributes ();
......
#define _FILE_OFFSET_BITS 64
#include <assert.h>
#include <string.h>
#include <stdio.h>
#include <stdlib.h>
#include <signal.h>
#include <unistd.h>
#include <fcntl.h>
#include <sys/types.h>
#include <aio.h>
#include <netdb.h>
#include <openssl/rand.h>
#include <openssl/rsa.h>
#include <openssl/pem.h>
#include <openssl/sha.h>
#include <sys/socket.h>
#include <netinet/tcp.h>
#include <poll.h>
#include "net.h"
#include "include.h"
#include "queries.h"
#include "loop.h"
#define sha1 SHA1
#include "mtproto-common.h"
#define MAX_NET_RES (1L << 16)
int verbosity;
int auth_success;
enum dc_state c_state;
char nonce[256];
char new_nonce[256];
char server_nonce[256];
int rpc_execute (struct connection *c, int op, int len);
int rpc_becomes_ready (struct connection *c);
int rpc_close (struct connection *c);
struct connection_methods auth_methods = {
.execute = rpc_execute,
.ready = rpc_becomes_ready,
.close = rpc_close
};
long long precise_time;
long long precise_time_rdtsc;
double get_utime (int clock_id) {
struct timespec T;
#if _POSIX_TIMERS
assert (clock_gettime (clock_id, &T) >= 0);
double res = T.tv_sec + (double) T.tv_nsec * 1e-9;
#else
#error "No high-precision clock"
double res = time ();
#endif
if (clock_id == CLOCK_REALTIME) {
precise_time = (long long) (res * (1LL << 32));
precise_time_rdtsc = rdtsc ();
}
return res;
}
#define STATS_BUFF_SIZE (64 << 10)
int stats_buff_len;
char stats_buff[STATS_BUFF_SIZE];
#define MAX_RESPONSE_SIZE (1L << 24)
char Response[MAX_RESPONSE_SIZE];
int Response_len;
/*
*
* STATE MACHINE
*
*/
char *rsa_public_key_name = "id_rsa.pub";
RSA *pubKey;
long long pk_fingerprint;
static int rsa_load_public_key (const char *public_key_name) {
pubKey = NULL;
FILE *f = fopen (public_key_name, "r");
if (f == NULL) {
fprintf (stderr, "Couldn't open public key file: %s\n", public_key_name);
return -1;
}
pubKey = PEM_read_RSAPublicKey (f, NULL, NULL, NULL);
fclose (f);
if (pubKey == NULL) {
fprintf (stderr, "PEM_read_RSAPublicKey returns NULL.\n");
return -1;
}
return 0;
}
int auth_work_start (struct connection *c);
/*
*
* UNAUTHORIZED (DH KEY EXCHANGE) PROTOCOL PART
*
*/
BIGNUM dh_prime, dh_g, g_a, dh_power, auth_key_num;
char s_power [256];
struct {
long long auth_key_id;
long long out_msg_id;
int msg_len;
} unenc_msg_header;
#define ENCRYPT_BUFFER_INTS 16384
int encrypt_buffer[ENCRYPT_BUFFER_INTS];
#define DECRYPT_BUFFER_INTS 16384
int decrypt_buffer[ENCRYPT_BUFFER_INTS];
int encrypt_packet_buffer (void) {
return pad_rsa_encrypt ((char *) packet_buffer, (packet_ptr - packet_buffer) * 4, (char *) encrypt_buffer, ENCRYPT_BUFFER_INTS * 4, pubKey->n, pubKey->e);
}
int encrypt_packet_buffer_aes_unauth (const char server_nonce[16], const char hidden_client_nonce[32]) {
init_aes_unauth (server_nonce, hidden_client_nonce, AES_ENCRYPT);
return pad_aes_encrypt ((char *) packet_buffer, (packet_ptr - packet_buffer) * 4, (char *) encrypt_buffer, ENCRYPT_BUFFER_INTS * 4);
}
int rpc_send_packet (struct connection *c) {
int len = (packet_ptr - packet_buffer) * 4;
c->out_packet_num ++;
long long next_msg_id = (long long) ((1LL << 32) * get_utime (CLOCK_REALTIME)) & -4;
if (next_msg_id <= unenc_msg_header.out_msg_id) {
unenc_msg_header.out_msg_id += 4;
} else {
unenc_msg_header.out_msg_id = next_msg_id;
}
unenc_msg_header.msg_len = len;
int total_len = len + 20;
assert (total_len > 0 && !(total_len & 0xfc000003));
total_len >>= 2;
if (total_len < 0x7f) {
assert (write_out (c, &total_len, 1) == 1);
} else {
total_len = (total_len << 8) | 0x7f;
assert (write_out (c, &total_len, 4) == 4);
}
write_out (c, &unenc_msg_header, 20);
write_out (c, packet_buffer, len);
flush_out (c);
return 1;
}
int rpc_send_message (struct connection *c, void *data, int len) {
assert (len > 0 && !(len & 0xfc000003));
int total_len = len >> 2;
if (total_len < 0x7f) {
assert (write_out (c, &total_len, 1) == 1);
} else {
total_len = (total_len << 8) | 0x7f;
assert (write_out (c, &total_len, 4) == 4);
}
c->out_packet_num ++;
write_out (c, data, len);
flush_out (c);
return 1;
}
int send_req_pq_packet (struct connection *c) {
assert (c_state == st_init);
assert (RAND_pseudo_bytes ((unsigned char *) nonce, 16) >= 0);
unenc_msg_header.out_msg_id = 0;
clear_packet ();
out_int (CODE_req_pq);
out_ints ((int *)nonce, 4);
rpc_send_packet (c);
c_state = st_reqpq_sent;
return 1;
}
unsigned long long gcd (unsigned long long a, unsigned long long b) {
return b ? gcd (b, a % b) : a;
}
//typedef unsigned int uint128_t __attribute__ ((mode(TI)));
unsigned long long what;
unsigned p1, p2;
int process_respq_answer (struct connection *c, char *packet, int len) {
int i;
if (verbosity) {
fprintf (stderr, "process_respq_answer(), len=%d\n", len);
}
assert (len >= 76);
assert (!*(long long *) packet);
assert (*(int *) (packet + 16) == len - 20);
assert (!(len & 3));
assert (*(int *) (packet + 20) == CODE_resPQ);
assert (!memcmp (packet + 24, nonce, 16));
memcpy (server_nonce, packet + 40, 16);
char *from = packet + 56;
int clen = *from++;
assert (clen <= 8);
what = 0;
for (i = 0; i < clen; i++) {
what = (what << 8) + (unsigned char)*from++;
}
while (((unsigned long)from) & 3) ++from;
p1 = 0, p2 = 0;
if (verbosity >= 2) {
fprintf (stderr, "%lld received\n", what);
}
int it = 0;
unsigned long long g = 0;
for (i = 0; i < 3 || it < 1000; i++) {
int q = ((lrand48() & 15) + 17) % what;
unsigned long long x = (long long)lrand48 () % (what - 1) + 1, y = x;
int lim = 1 << (i + 18);
int j;
for (j = 1; j < lim; j++) {
++it;
unsigned long long a = x, b = x, c = q;
while (b) {
if (b & 1) {
c += a;
if (c >= what) {
c -= what;
}
}
a += a;
if (a >= what) {
a -= what;
}
b >>= 1;
}
x = c;
unsigned long long z = x < y ? what + x - y : x - y;
g = gcd (z, what);
if (g != 1) {
break;
}
if (!(j & (j - 1))) {
y = x;
}
}
if (g > 1 && g < what) break;
}
assert (g > 1 && g < what);
p1 = g;
p2 = what / g;
if (p1 > p2) {
unsigned t = p1; p1 = p2; p2 = t;
}
if (verbosity) {
fprintf (stderr, "p1 = %d, p2 = %d, %d iterations\n", p1, p2, it);
}
/// ++p1; ///
assert (*(int *) (from) == CODE_vector);
int fingerprints_num = *(int *)(from + 4);
assert (fingerprints_num >= 1 && fingerprints_num <= 64 && len == fingerprints_num * 8 + 8 + (from - packet));
long long *fingerprints = (long long *) (from + 8);
for (i = 0; i < fingerprints_num; i++) {
if (fingerprints[i] == pk_fingerprint) {
//fprintf (stderr, "found our public key at position %d\n", i);
break;
}
}
if (i == fingerprints_num) {
fprintf (stderr, "fatal: don't have any matching keys (%016llx expected)\n", pk_fingerprint);
exit (2);
}
// create inner part (P_Q_inner_data)
clear_packet ();
packet_ptr += 5;
out_int (CODE_p_q_inner_data);
out_cstring (packet + 57, clen);
//out_int (0x0f01); // pq=15
if (p1 < 256) {
clen = 1;
} else if (p1 < 65536) {
clen = 2;
} else if (p1 < 16777216) {
clen = 3;
} else {
clen = 4;
}
p1 = __builtin_bswap32 (p1);
out_cstring ((char *)&p1 + 4 - clen, clen);
p1 = __builtin_bswap32 (p1);
if (p2 < 256) {
clen = 1;
} else if (p2 < 65536) {
clen = 2;
} else if (p2 < 16777216) {
clen = 3;
} else {
clen = 4;
}
p2 = __builtin_bswap32 (p2);
out_cstring ((char *)&p2 + 4 - clen, clen);
p2 = __builtin_bswap32 (p2);
//out_int (0x0301); // p=3
//out_int (0x0501); // q=5
out_ints ((int *) nonce, 4);
out_ints ((int *) server_nonce, 4);
assert (RAND_pseudo_bytes ((unsigned char *) new_nonce, 32) >= 0);
out_ints ((int *) new_nonce, 8);
sha1 ((unsigned char *) (packet_buffer + 5), (packet_ptr - packet_buffer - 5) * 4, (unsigned char *) packet_buffer);
int l = encrypt_packet_buffer ();
clear_packet ();
out_int (CODE_req_DH_params);
out_ints ((int *) nonce, 4);
out_ints ((int *) server_nonce, 4);
//out_int (0x0301); // p=3
//out_int (0x0501); // q=5
if (p1 < 256) {
clen = 1;
} else if (p1 < 65536) {
clen = 2;
} else if (p1 < 16777216) {
clen = 3;
} else {
clen = 4;
}
p1 = __builtin_bswap32 (p1);
out_cstring ((char *)&p1 + 4 - clen, clen);
p1 = __builtin_bswap32 (p1);
if (p2 < 256) {
clen = 1;
} else if (p2 < 65536) {
clen = 2;
} else if (p2 < 16777216) {
clen = 3;
} else {
clen = 4;
}
p2 = __builtin_bswap32 (p2);
out_cstring ((char *)&p2 + 4 - clen, clen);
p2 = __builtin_bswap32 (p2);
out_long (pk_fingerprint);
out_cstring ((char *) encrypt_buffer, l);
c_state = st_reqdh_sent;
return rpc_send_packet (c);
}
int process_dh_answer (struct connection *c, char *packet, int len) {
if (verbosity) {
fprintf (stderr, "process_dh_answer(), len=%d\n", len);
}
if (len < 116) {
fprintf (stderr, "%u * %u = %llu", p1, p2, what);
}
assert (len >= 116);
assert (!*(long long *) packet);
assert (*(int *) (packet + 16) == len - 20);
assert (!(len & 3));
assert (*(int *) (packet + 20) == (int)CODE_server_DH_params_ok);
assert (!memcmp (packet + 24, nonce, 16));
assert (!memcmp (packet + 40, server_nonce, 16));
init_aes_unauth (server_nonce, new_nonce, AES_DECRYPT);
in_ptr = (int *)(packet + 56);
in_end = (int *)(packet + len);
int l = prefetch_strlen ();
assert (l > 0);
l = pad_aes_decrypt (fetch_str (l), l, (char *) decrypt_buffer, DECRYPT_BUFFER_INTS * 4 - 16);
assert (in_ptr == in_end);
assert (l >= 60);
assert (decrypt_buffer[5] == (int)CODE_server_DH_inner_data);
assert (!memcmp (decrypt_buffer + 6, nonce, 16));
assert (!memcmp (decrypt_buffer + 10, server_nonce, 16));
assert (decrypt_buffer[14] == 2);
in_ptr = decrypt_buffer + 15;
in_end = decrypt_buffer + (l >> 2);
BN_init (&dh_prime);
BN_init (&g_a);
assert (fetch_bignum (&dh_prime) > 0);
assert (fetch_bignum (&g_a) > 0);
int server_time = *in_ptr++;
assert (in_ptr <= in_end);
static char sha1_buffer[20];
sha1 ((unsigned char *) decrypt_buffer + 20, (in_ptr - decrypt_buffer - 5) * 4, (unsigned char *) sha1_buffer);
assert (!memcmp (decrypt_buffer, sha1_buffer, 20));
assert ((char *) in_end - (char *) in_ptr < 16);
GET_DC(c)->server_time_delta = server_time - time (0);
GET_DC(c)->server_time_udelta = server_time - get_utime (CLOCK_MONOTONIC);
//fprintf (stderr, "server time is %d, delta = %d\n", server_time, server_time_delta);
// Build set_client_DH_params answer
clear_packet ();
packet_ptr += 5;
out_int (CODE_client_DH_inner_data);
out_ints ((int *) nonce, 4);
out_ints ((int *) server_nonce, 4);
out_long (0LL);
BN_init (&dh_g);
BN_set_word (&dh_g, 2);
assert (RAND_pseudo_bytes ((unsigned char *)s_power, 256) >= 0);
BIGNUM *dh_power = BN_new ();
assert (BN_bin2bn ((unsigned char *)s_power, 256, dh_power) == dh_power);
BIGNUM *y = BN_new ();
assert (BN_mod_exp (y, &dh_g, dh_power, &dh_prime, BN_ctx) == 1);
out_bignum (y);
BN_free (y);
BN_init (&auth_key_num);
assert (BN_mod_exp (&auth_key_num, &g_a, dh_power, &dh_prime, BN_ctx) == 1);
l = BN_num_bytes (&auth_key_num);
assert (l >= 250 && l <= 256);
assert (BN_bn2bin (&auth_key_num, (unsigned char *)GET_DC(c)->auth_key));
memset (GET_DC(c)->auth_key + l, 0, 256 - l);
BN_free (dh_power);
BN_free (&auth_key_num);
BN_free (&dh_g);
BN_free (&g_a);
BN_free (&dh_prime);
//hexdump (auth_key, auth_key + 256);
sha1 ((unsigned char *) (packet_buffer + 5), (packet_ptr - packet_buffer - 5) * 4, (unsigned char *) packet_buffer);
//hexdump ((char *)packet_buffer, (char *)packet_ptr);
l = encrypt_packet_buffer_aes_unauth (server_nonce, new_nonce);
clear_packet ();
out_int (CODE_set_client_DH_params);
out_ints ((int *) nonce, 4);
out_ints ((int *) server_nonce, 4);
out_cstring ((char *) encrypt_buffer, l);
c_state = st_client_dh_sent;
return rpc_send_packet (c);
}
int process_auth_complete (struct connection *c UU, char *packet, int len) {
if (verbosity) {
fprintf (stderr, "process_dh_answer(), len=%d\n", len);
}
assert (len == 72);
assert (!*(long long *) packet);
assert (*(int *) (packet + 16) == len - 20);
assert (!(len & 3));
assert (*(int *) (packet + 20) == CODE_dh_gen_ok);
assert (!memcmp (packet + 24, nonce, 16));
assert (!memcmp (packet + 40, server_nonce, 16));
static unsigned char tmp[44], sha1_buffer[20];
memcpy (tmp, new_nonce, 32);
tmp[32] = 1;
sha1 ((unsigned char *)GET_DC(c)->auth_key, 256, sha1_buffer);
GET_DC(c)->auth_key_id = *(long long *)(sha1_buffer + 12);
memcpy (tmp + 33, sha1_buffer, 8);
sha1 (tmp, 41, sha1_buffer);
assert (!memcmp (packet + 56, sha1_buffer + 4, 16));
GET_DC(c)->server_salt = *(long long *)server_nonce ^ *(long long *)new_nonce;
if (verbosity >= 3) {
fprintf (stderr, "auth_key_id=%016llx\n", GET_DC(c)->auth_key_id);
}
//kprintf ("OK\n");
//c->status = conn_error;
//sleep (1);
c_state = st_authorized;
//return 1;
if (verbosity) {
fprintf (stderr, "Auth success\n");
}
auth_success ++;
GET_DC(c)->flags |= 1;
write_auth_file ();
return 1;
}
/*
*
* AUTHORIZED (MAIN) PROTOCOL PART
*
*/
struct encrypted_message enc_msg;
long long client_last_msg_id, server_last_msg_id;
double get_server_time (struct dc *DC) {
if (!DC->server_time_udelta) {
DC->server_time_udelta = get_utime (CLOCK_REALTIME) - get_utime (CLOCK_MONOTONIC);
}
return get_utime (CLOCK_MONOTONIC) + DC->server_time_udelta;
}
long long generate_next_msg_id (struct dc *DC) {
long long next_id = (long long) (get_server_time (DC) * (1LL << 32)) & -4;
if (next_id <= client_last_msg_id) {
next_id = client_last_msg_id += 4;
} else {
client_last_msg_id = next_id;
}
return next_id;
}
void init_enc_msg (struct session *S, int useful) {
struct dc *DC = S->dc;
assert (DC->auth_key_id);
enc_msg.auth_key_id = DC->auth_key_id;
assert (DC->server_salt);
enc_msg.server_salt = DC->server_salt;
if (!S->session_id) {
assert (RAND_pseudo_bytes ((unsigned char *) &S->session_id, 8) >= 0);
}
enc_msg.session_id = S->session_id;
//enc_msg.auth_key_id2 = auth_key_id;
enc_msg.msg_id = generate_next_msg_id (DC);
//enc_msg.msg_id -= 0x10000000LL * (lrand48 () & 15);
//kprintf ("message id %016llx\n", enc_msg.msg_id);
enc_msg.seq_no = S->seq_no;
if (useful) {
enc_msg.seq_no |= 1;
}
S->seq_no += 2;
};
int aes_encrypt_message (struct dc *DC, struct encrypted_message *enc) {
unsigned char sha1_buffer[20];
const int MINSZ = offsetof (struct encrypted_message, message);
const int UNENCSZ = offsetof (struct encrypted_message, server_salt);
int enc_len = (MINSZ - UNENCSZ) + enc->msg_len;
assert (enc->msg_len >= 0 && enc->msg_len <= MAX_MESSAGE_INTS * 4 - 16 && !(enc->msg_len & 3));
sha1 ((unsigned char *) &enc->server_salt, enc_len, sha1_buffer);
//printf ("enc_len is %d\n", enc_len);
if (verbosity >= 2) {
fprintf (stderr, "sending message with sha1 %08x\n", *(int *)sha1_buffer);
}
memcpy (enc->msg_key, sha1_buffer + 4, 16);
init_aes_auth (DC->auth_key, enc->msg_key, AES_ENCRYPT);
//hexdump ((char *)enc, (char *)enc + enc_len + 24);
return pad_aes_encrypt ((char *) &enc->server_salt, enc_len, (char *) &enc->server_salt, MAX_MESSAGE_INTS * 4 + (MINSZ - UNENCSZ));
}
long long encrypt_send_message (struct connection *c, int *msg, int msg_ints, int useful) {
struct dc *DC = GET_DC(c);
struct session *S = c->session;
assert (S);
const int UNENCSZ = offsetof (struct encrypted_message, server_salt);
if (msg_ints <= 0 || msg_ints > MAX_MESSAGE_INTS - 4) {
return -1;
}
if (msg) {
memcpy (enc_msg.message, msg, msg_ints * 4);
enc_msg.msg_len = msg_ints * 4;
} else {
if ((enc_msg.msg_len & 0x80000003) || enc_msg.msg_len > MAX_MESSAGE_INTS * 4 - 16) {
return -1;
}
}
init_enc_msg (S, useful);
//hexdump ((char *)msg, (char *)msg + (msg_ints * 4));
int l = aes_encrypt_message (DC, &enc_msg);
//hexdump ((char *)&enc_msg, (char *)&enc_msg + l + 24);
assert (l > 0);
rpc_send_message (c, &enc_msg, l + UNENCSZ);
return client_last_msg_id;
}
int longpoll_count, good_messages;
int auth_work_start (struct connection *c UU) {
return 1;
}
void rpc_execute_answer (struct connection *c, long long msg_id UU);
void work_container (struct connection *c, long long msg_id UU) {
if (verbosity) {
fprintf (stderr, "work_container: msg_id = %lld\n", msg_id);
}
assert (fetch_int () == CODE_msg_container);
int n = fetch_int ();
int i;
for (i = 0; i < n; i++) {
long long id = fetch_long ();
int seqno = fetch_int ();
if (seqno & 1) {
insert_seqno (c->session, seqno);
}
int bytes = fetch_int ();
int *t = in_ptr;
rpc_execute_answer (c, id);
assert (in_ptr == t + (bytes / 4));
}
}
void work_new_session_created (struct connection *c, long long msg_id UU) {
if (verbosity) {
fprintf (stderr, "work_new_session_created: msg_id = %lld\n", msg_id);
}
assert (fetch_int () == (int)CODE_new_session_created);
fetch_long (); // first message id
//DC->session_id = fetch_long ();
fetch_long (); // unique_id
GET_DC(c)->server_salt = fetch_long ();
}
void work_msgs_ack (struct connection *c UU, long long msg_id UU) {
if (verbosity) {
fprintf (stderr, "work_msgs_ack: msg_id = %lld\n", msg_id);
}
assert (fetch_int () == CODE_msgs_ack);
assert (fetch_int () == CODE_vector);
int n = fetch_int ();
int i;
for (i = 0; i < n; i++) {
long long id = fetch_long ();
query_ack (id);
}
}
void work_rpc_result (struct connection *c UU, long long msg_id UU) {
if (verbosity) {
fprintf (stderr, "work_rpc_result: msg_id = %lld\n", msg_id);
}
assert (fetch_int () == (int)CODE_rpc_result);
long long id = fetch_long ();
int op = prefetch_int ();
if (op == CODE_rpc_error) {
query_error (id);
} else {
query_result (id);
}
}
void rpc_execute_answer (struct connection *c, long long msg_id UU) {
int op = prefetch_int ();
switch (op) {
case CODE_msg_container:
work_container (c, msg_id);
return;
case CODE_new_session_created:
work_new_session_created (c, msg_id);
return;
case CODE_msgs_ack:
work_msgs_ack (c, msg_id);
return;
case CODE_rpc_result:
work_rpc_result (c, msg_id);
return;
}
fprintf (stderr, "Unknown message: \n");
hexdump_in ();
}
int process_rpc_message (struct connection *c UU, struct encrypted_message *enc, int len) {
const int MINSZ = offsetof (struct encrypted_message, message);
const int UNENCSZ = offsetof (struct encrypted_message, server_salt);
if (verbosity) {
fprintf (stderr, "process_rpc_message(), len=%d\n", len);
}
assert (len >= MINSZ && (len & 15) == (UNENCSZ & 15));
struct dc *DC = GET_DC(c);
assert (enc->auth_key_id == DC->auth_key_id);
assert (DC->auth_key_id);
init_aes_auth (DC->auth_key + 8, enc->msg_key, AES_DECRYPT);
int l = pad_aes_decrypt ((char *)&enc->server_salt, len - UNENCSZ, (char *)&enc->server_salt, len - UNENCSZ);
assert (l == len - UNENCSZ);
//assert (enc->auth_key_id2 == enc->auth_key_id);
assert (!(enc->msg_len & 3) && enc->msg_len > 0 && enc->msg_len <= len - MINSZ && len - MINSZ - enc->msg_len <= 12);
static unsigned char sha1_buffer[20];
sha1 ((void *)&enc->server_salt, enc->msg_len + (MINSZ - UNENCSZ), sha1_buffer);
assert (!memcmp (&enc->msg_key, sha1_buffer + 4, 16));
//assert (enc->server_salt == server_salt); //in fact server salt can change
if (DC->server_salt != enc->server_salt) {
DC->server_salt = enc->server_salt;
write_auth_file ();
}
int this_server_time = enc->msg_id >> 32LL;
double st = get_server_time (DC);
assert (this_server_time >= st - 300 && this_server_time <= st + 30);
//assert (enc->msg_id > server_last_msg_id && (enc->msg_id & 3) == 1);
if (verbosity >= 2) {
fprintf (stderr, "received mesage id %016llx\n", enc->msg_id);
}
server_last_msg_id = enc->msg_id;
//*(long long *)(longpoll_query + 3) = *(long long *)((char *)(&enc->msg_id) + 0x3c);
//*(long long *)(longpoll_query + 5) = *(long long *)((char *)(&enc->msg_id) + 0x3c);
assert (l >= (MINSZ - UNENCSZ) + 8);
//assert (enc->message[0] == CODE_rpc_result && *(long long *)(enc->message + 1) == client_last_msg_id);
if (verbosity >= 2) {
fprintf (stderr, "OK, message is good!\n");
}
++good_messages;
in_ptr = enc->message;
in_end = in_ptr + (enc->msg_len / 4);
if (enc->seq_no & 1) {
insert_seqno (c->session, enc->seq_no);
}
assert (c->session->session_id == enc->session_id);
rpc_execute_answer (c, enc->msg_id);
return 0;
}
int rpc_execute (struct connection *c, int op, int len) {
if (verbosity) {
fprintf (stderr, "outbound rpc connection #%d : received rpc answer %d with %d content bytes\n", c->fd, op, len);
}
if (len >= MAX_RESPONSE_SIZE/* - 12*/ || len < 0/*12*/) {
fprintf (stderr, "answer too long (%d bytes), skipping\n", len);
return 0;
}
int Response_len = len;
assert (read_in (c, Response, Response_len) == Response_len);
Response[Response_len] = 0;
if (verbosity >= 2) {
fprintf (stderr, "have %d Response bytes\n", Response_len);
}
setsockopt (c->fd, IPPROTO_TCP, TCP_QUICKACK, (int[]){0}, 4);
int o = c_state;
if (GET_DC(c)->flags & 1) { o = st_authorized;}
switch (o) {
case st_reqpq_sent:
process_respq_answer (c, Response/* + 8*/, Response_len/* - 12*/);
setsockopt (c->fd, IPPROTO_TCP, TCP_QUICKACK, (int[]){0}, 4);
return 0;
case st_reqdh_sent:
process_dh_answer (c, Response/* + 8*/, Response_len/* - 12*/);
setsockopt (c->fd, IPPROTO_TCP, TCP_QUICKACK, (int[]){0}, 4);
return 0;
case st_client_dh_sent:
process_auth_complete (c, Response/* + 8*/, Response_len/* - 12*/);
setsockopt (c->fd, IPPROTO_TCP, TCP_QUICKACK, (int[]){0}, 4);
return 0;
case st_authorized:
process_rpc_message (c, (void *)(Response/* + 8*/), Response_len/* - 12*/);
setsockopt (c->fd, IPPROTO_TCP, TCP_QUICKACK, (int[]){0}, 4);
return 0;
default:
fprintf (stderr, "fatal: cannot receive answer in state %d\n", c_state);
exit (2);
}
return 0;
}
int tc_close (struct connection *c, int who) {
if (verbosity) {
fprintf (stderr, "outbound http connection #%d : closing by %d\n", c->fd, who);
}
return 0;
}
int tc_becomes_ready (struct connection *c) {
if (verbosity) {
fprintf (stderr, "outbound connection #%d becomes ready\n", c->fd);
}
char byte = 0xef;
assert (write_out (c, &byte, 1) == 1);
flush_out (c);
setsockopt (c->fd, IPPROTO_TCP, TCP_QUICKACK, (int[]){0}, 4);
int o = c_state;
if (GET_DC(c)->flags & 1) { o = st_authorized; }
switch (o) {
case st_init:
send_req_pq_packet (c);
break;
case st_authorized:
auth_work_start (c);
break;
default:
fprintf (stderr, "c_state = %d\n", c_state);
assert (0);
}
return 0;
}
int rpc_becomes_ready (struct connection *c) {
return tc_becomes_ready (c);
}
int rpc_close (struct connection *c) {
return tc_close (c, 0);
}
int auth_is_success (void) {
return auth_success;
}
void on_start (void) {
prng_seed (0, 0);
if (rsa_load_public_key (rsa_public_key_name) < 0) {
perror ("rsa_load_public_key");
exit (1);
}
if (verbosity) {
fprintf (stderr, "public key '%s' loaded successfully\n", rsa_public_key_name);
}
pk_fingerprint = compute_rsa_key_fingerprint (pubKey);
}
int auth_ok (void) {
return auth_success;
}
void dc_authorize (struct dc *DC) {
c_state = 0;
auth_success = 0;
if (!DC->sessions[0]) {
dc_create_session (DC);
}
if (verbosity) {
fprintf (stderr, "Starting authorization for DC #%d: %s:%d\n", DC->id, DC->ip, DC->port);
}
net_loop (0, auth_ok);
}
#ifndef __MTPROTO_CLIENT_H__
#define __MTPROTO_CLIENT_H__
#include "net.h"
void on_start (void);
long long encrypt_send_message (struct connection *c, int *msg, int msg_ints, int useful);
void dc_authorize (struct dc *DC);
#endif
#define _FILE_OFFSET_BITS 64
#include <assert.h>
#include <string.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <fcntl.h>
#include <sys/types.h>
#include <aio.h>
#include <netdb.h>
#include <openssl/bn.h>
#include <openssl/rand.h>
#include <openssl/pem.h>
#include <openssl/aes.h>
#include <openssl/sha.h>
#include "mtproto-common.h"
long long rsa_encrypted_chunks, rsa_decrypted_chunks;
BN_CTX *BN_ctx;
int verbosity;
int get_random_bytes (void *buf, int n) {
int r = 0, h = open ("/dev/random", O_RDONLY | O_NONBLOCK);
if (h >= 0) {
r = read (h, buf, n);
if (r > 0) {
if (verbosity >= 3) {
fprintf (stderr, "added %d bytes of real entropy to secure random numbers seed\n", r);
}
}
close (h);
}
if (r < n) {
h = open ("/dev/urandom", O_RDONLY);
if (h < 0) {
return r;
}
int s = read (h, buf + r, n - r);
close (h);
if (s < 0) {
return r;
}
r += s;
}
if (r >= (int)sizeof (long)) {
*(long *)buf ^= lrand48 ();
srand48 (*(long *)buf);
}
return r;
}
void prng_seed (const char *password_filename, int password_length) {
unsigned char *a = calloc (64 + password_length, 1);
assert (a != NULL);
long long r = rdtsc ();
struct timespec T;
assert (clock_gettime(CLOCK_REALTIME, &T) >= 0);
memcpy (a, &T.tv_sec, 4);
memcpy (a+4, &T.tv_nsec, 4);
memcpy (a+8, &r, 8);
unsigned short p = getpid ();
memcpy (a + 16, &p, 2);
int s = get_random_bytes (a + 18, 32) + 18;
if (password_filename) {
int fd = open (password_filename, O_RDONLY);
if (fd < 0) {
fprintf (stderr, "Warning: fail to open password file - \"%s\", %m.\n", password_filename);
} else {
int l = read (fd, a + s, password_length);
if (l < 0) {
fprintf (stderr, "Warning: fail to read password file - \"%s\", %m.\n", password_filename);
} else {
if (verbosity > 0) {
fprintf (stderr, "read %d bytes from password file.\n", l);
}
s += l;
}
close (fd);
}
}
RAND_seed (a, s);
BN_ctx = BN_CTX_new ();
memset (a, 0, s);
free (a);
}
int serialize_bignum (BIGNUM *b, char *buffer, int maxlen) {
int itslen = BN_num_bytes (b);
int reqlen;
if (itslen < 254) {
reqlen = itslen + 1;
} else {
reqlen = itslen + 4;
}
int newlen = (reqlen + 3) & -4;
int pad = newlen - reqlen;
reqlen = newlen;
if (reqlen > maxlen) {
return -reqlen;
}
if (itslen < 254) {
*buffer++ = itslen;
} else {
*(int *)buffer = (itslen << 8) + 0xfe;
buffer += 4;
}
int l = BN_bn2bin (b, (unsigned char *)buffer);
assert (l == itslen);
buffer += l;
while (pad --> 0) {
*buffer++ = 0;
}
return reqlen;
}
long long compute_rsa_key_fingerprint (RSA *key) {
static char tempbuff[4096];
static unsigned char sha[20];
assert (key->n && key->e);
int l1 = serialize_bignum (key->n, tempbuff, 4096);
assert (l1 > 0);
int l2 = serialize_bignum (key->e, tempbuff + l1, 4096 - l1);
assert (l2 > 0 && l1 + l2 <= 4096);
SHA1 ((unsigned char *)tempbuff, l1 + l2, sha);
return *(long long *)(sha + 12);
}
void out_cstring (const char *str, long len) {
assert (len >= 0 && len < (1 << 24));
assert ((char *) packet_ptr + len + 8 < (char *) (packet_buffer + PACKET_BUFFER_SIZE));
char *dest = (char *) packet_ptr;
if (len < 254) {
*dest++ = len;
} else {
*packet_ptr = (len << 8) + 0xfe;
dest += 4;
}
memcpy (dest, str, len);
dest += len;
while ((long) dest & 3) {
*dest++ = 0;
}
packet_ptr = (int *) dest;
}
void out_cstring_careful (const char *str, long len) {
assert (len >= 0 && len < (1 << 24));
assert ((char *) packet_ptr + len + 8 < (char *) (packet_buffer + PACKET_BUFFER_SIZE));
char *dest = (char *) packet_ptr;
if (len < 254) {
dest++;
if (dest != str) {
memmove (dest, str, len);
}
dest[-1] = len;
} else {
dest += 4;
if (dest != str) {
memmove (dest, str, len);
}
*packet_ptr = (len << 8) + 0xfe;
}
dest += len;
while ((long) dest & 3) {
*dest++ = 0;
}
packet_ptr = (int *) dest;
}
void out_data (const char *data, long len) {
assert (len >= 0 && len < (1 << 24) && !(len & 3));
assert ((char *) packet_ptr + len + 8 < (char *) (packet_buffer + PACKET_BUFFER_SIZE));
memcpy (packet_ptr, data, len);
packet_ptr += len >> 2;
}
int *in_ptr, *in_end;
int fetch_bignum (BIGNUM *x) {
int l = prefetch_strlen ();
if (l < 0) {
return l;
}
char *str = fetch_str (l);
assert (BN_bin2bn ((unsigned char *) str, l, x) == x);
return l;
}
int pad_rsa_encrypt (char *from, int from_len, char *to, int size, BIGNUM *N, BIGNUM *E) {
int pad = (255000 - from_len - 32) % 255 + 32;
int chunks = (from_len + pad) / 255;
int bits = BN_num_bits (N);
assert (bits >= 2041 && bits <= 2048);
assert (from_len > 0 && from_len <= 2550);
assert (size >= chunks * 256);
assert (RAND_pseudo_bytes ((unsigned char *) from + from_len, pad) >= 0);
int i;
BIGNUM x, y;
BN_init (&x);
BN_init (&y);
rsa_encrypted_chunks += chunks;
for (i = 0; i < chunks; i++) {
BN_bin2bn ((unsigned char *) from, 255, &x);
assert (BN_mod_exp (&y, &x, E, N, BN_ctx) == 1);
unsigned l = 256 - BN_num_bytes (&y);
assert (l <= 256);
memset (to, 0, l);
BN_bn2bin (&y, (unsigned char *) to + l);
to += 256;
}
BN_free (&x);
BN_free (&y);
return chunks * 256;
}
int pad_rsa_decrypt (char *from, int from_len, char *to, int size, BIGNUM *N, BIGNUM *D) {
if (from_len < 0 || from_len > 0x1000 || (from_len & 0xff)) {
return -1;
}
int chunks = (from_len >> 8);
int bits = BN_num_bits (N);
assert (bits >= 2041 && bits <= 2048);
assert (size >= chunks * 255);
int i;
BIGNUM x, y;
BN_init (&x);
BN_init (&y);
for (i = 0; i < chunks; i++) {
++rsa_decrypted_chunks;
BN_bin2bn ((unsigned char *) from, 256, &x);
assert (BN_mod_exp (&y, &x, D, N, BN_ctx) == 1);
int l = BN_num_bytes (&y);
if (l > 255) {
BN_free (&x);
BN_free (&y);
return -1;
}
assert (l >= 0 && l <= 255);
memset (to, 0, 255 - l);
BN_bn2bin (&y, (unsigned char *) to + 255 - l);
to += 255;
}
BN_free (&x);
BN_free (&y);
return chunks * 255;
}
unsigned char aes_key_raw[32], aes_iv[32];
AES_KEY aes_key;
void init_aes_unauth (const char server_nonce[16], const char hidden_client_nonce[32], int encrypt) {
static unsigned char buffer[64], hash[20];
memcpy (buffer, hidden_client_nonce, 32);
memcpy (buffer + 32, server_nonce, 16);
SHA1 (buffer, 48, aes_key_raw);
memcpy (buffer + 32, hidden_client_nonce, 32);
SHA1 (buffer, 64, aes_iv + 8);
memcpy (buffer, server_nonce, 16);
memcpy (buffer + 16, hidden_client_nonce, 32);
SHA1 (buffer, 48, hash);
memcpy (aes_key_raw + 20, hash, 12);
memcpy (aes_iv, hash + 12, 8);
memcpy (aes_iv + 28, hidden_client_nonce, 4);
if (encrypt == AES_ENCRYPT) {
AES_set_encrypt_key (aes_key_raw, 32*8, &aes_key);
} else {
AES_set_decrypt_key (aes_key_raw, 32*8, &aes_key);
}
}
void init_aes_auth (char auth_key[192], char msg_key[16], int encrypt) {
static unsigned char buffer[48], hash[20];
// sha1_a = SHA1 (msg_key + substr (auth_key, 0, 32));
// sha1_b = SHA1 (substr (auth_key, 32, 16) + msg_key + substr (auth_key, 48, 16));
// sha1_с = SHA1 (substr (auth_key, 64, 32) + msg_key);
// sha1_d = SHA1 (msg_key + substr (auth_key, 96, 32));
// aes_key = substr (sha1_a, 0, 8) + substr (sha1_b, 8, 12) + substr (sha1_c, 4, 12);
// aes_iv = substr (sha1_a, 8, 12) + substr (sha1_b, 0, 8) + substr (sha1_c, 16, 4) + substr (sha1_d, 0, 8);
memcpy (buffer, msg_key, 16);
memcpy (buffer + 16, auth_key, 32);
SHA1 (buffer, 48, hash);
memcpy (aes_key_raw, hash, 8);
memcpy (aes_iv, hash + 8, 12);
memcpy (buffer, auth_key + 32, 16);
memcpy (buffer + 16, msg_key, 16);
memcpy (buffer + 32, auth_key + 48, 16);
SHA1 (buffer, 48, hash);
memcpy (aes_key_raw + 8, hash + 8, 12);
memcpy (aes_iv + 12, hash, 8);
memcpy (buffer, auth_key + 64, 32);
memcpy (buffer + 32, msg_key, 16);
SHA1 (buffer, 48, hash);
memcpy (aes_key_raw + 20, hash + 4, 12);
memcpy (aes_iv + 20, hash + 16, 4);
memcpy (buffer, msg_key, 16);
memcpy (buffer + 16, auth_key + 96, 32);
SHA1 (buffer, 48, hash);
memcpy (aes_iv + 24, hash, 8);
if (encrypt == AES_ENCRYPT) {
AES_set_encrypt_key (aes_key_raw, 32*8, &aes_key);
} else {
AES_set_decrypt_key (aes_key_raw, 32*8, &aes_key);
}
}
int pad_aes_encrypt (char *from, int from_len, char *to, int size) {
int padded_size = (from_len + 15) & -16;
assert (from_len > 0 && padded_size <= size);
if (from_len < padded_size) {
assert (RAND_pseudo_bytes ((unsigned char *) from + from_len, padded_size - from_len) >= 0);
}
AES_ige_encrypt ((unsigned char *) from, (unsigned char *) to, padded_size, &aes_key, aes_iv, AES_ENCRYPT);
return padded_size;
}
int pad_aes_decrypt (char *from, int from_len, char *to, int size) {
if (from_len <= 0 || from_len > size || (from_len & 15)) {
return -1;
}
AES_ige_encrypt ((unsigned char *) from, (unsigned char *) to, from_len, &aes_key, aes_iv, AES_DECRYPT);
return from_len;
}
#ifndef __MTPROTO_COMMON_H__
#define __MTPROTO_COMMON_H__
#include <string.h>
#include <openssl/rsa.h>
#include <openssl/bn.h>
#include <openssl/aes.h>
#include <stdio.h>
/* DH key exchange protocol data structures */
#define CODE_req_pq 0x60469778
#define CODE_resPQ 0x05162463
#define CODE_req_DH_params 0xd712e4be
#define CODE_p_q_inner_data 0x83c95aec
#define CODE_server_DH_inner_data 0xb5890dba
#define CODE_server_DH_params_fail 0x79cb045d
#define CODE_server_DH_params_ok 0xd0e8075c
#define CODE_set_client_DH_params 0xf5045f1f
#define CODE_client_DH_inner_data 0x6643b654
#define CODE_dh_gen_ok 0x3bcbf734
#define CODE_dh_gen_retry 0x46dc1fb9
#define CODE_dh_gen_fail 0xa69dae02
/* generic data structures */
#define CODE_vector_long 0xc734a64e
#define CODE_vector_int 0xa03855ae
#define CODE_vector_Object 0xa351ae8e
#define CODE_vector 0x1cb5c415
/* service messages */
#define CODE_rpc_result 0xf35c6d01
#define CODE_rpc_error 0x2144ca19
#define CODE_msg_container 0x73f1f8dc
#define CODE_msg_copy 0xe06046b2
#define CODE_http_wait 0x9299359f
#define CODE_msgs_ack 0x62d6b459
#define CODE_bad_msg_notification 0xa7eff811
#define CODE_bad_server_salt 0xedab447b
#define CODE_msgs_state_req 0xda69fb52
#define CODE_msgs_state_info 0x04deb57d
#define CODE_msgs_all_info 0x8cc0d131
#define CODE_new_session_created 0x9ec20908
#define CODE_msg_resend_req 0x7d861a08
#define CODE_ping 0x7abe77ec
#define CODE_pong 0x347773c5
#define CODE_destroy_session 0xe7512126
#define CODE_destroy_session_ok 0xe22045fc
#define CODE_destroy_session_none 0x62d350c9
#define CODE_destroy_sessions 0x9a6face8
#define CODE_destroy_sessions_res 0xa8164668
#define CODE_get_future_salts 0xb921bd04
#define CODE_future_salt 0x0949d9dc
#define CODE_future_salts 0xae500895
#define CODE_rpc_drop_answer 0x58e4a740
#define CODE_rpc_answer_unknown 0x5e2ad36e
#define CODE_rpc_answer_dropped_running 0xcd78e586
#define CODE_rpc_answer_dropped 0xa43ad8b7
#define CODE_msg_detailed_info 0x276d3ec6
#define CODE_msg_new_detailed_info 0x809db6df
#define CODE_ping_delay_disconnect 0xf3427b8c
/* sample rpc query/response structures */
#define CODE_getUser 0xb0f732d5
#define CODE_getUsers 0x2d84d5f5
#define CODE_user 0xd23c81a3
#define CODE_no_user 0xc67599d1
#define CODE_msgs_random 0x12345678
#define CODE_random_msg 0x87654321
#define RPC_INVOKE_REQ 0x2374df3d
#define RPC_INVOKE_KPHP_REQ 0x99a37fda
#define RPC_REQ_RUNNING 0x346d5efa
#define RPC_REQ_ERROR 0x7ae432f5
#define RPC_REQ_RESULT 0x63aeda4e
#define RPC_READY 0x6a34cac7
#define RPC_STOP_READY 0x59d86654
#define RPC_SEND_SESSION_MSG 0x1ed5a3cc
#define RPC_RESPONSE_INDIRECT 0x2194f56e
/* RPC for workers */
#define CODE_send_session_msg 0x81bb412c
#define CODE_sendMsgOk 0x29841ee2
#define CODE_sendMsgNoSession 0x2b2b9e78
#define CODE_sendMsgFailed 0x4b0cbd57
#define CODE_get_auth_sessions 0x611f7845
#define CODE_authKeyNone 0x8a8bc1f3
#define CODE_authKeySessions 0x6b7f026c
#define CODE_add_session_box 0xe707e295
#define CODE_set_session_box 0x193d4231
#define CODE_replace_session_box 0xcb101b49
#define CODE_replace_session_box_cas 0xb2bbfa78
#define CODE_delete_session_box 0x01b78d81
#define CODE_delete_session_box_cas 0xb3fdc3c5
#define CODE_session_box_no_session 0x43f46c33
#define CODE_session_box_created 0xe1dd5d40
#define CODE_session_box_replaced 0xbd9cb6b2
#define CODE_session_box_deleted 0xaf8fd05e
#define CODE_session_box_not_found 0xb3560a7f
#define CODE_session_box_found 0x560fe356
#define CODE_session_box_changed 0x014b31b8
#define CODE_get_session_box 0x8793a924
#define CODE_get_session_box_cond 0x7888fab6
#define CODE_session_box_session_absent 0x9e234062
#define CODE_session_box_absent 0xa1a106eb
#define CODE_session_box 0x7956cd97
#define CODE_session_box_large 0xb568d189
#define CODE_get_sessions_activity 0x059dc5f6
#define CODE_sessions_activities 0x60ce5b1d
#define CODE_get_session_activity 0x96dbac11
#define CODE_session_activity 0xe175e8e0
/* RPC for front/proxy */
#define RPC_FRONT 0x27a456f3
#define RPC_FRONT_ACK 0x624abd23
#define RPC_FRONT_ERR 0x71dda175
#define RPC_PROXY_REQ 0x36cef1ee
#define RPC_PROXY_ANS 0x4403da0d
#define RPC_CLOSE_CONN 0x1fcf425d
#define RPC_CLOSE_EXT 0x5eb634a2
#define RPC_SIMPLE_ACK 0x3bac409b
#define CODE_auth_send_code 0xd16ff372
#define CODE_auth_sent_code 0x2215bcbd
#define CODE_help_get_config 0xc4f9186b
#define CODE_config 0x232d5905
#define CODE_dc_option 0x2ec2a43c
#define CODE_bool_false 0xbc799737
#define CODE_bool_true 0x997275b5
#define CODE_user_self 0x720535ec
#define CODE_auth_authorization 0xf6b673a4
#define CODE_user_profile_photo_empty 0x4f11bae1
#define CODE_user_profile_photo 0x990d1493
#define CODE_user_status_empty 0x9d05049
#define CODE_user_status_online 0xedb93949
#define CODE_user_status_offline 0x8c703f
#define CODE_sign_in 0xbcd51581
#define CODE_file_location 0x53d69076
#define CODE_file_location_unavailable 0x7c596b46
#define CODE_contacts_get_contacts 0x22c6aa08
#define CODE_contacts_contacts 0x6f8b8cb2
#define CODE_contact 0xf911c994
#define CODE_user_empty 0x200250ba
#define CODE_user_contact 0xf2fb8319
#define CODE_user_request 0x22e8ceb0
#define CODE_user_foreign 0x5214c89d
#define CODE_user_deleted 0xb29ad7cc
#define CODE_gzip_packed 0x3072cfa1
/* not really a limit, for struct encrypted_message only */
// #define MAX_MESSAGE_INTS 16384
#define MAX_MESSAGE_INTS 1048576
#define MAX_PROTO_MESSAGE_INTS 1048576
#pragma pack(push,4)
struct encrypted_message {
// unencrypted header
long long auth_key_id;
char msg_key[16];
// encrypted part, starts with encrypted header
long long server_salt;
long long session_id;
// long long auth_key_id2; // removed
// first message follows
long long msg_id;
int seq_no;
int msg_len; // divisible by 4
int message[MAX_MESSAGE_INTS];
};
struct worker_descr {
int addr;
int port;
int pid;
int start_time;
int id;
};
struct rpc_ready_packet {
int len;
int seq_num;
int type;
struct worker_descr worker;
int worker_ready_cnt;
int crc32;
};
struct front_descr {
int addr;
int port;
int pid;
int start_time;
int id;
};
struct rpc_front_packet {
int len;
int seq_num;
int type;
struct front_descr front;
long long hash_mult;
int rem, mod;
int crc32;
};
struct middle_descr {
int addr;
int port;
int pid;
int start_time;
int id;
};
struct rpc_front_ack {
int len;
int seq_num;
int type;
struct middle_descr middle;
int crc32;
};
struct rpc_front_err {
int len;
int seq_num;
int type;
int errcode;
struct middle_descr middle;
long long hash_mult;
int rem, mod;
int crc32;
};
struct rpc_proxy_req {
int len;
int seq_num;
int type;
int flags;
long long ext_conn_id;
unsigned char remote_ipv6[16];
int remote_port;
unsigned char our_ipv6[16];
int our_port;
int data[];
};
#define PROXY_HDR(__x) ((struct rpc_proxy_req *)((__x) - offsetof(struct rpc_proxy_req, data)))
struct rpc_proxy_ans {
int len;
int seq_num;
int type;
int flags; // +16 = small error packet, +8 = flush immediately
long long ext_conn_id;
int data[];
};
struct rpc_close_conn {
int len;
int seq_num;
int type;
long long ext_conn_id;
int crc32;
};
struct rpc_close_ext {
int len;
int seq_num;
int type;
long long ext_conn_id;
int crc32;
};
struct rpc_simple_ack {
int len;
int seq_num;
int type;
long long ext_conn_id;
int confirm_key;
int crc32;
};
#pragma pack(pop)
BN_CTX *BN_ctx;
void prng_seed (const char *password_filename, int password_length);
int serialize_bignum (BIGNUM *b, char *buffer, int maxlen);
long long compute_rsa_key_fingerprint (RSA *key);
#define PACKET_BUFFER_SIZE (16384 * 100) // temp fix
int packet_buffer[PACKET_BUFFER_SIZE], *packet_ptr;
static inline void out_ints (int *what, int len) {
assert (packet_ptr + len <= packet_buffer + PACKET_BUFFER_SIZE);
memcpy (packet_ptr, what, len * 4);
packet_ptr += len;
}
static inline void out_int (int x) {
assert (packet_ptr + 1 <= packet_buffer + PACKET_BUFFER_SIZE);
*packet_ptr++ = x;
}
static inline void out_long (long long x) {
assert (packet_ptr + 2 <= packet_buffer + PACKET_BUFFER_SIZE);
*(long long *)packet_ptr = x;
packet_ptr += 2;
}
static inline void clear_packet (void) {
packet_ptr = packet_buffer;
}
void out_cstring (const char *str, long len);
void out_cstring_careful (const char *str, long len);
void out_data (const char *data, long len);
static inline void out_string (const char *str) {
out_cstring (str, strlen (str));
}
static inline void out_bignum (BIGNUM *n) {
int l = serialize_bignum (n, (char *)packet_ptr, (PACKET_BUFFER_SIZE - (packet_ptr - packet_buffer)) * 4);
assert (l > 0);
packet_ptr += l >> 2;
}
extern int *in_ptr, *in_end;
static inline int prefetch_strlen (void) {
if (in_ptr >= in_end) {
return -1;
}
unsigned l = *in_ptr;
if ((l & 0xff) < 0xfe) {
l &= 0xff;
return (in_end >= in_ptr + (l >> 2) + 1) ? (int)l : -1;
} else if ((l & 0xff) == 0xfe) {
l >>= 8;
return (l >= 254 && in_end >= in_ptr + ((l + 7) >> 2)) ? (int)l : -1;
} else {
return -1;
}
}
static inline char *fetch_str (int len) {
if (len < 254) {
char *str = (char *) in_ptr + 1;
in_ptr += 1 + (len >> 2);
return str;
} else {
char *str = (char *) in_ptr + 4;
in_ptr += (len + 7) >> 2;
return str;
}
}
static inline char *fetch_str_dup (void) {
int l = prefetch_strlen ();
return strndup (fetch_str (l), l);
}
static __inline__ unsigned long long rdtsc(void) {
unsigned hi, lo;
__asm__ __volatile__ ("rdtsc" : "=a"(lo), "=d"(hi));
return ( (unsigned long long)lo)|( ((unsigned long long)hi)<<32 );
}
static inline long have_prefetch_ints (void) {
return in_end - in_ptr;
}
int fetch_bignum (BIGNUM *x);
static inline int fetch_int (void) {
return *(in_ptr ++);
}
static inline int prefetch_int (void) {
return *(in_ptr);
}
static inline long long fetch_long (void) {
long long r = *(long long *)in_ptr;
in_ptr += 2;
return r;
}
int get_random_bytes (void *buf, int n);
int pad_rsa_encrypt (char *from, int from_len, char *to, int size, BIGNUM *N, BIGNUM *E);
int pad_rsa_decrypt (char *from, int from_len, char *to, int size, BIGNUM *N, BIGNUM *D);
extern long long rsa_encrypted_chunks, rsa_decrypted_chunks;
extern unsigned char aes_key_raw[32], aes_iv[32];
extern AES_KEY aes_key;
void init_aes_unauth (const char server_nonce[16], const char hidden_client_nonce[32], int encrypt);
void init_aes_auth (char auth_key[192], char msg_key[16], int encrypt);
int pad_aes_encrypt (char *from, int from_len, char *to, int size);
int pad_aes_decrypt (char *from, int from_len, char *to, int size);
static inline void hexdump_in (void) {
int *ptr = in_ptr;
while (ptr < in_end) { fprintf (stderr, " %08x", *(ptr ++)); }
fprintf (stderr, "\n");
}
#endif
#define _GNU_SOURCE
#include <string.h>
#include <stdlib.h>
#include <assert.h>
#include <netdb.h>
#include <netinet/tcp.h>
#include <sys/fcntl.h>
#include <errno.h>
#include <stdio.h>
#include <unistd.h>
#include <poll.h>
#include <openssl/rand.h>
#include <arpa/inet.h>
#include "net.h"
#include "include.h"
#include "mtproto-client.h"
#include "mtproto-common.h"
#include "tree.h"
DEFINE_TREE(int,int,int_cmp,0)
int verbosity;
extern struct connection_methods auth_methods;
struct connection_buffer *new_connection_buffer (int size) {
struct connection_buffer *b = malloc (sizeof (*b));
memset (b, 0, sizeof (*b));
b->start = malloc (size);
b->end = b->start + size;
b->rptr = b->wptr = b->start;
return b;
}
void delete_connection_buffer (struct connection_buffer *b) {
free (b->start);
free (b);
}
int write_out (struct connection *c, const void *data, int len) {
if (!len) { return 0; }
assert (len > 0);
int x = 0;
if (!c->out_head) {
struct connection_buffer *b = new_connection_buffer (1 << 20);
c->out_head = c->out_tail = b;
}
while (len) {
if (c->out_tail->end - c->out_tail->wptr >= len) {
memcpy (c->out_tail->wptr, data, len);
c->out_tail->wptr += len;
c->out_bytes += len;
return x + len;
} else {
int y = c->out_tail->end - c->out_tail->wptr;
assert (y < len);
memcpy (c->out_tail->wptr, data, y);
x += y;
len -= y;
data += y;
struct connection_buffer *b = new_connection_buffer (1 << 20);
c->out_tail->next = b;
b->next = 0;
c->out_tail = b;
c->out_bytes += y;
}
}
return x;
}
int read_in (struct connection *c, void *data, int len) {
if (!len) { return 0; }
assert (len > 0);
if (len > c->in_bytes) {
len = c->in_bytes;
}
int x = 0;
while (len) {
int y = c->in_head->wptr - c->in_head->rptr;
if (y > len) {
memcpy (data, c->in_head->rptr, len);
c->in_head->rptr += len;
c->in_bytes -= len;
return x + len;
} else {
memcpy (data, c->in_head->rptr, y);
c->in_bytes -= y;
x += y;
data += y;
len -= y;
void *old = c->in_head;
c->in_head = c->in_head->next;
if (!c->in_head) {
c->in_tail = 0;
}
delete_connection_buffer (old);
}
}
return x;
}
int read_in_lookup (struct connection *c, void *data, int len) {
if (!len) { return 0; }
assert (len > 0);
if (len > c->in_bytes) {
len = c->in_bytes;
}
int x = 0;
struct connection_buffer *b = c->in_head;
while (len) {
int y = b->wptr - b->rptr;
if (y > len) {
memcpy (data, b->rptr, len);
return x + len;
} else {
memcpy (data, b->rptr, y);
x += y;
b = b->next;
}
}
return x;
}
void flush_out (struct connection *c UU) {
}
#define MAX_CONNECTIONS 100
struct connection *Connections[MAX_CONNECTIONS];
int max_connection_fd;
struct connection *create_connection (const char *host, int port, struct session *session, struct connection_methods *methods) {
struct connection *c = malloc (sizeof (*c));
memset (c, 0, sizeof (*c));
struct hostent *h;
if (!(h = gethostbyname (host)) || h->h_addrtype != AF_INET || h->h_length != 4 || !h->h_addr_list || !h->h_addr) {
assert (0);
}
int fd;
assert ((fd = socket (AF_INET, SOCK_STREAM, 0)) != -1);
assert (fd >= 0 && fd < MAX_CONNECTIONS);
if (fd > max_connection_fd) {
max_connection_fd = fd;
}
int flags = -1;
setsockopt (fd, SOL_SOCKET, SO_REUSEADDR, &flags, sizeof (flags));
setsockopt (fd, SOL_SOCKET, SO_KEEPALIVE, &flags, sizeof (flags));
setsockopt (fd, IPPROTO_TCP, TCP_NODELAY, &flags, sizeof (flags));
struct sockaddr_in addr;
addr.sin_family = AF_INET;
addr.sin_port = htons (port);
addr.sin_addr.s_addr = inet_addr (host);
fcntl (fd, F_SETFL, O_NONBLOCK);
if (connect (fd, (struct sockaddr *) &addr, sizeof (addr)) == -1) {
if (errno != EINPROGRESS) {
fprintf (stderr, "Can not connect to %s:%d %m\n", host, port);
close (fd);
free (c);
return 0;
}
}
struct pollfd s;
s.fd = fd;
s.events = POLLOUT | POLLERR | POLLRDHUP | POLLHUP;
if (poll (&s, 1, 10000) <= 0 || !(s.revents & POLLOUT)) {
perror ("poll");
close (fd);
free (c);
return 0;
}
c->session = session;
c->fd = fd;
c->ip = htonl (*(int *)h->h_addr);
c->flags = 0;
c->state = conn_ready;
c->methods = methods;
assert (!Connections[fd]);
Connections[fd] = c;
if (verbosity) {
fprintf (stderr, "connect to %s:%d successful\n", host, port);
}
if (c->methods->ready) {
c->methods->ready (c);
}
return c;
}
void fail_connection (struct connection *c) {
struct connection_buffer *b = c->out_head;
while (b) {
struct connection_buffer *d = b;
b = b->next;
delete_connection_buffer (d);
}
b = c->in_head;
while (b) {
struct connection_buffer *d = b;
b = b->next;
delete_connection_buffer (d);
}
c->out_head = c->out_tail = c->in_head = c->in_tail = 0;
c->state = conn_failed;
c->out_bytes = c->in_bytes = 0;
}
void try_write (struct connection *c) {
if (verbosity) {
fprintf (stderr, "try write: fd = %d\n", c->fd);
}
int x = 0;
while (c->out_head) {
int r = write (c->fd, c->out_head->rptr, c->out_head->wptr - c->out_head->rptr);
if (r >= 0) {
x += r;
c->out_head->rptr += r;
if (c->out_head->rptr != c->out_head->wptr) {
break;
}
struct connection_buffer *b = c->out_head;
c->out_head = b->next;
if (!c->out_head) {
c->out_tail = 0;
}
delete_connection_buffer (b);
} else {
if (errno != EAGAIN && errno != EWOULDBLOCK) {
fail_connection (c);
return;
} else {
break;
}
}
}
if (verbosity) {
fprintf (stderr, "Sent %d bytes to %d\n", x, c->fd);
}
c->out_bytes -= x;
}
void hexdump (struct connection_buffer *b) {
int pos = 0;
int rem = 8;
while (b) {
unsigned char *c = b->rptr;
while (c != b->wptr) {
if (rem == 8) {
if (pos) { printf ("\n"); }
printf ("%04d", pos);
}
printf (" %02x", (int)*c);
rem --;
pos ++;
if (!rem) {
rem = 8;
}
c ++;
}
b = b->next;
}
printf ("\n");
}
void try_rpc_read (struct connection *c) {
assert (c->in_head);
if (verbosity >= 4) {
hexdump (c->in_head);
}
while (1) {
if (c->in_bytes < 1) { return; }
unsigned len = 0;
unsigned t = 0;
assert (read_in_lookup (c, &len, 1) == 1);
if (len >= 1 && len <= 0x7e) {
if (c->in_bytes < (int)(4 * len)) { return; }
} else {
if (c->in_bytes < 4) { return; }
assert (read_in_lookup (c, &len, 4) == 4);
len = (len >> 8);
if (c->in_bytes < (int)(4 * len)) { return; }
len = 0x7f;
}
if (len >= 1 && len <= 0x7e) {
assert (read_in (c, &t, 1) == 1);
assert (t == len);
assert (len >= 1);
} else {
assert (len == 0x7f);
assert (read_in (c, &len, 4) == 4);
len = (len >> 8);
assert (len >= 1);
}
len *= 4;
int op;
assert (read_in_lookup (c, &op, 4) == 4);
c->methods->execute (c, op, len);
}
}
void try_read (struct connection *c) {
if (verbosity) {
fprintf (stderr, "try read: fd = %d\n", c->fd);
}
if (!c->in_tail) {
c->in_head = c->in_tail = new_connection_buffer (1 << 20);
}
int x = 0;
while (1) {
int r = read (c->fd, c->in_tail->wptr, c->in_tail->end - c->in_tail->wptr);
if (r >= 0) {
c->in_tail->wptr += r;
x += r;
if (c->in_tail->wptr != c->in_tail->end) {
break;
}
struct connection_buffer *b = new_connection_buffer (1 << 20);
c->in_tail->next = b;
c->in_tail = b;
} else {
if (errno != EAGAIN && errno != EWOULDBLOCK) {
fail_connection (c);
return;
} else {
break;
}
}
}
if (verbosity) {
fprintf (stderr, "Received %d bytes from %d\n", x, c->fd);
}
c->in_bytes += x;
if (x) {
try_rpc_read (c);
}
}
int connections_make_poll_array (struct pollfd *fds, int max) {
int _max = max;
int i;
for (i = 0; i <= max_connection_fd; i++) if (Connections[i] && Connections[i]->state != conn_failed) {
assert (max > 0);
struct connection *c = Connections[i];
fds[0].fd = c->fd;
fds[0].events = POLLERR | POLLHUP | POLLRDHUP | POLLIN;
if (c->out_bytes || c->state == conn_connecting) {
fds[0].events |= POLLOUT;
}
fds ++;
max --;
}
if (verbosity >= 3) {
fprintf (stderr, "%d connections in poll\n", _max - max);
}
return _max - max;
}
void connections_poll_result (struct pollfd *fds, int max) {
if (verbosity >= 2) {
fprintf (stderr, "connections_poll_result: max = %d\n", max);
}
int i;
for (i = 0; i < max; i++) {
struct connection *c = Connections[fds[i].fd];
if (fds[i].revents & POLLIN) {
try_read (c);
}
if (fds[i].revents & (POLLHUP | POLLERR | POLLRDHUP)) {
if (verbosity) {
fprintf (stderr, "fail connection\n");
}
fail_connection (c);
} else if (fds[i].revents & POLLOUT) {
if (c->state == conn_connecting) {
c->state = conn_ready;
}
if (c->out_bytes) {
try_write (c);
}
}
}
}
int send_all_acks (struct session *S) {
clear_packet ();
out_int (tree_count_int (S->ack_tree));
while (S->ack_tree) {
int x = tree_get_min_int (S->ack_tree);
out_int (x);
S->ack_tree = tree_delete_int (S->ack_tree, x);
}
encrypt_send_message (S->c, packet_buffer, packet_ptr - packet_buffer, 0);
return 0;
}
void insert_seqno (struct session *S, int seqno) {
if (!S->ack_tree) {
S->ev.alarm = (void *)send_all_acks;
S->ev.self = (void *)S;
S->ev.timeout = get_double_time () + ACK_TIMEOUT;
insert_event_timer (&S->ev);
}
if (!tree_lookup_int (S->ack_tree, seqno)) {
S->ack_tree = tree_insert_int (S->ack_tree, seqno, lrand48 ());
}
}
extern struct dc *DC_list[];
struct dc *alloc_dc (int id, char *ip, int port) {
assert (!DC_list[id]);
struct dc *DC = malloc (sizeof (*DC));
memset (DC, 0, sizeof (*DC));
DC->id = id;
DC->ip = ip;
DC->port = port;
DC_list[id] = DC;
return DC;
}
void dc_create_session (struct dc *DC) {
struct session *S = malloc (sizeof (*S));
memset (S, 0, sizeof (*S));
assert (RAND_pseudo_bytes ((unsigned char *) &S->session_id, 8) >= 0);
S->dc = DC;
S->c = create_connection (DC->ip, DC->port, S, &auth_methods);
assert (!DC->sessions[0]);
DC->sessions[0] = S;
}
#ifndef __NET_H__
#define __NET_H__
#include <poll.h>
struct dc;
#include "queries.h"
#define TG_SERVER "173.240.5.1"
//#define TG_SERVER "95.142.192.66"
#define TG_APP_HASH "3bc14c6455ef1595ec86a125762c3aad"
#define TG_APP_ID 51
#define ACK_TIMEOUT 60
#define MAX_DC_ID 10
enum dc_state{
st_init,
st_reqpq_sent,
st_reqdh_sent,
st_client_dh_sent,
st_authorized,
st_error
} ;
struct connection;
struct connection_methods {
int (*ready) (struct connection *c);
int (*close) (struct connection *c);
int (*execute) (struct connection *c, int op, int len);
};
#define MAX_DC_SESSIONS 3
struct session {
struct dc *dc;
long long session_id;
int seq_no;
struct connection *c;
struct tree_int *ack_tree;
struct event_timer ev;
};
struct dc {
int id;
int port;
int flags;
char *ip;
char *user;
struct session *sessions[MAX_DC_SESSIONS];
char auth_key[256];
long long auth_key_id;
long long server_salt;
int server_time_delta;
double server_time_udelta;
};
#define DC_SERIALIZED_MAGIC 0x64582faa
struct dc_serialized {
int magic;
int port;
char ip[64];
char user[64];
char auth_key[256];
long long auth_key_id, server_salt;
int authorized;
};
struct connection_buffer {
void *start;
void *end;
void *rptr;
void *wptr;
struct connection_buffer *next;
};
enum conn_state {
conn_none,
conn_connecting,
conn_ready,
conn_failed,
conn_stopped
};
struct connection {
int fd;
int ip;
int port;
int flags;
enum conn_state state;
int ipv6[4];
struct connection_buffer *in_head;
struct connection_buffer *in_tail;
struct connection_buffer *out_head;
struct connection_buffer *out_tail;
int in_bytes;
int out_bytes;
int packet_num;
int out_packet_num;
struct connection_methods *methods;
struct session *session;
void *extra;
};
extern struct connection *Connections[];
int write_out (struct connection *c, const void *data, int len);
void flush_out (struct connection *c);
int read_in (struct connection *c, void *data, int len);
void create_all_outbound_connections (void);
struct connection *create_connection (const char *host, int port, struct session *session, struct connection_methods *methods);
int connections_make_poll_array (struct pollfd *fds, int max);
void connections_poll_result (struct pollfd *fds, int max);
void dc_create_session (struct dc *DC);
void insert_seqno (struct session *S, int seqno);
struct dc *alloc_dc (int id, char *ip, int port);
#define GET_DC(c) (c->session->dc)
#endif
#include <string.h>
#include <memory.h>
#include <stdlib.h>
#include <zlib.h>
#include "include.h"
#include "mtproto-client.h"
#include "queries.h"
#include "tree.h"
#include "mtproto-common.h"
#include "telegram.h"
#include "loop.h"
#include "structures.h"
#include "interface.h"
int verbosity;
#define QUERY_TIMEOUT 0.3
#define memcmp8(a,b) memcmp ((a), (b), 8)
DEFINE_TREE (query, struct query *, memcmp8, 0) ;
struct tree_query *queries_tree;
double get_double_time (void) {
struct timespec tv;
clock_gettime (CLOCK_REALTIME, &tv);
return tv.tv_sec + 1e-9 * tv.tv_nsec;
}
struct query *query_get (long long id) {
return tree_lookup_query (queries_tree, (void *)&id);
}
int alarm_query (struct query *q) {
assert (q);
return 0;
}
struct query *send_query (struct dc *DC, int ints, void *data, struct query_methods *methods) {
assert (DC);
assert (DC->auth_key_id);
if (!DC->sessions[0]) {
dc_create_session (DC);
}
if (verbosity) {
fprintf (stderr, "Sending query of size %d to DC (%s:%d)\n", 4 * ints, DC->ip, DC->port);
}
struct query *q = malloc (sizeof (*q));
q->data_len = ints;
q->data = malloc (4 * ints);
memcpy (q->data, data, 4 * ints);
q->msg_id = encrypt_send_message (DC->sessions[0]->c, data, ints, 1);
if (verbosity) {
fprintf (stderr, "Msg_id is %lld\n", q->msg_id);
}
q->methods = methods;
if (queries_tree) {
fprintf (stderr, "%lld %lld\n", q->msg_id, queries_tree->x->msg_id);
}
queries_tree = tree_insert_query (queries_tree, q, lrand48 ());
q->ev.alarm = (void *)alarm_query;
q->ev.timeout = get_double_time () + QUERY_TIMEOUT;
q->ev.self = (void *)q;
insert_event_timer (&q->ev);
return q;
}
void query_ack (long long id) {
struct query *q = query_get (id);
if (q) { q->flags |= QUERY_ACK_RECEIVED; }
}
void query_error (long long id) {
assert (fetch_int () == CODE_rpc_error);
int error_code = fetch_int ();
int error_len = prefetch_strlen ();
char *error = fetch_str (error_len);
if (verbosity) {
fprintf (stderr, "error for query #%lld: #%d :%.*s\n", id, error_code, error_len, error);
}
struct query *q = query_get (id);
if (!q) {
if (verbosity) {
fprintf (stderr, "No such query\n");
}
} else {
remove_event_timer (&q->ev);
queries_tree = tree_delete_query (queries_tree, q);
if (q->methods && q->methods->on_error) {
q->methods->on_error (q, error_code, error_len, error);
}
free (q->data);
free (q);
}
}
#define MAX_PACKED_SIZE (1 << 20)
static int packed_buffer[MAX_PACKED_SIZE / 4];
void query_result (long long id UU) {
if (verbosity) {
fprintf (stderr, "result for query #%lld\n", id);
}
if (verbosity >= 4) {
fprintf (stderr, "result: ");
hexdump_in ();
}
int op = prefetch_int ();
int *end = 0;
int *eend = 0;
if (op == CODE_gzip_packed) {
fetch_int ();
int l = prefetch_strlen ();
char *s = fetch_str (l);
size_t dl = MAX_PACKED_SIZE;
z_stream strm = {0};
assert (inflateInit2 (&strm, 16 + MAX_WBITS) == Z_OK);
strm.avail_in = l;
strm.next_in = (void *)s;
strm.avail_out = MAX_PACKED_SIZE;
strm.next_out = (void *)packed_buffer;
int err = inflate (&strm, Z_FINISH);
if (verbosity) {
fprintf (stderr, "inflate error = %d\n", err);
fprintf (stderr, "inflated %d bytes\n", (int)strm.total_out);
}
end = in_ptr;
eend = in_end;
assert (dl % 4 == 0);
in_ptr = packed_buffer;
in_end = in_ptr + strm.total_out / 4;
if (verbosity >= 4) {
fprintf (stderr, "Unzipped data: ");
hexdump_in ();
}
}
struct query *q = query_get (id);
if (!q) {
if (verbosity) {
fprintf (stderr, "No such query\n");
}
} else {
remove_event_timer (&q->ev);
queries_tree = tree_delete_query (queries_tree, q);
if (q->methods && q->methods->on_answer) {
q->methods->on_answer (q);
}
free (q->data);
free (q);
}
if (end) {
in_ptr = end;
in_end = eend;
}
}
#define event_timer_cmp(a,b) ((a)->timeout > (b)->timeout ? 1 : ((a)->timeout < (b)->timeout ? -1 : (memcmp (a, b, sizeof (struct event_timer)))))
DEFINE_TREE (timer, struct event_timer *, event_timer_cmp, 0)
struct tree_timer *timer_tree;
void insert_event_timer (struct event_timer *ev) {
return;
fprintf (stderr, "INSERT: %lf %p %p\n", ev->timeout, ev->self, ev->alarm);
tree_check_timer (timer_tree);
timer_tree = tree_insert_timer (timer_tree, ev, lrand48 ());
tree_check_timer (timer_tree);
}
void remove_event_timer (struct event_timer *ev) {
return;
fprintf (stderr, "REMOVE: %lf %p %p\n", ev->timeout, ev->self, ev->alarm);
tree_check_timer (timer_tree);
timer_tree = tree_delete_timer (timer_tree, ev);
tree_check_timer (timer_tree);
}
double next_timer_in (void) {
if (!timer_tree) { return 1e100; }
return tree_get_min_timer (timer_tree)->timeout;
}
void work_timers (void) {
double t = get_double_time ();
while (timer_tree) {
struct event_timer *ev = tree_get_min_timer (timer_tree);
assert (ev);
if (ev->timeout > t) { break; }
remove_event_timer (ev);
ev->alarm (ev->self);
}
}
int max_chat_size;
int want_dc_num;
extern struct dc *DC_list[];
extern struct dc *DC_working;
int help_get_config_on_answer (struct query *q UU) {
assert (fetch_int () == CODE_config);
fetch_int ();
unsigned test_mode = fetch_int ();
assert (test_mode == CODE_bool_true || test_mode == CODE_bool_false);
assert (test_mode == CODE_bool_false);
int this_dc = fetch_int ();
if (verbosity) {
fprintf (stderr, "this_dc = %d\n", this_dc);
}
assert (fetch_int () == CODE_vector);
int n = fetch_int ();
assert (n <= 10);
int i;
for (i = 0; i < n; i++) {
assert (fetch_int () == CODE_dc_option);
int id = fetch_int ();
int l1 = prefetch_strlen ();
char *name = fetch_str (l1);
int l2 = prefetch_strlen ();
char *ip = fetch_str (l2);
int port = fetch_int ();
if (verbosity) {
fprintf (stderr, "id = %d, name = %.*s ip = %.*s port = %d\n", id, l1, name, l2, ip, port);
}
if (!DC_list[id]) {
alloc_dc (id, strndup (ip, l2), port);
}
}
max_chat_size = fetch_int ();
if (verbosity >= 2) {
fprintf (stderr, "chat_size = %d\n", max_chat_size);
}
return 0;
}
struct query_methods help_get_config_methods = {
.on_answer = help_get_config_on_answer
};
char *phone_code_hash;
int send_code_on_answer (struct query *q UU) {
assert (fetch_int () == CODE_auth_sent_code);
assert (fetch_int () == (int)CODE_bool_true);
int l = prefetch_strlen ();
char *s = fetch_str (l);
if (phone_code_hash) {
free (phone_code_hash);
}
phone_code_hash = strndup (s, l);
want_dc_num = -1;
return 0;
}
int send_code_on_error (struct query *q UU, int error_code, int l, char *error) {
int s = strlen ("PHONE_MIGRATE_");
if (l >= s && !memcmp (error, "PHONE_MIGRATE_", s)) {
int i = error[s] - '0';
want_dc_num = i;
} else {
fprintf (stderr, "error_code = %d, error = %.*s\n", error_code, l, error);
assert (0);
}
return 0;
}
struct query_methods send_code_methods = {
.on_answer = send_code_on_answer,
.on_error = send_code_on_error
};
int code_is_sent (void) {
return want_dc_num;
}
int config_got (void) {
return DC_list[want_dc_num] != 0;
}
char *suser;
extern int dc_working_num;
void do_send_code (const char *user) {
suser = strdup (user);
want_dc_num = 0;
clear_packet ();
out_int (CODE_auth_send_code);
out_string (user);
out_int (0);
out_int (TG_APP_ID);
out_string (TG_APP_HASH);
send_query (DC_working, packet_ptr - packet_buffer, packet_buffer, &send_code_methods);
net_loop (0, code_is_sent);
if (want_dc_num == -1) { return; }
if (DC_list[want_dc_num]) {
DC_working = DC_list[want_dc_num];
if (!DC_working->auth_key_id) {
dc_authorize (DC_working);
}
if (!DC_working->sessions[0]) {
dc_create_session (DC_working);
}
dc_working_num = want_dc_num;
} else {
clear_packet ();
out_int (CODE_help_get_config);
send_query (DC_working, packet_ptr - packet_buffer, packet_buffer, &help_get_config_methods);
net_loop (0, config_got);
DC_working = DC_list[want_dc_num];
if (!DC_working->auth_key_id) {
dc_authorize (DC_working);
}
if (!DC_working->sessions[0]) {
dc_create_session (DC_working);
}
dc_working_num = want_dc_num;
}
want_dc_num = 0;
clear_packet ();
out_int (CODE_auth_send_code);
out_string (user);
out_int (0);
out_int (TG_APP_ID);
out_string (TG_APP_HASH);
send_query (DC_working, packet_ptr - packet_buffer, packet_buffer, &send_code_methods);
net_loop (0, code_is_sent);
assert (want_dc_num == -1);
}
int sign_in_ok;
int sign_in_is_ok (void) {
return sign_in_ok;
}
struct user User;
int sign_in_on_answer (struct query *q UU) {
assert (fetch_int () == (int)CODE_auth_authorization);
int expires = fetch_int ();
fetch_user (&User);
sign_in_ok = 1;
if (verbosity) {
fprintf (stderr, "authorized successfully: name = '%s %s', phone = '%s', expires = %d\n", User.first_name, User.last_name, User.phone, (int)(expires - get_double_time ()));
}
return 0;
}
int sign_in_on_error (struct query *q UU, int error_code, int l, char *error) {
fprintf (stderr, "error_code = %d, error = %.*s\n", error_code, l, error);
sign_in_ok = -1;
return 0;
}
struct query_methods sign_in_methods = {
.on_answer = sign_in_on_answer,
.on_error = sign_in_on_error
};
int do_send_code_result (const char *code) {
clear_packet ();
out_int (CODE_sign_in);
out_string (suser);
out_string (phone_code_hash);
out_string (code);
send_query (DC_working, packet_ptr - packet_buffer, packet_buffer, &sign_in_methods);
sign_in_ok = 0;
net_loop (0, sign_in_is_ok);
return sign_in_ok;
}
int get_contacts_on_answer (struct query *q UU) {
assert (fetch_int () == (int)CODE_contacts_contacts);
assert (fetch_int () == CODE_vector);
int n = fetch_int ();
int i;
for (i = 0; i < n; i++) {
assert (fetch_int () == (int)CODE_contact);
fetch_int (); // id
fetch_int (); // mutual
}
assert (fetch_int () == CODE_vector);
n = fetch_int ();
for (i = 0; i < n; i++) {
struct user User;
fetch_user (&User);
rprintf ("User: id = %d, first_name = %s, last_name = %s\n", User.id, User.first_name, User.last_name);
}
return 0;
}
struct query_methods get_contacts_methods = {
.on_answer = get_contacts_on_answer,
};
void do_update_contact_list (void) {
clear_packet ();
out_int (CODE_contacts_get_contacts);
out_string ("");
send_query (DC_working, packet_ptr - packet_buffer, packet_buffer, &get_contacts_methods);
}
#include "net.h"
#ifndef __QUERIES_H__
#define __QUERIES_H__
#define QUERY_ACK_RECEIVED 1
struct query;
struct query_methods {
int (*on_answer)(struct query *q);
int (*on_error)(struct query *q, int error_code, int len, char *error);
int (*on_timeout)(struct query *q);
};
struct event_timer {
double timeout;
int (*alarm)(void *self);
void *self;
};
struct query {
long long msg_id;
int data_len;
int flags;
void *data;
struct query_methods *methods;
struct event_timer ev;
};
struct query *send_query (struct dc *DC, int len, void *data, struct query_methods *methods);
void query_ack (long long id);
void query_error (long long id);
void query_result (long long id);
void insert_event_timer (struct event_timer *ev);
void remove_event_timer (struct event_timer *ev);
double next_timer_in (void);
void work_timers (void);
extern struct query_methods help_get_config_methods;
void do_send_code (const char *user);
int do_send_code_result (const char *code);
double get_double_time (void);
void do_update_contact_list (void);
#endif
#include <assert.h>
#include "structures.h"
#include "mtproto-common.h"
void fetch_file_location (struct file_location *loc) {
int x = fetch_int ();
if (x == CODE_file_location_unavailable) {
loc->dc = -1;
loc->volume = fetch_long ();
loc->local_id = fetch_int ();
loc->secret = fetch_long ();
} else {
assert (x == CODE_file_location);
loc->dc = fetch_int ();;
loc->volume = fetch_long ();
loc->local_id = fetch_int ();
loc->secret = fetch_long ();
}
}
void fetch_user_status (struct user_status *S) {
int x = fetch_int ();
switch (x) {
case CODE_user_status_empty:
S->online = 0;
break;
case CODE_user_status_online:
S->online = 1;
S->when = fetch_int ();
break;
case CODE_user_status_offline:
S->online = -1;
S->when = fetch_int ();
break;
default:
assert (0);
}
}
void fetch_user (struct user *U) {
memset (U, 0, sizeof (*U));
unsigned x = fetch_int ();
assert (x == CODE_user_empty || x == CODE_user_self || x == CODE_user_contact || x == CODE_user_request || x == CODE_user_foreign || x == CODE_user_deleted);
U->id = fetch_int ();
if (x == CODE_user_empty) {
U->flags = 1;
return;
}
U->first_name = fetch_str_dup ();
U->last_name = fetch_str_dup ();
if (x == CODE_user_deleted) {
U->flags = 2;
return;
}
if (x == CODE_user_self) {
U->flags = 4;
} else {
U->access_hash = fetch_long ();
}
if (x == CODE_user_foreign) {
U->flags |= 8;
} else {
U->phone = fetch_str_dup ();
}
unsigned y = fetch_int ();
if (y == CODE_user_profile_photo_empty) {
U->photo_small.dc = -2;
U->photo_big.dc = -2;
} else {
assert (y == CODE_user_profile_photo);
fetch_file_location (&U->photo_small);
fetch_file_location (&U->photo_big);
}
fetch_user_status (&U->status);
if (x == CODE_user_self) {
assert (fetch_int () == (int)CODE_bool_false);
}
if (x == CODE_user_contact) {
U->flags |= 16;
}
}
#ifndef __STRUCTURES_H__
#define __STRUCTURES_H__
struct file_location {
int dc;
long long volume;
int local_id;
long long secret;
};
struct user_status {
int online;
int when;
};
struct user {
int id;
int flags;
char *first_name;
char *last_name;
char *phone;
long long access_hash;
struct file_location photo_big;
struct file_location photo_small;
struct user_status status;
};
void fetch_file_location (struct file_location *loc);
void fetch_user (struct user *U);
#endif
#ifndef __TREE_H__
#define __TREE_H__
#include <stdio.h>
#include <memory.h>
#include <malloc.h>
#include <assert.h>
#define DEFINE_TREE(X_NAME, X_TYPE, X_CMP, X_UNSET) \
struct tree_ ## X_NAME { \
struct tree_ ## X_NAME *left, *right;\
X_TYPE x;\
int y;\
};\
\
struct tree_ ## X_NAME *new_tree_node_ ## X_NAME (X_TYPE x, int y) {\
struct tree_ ## X_NAME *T = malloc (sizeof (*T));\
T->x = x;\
T->y = y;\
T->left = T->right = 0;\
return T;\
}\
\
void delete_tree_node_ ## X_NAME (struct tree_ ## X_NAME *T) {\
free (T);\
}\
\
void tree_split_ ## X_NAME (struct tree_ ## X_NAME *T, X_TYPE x, struct tree_ ## X_NAME **L, struct tree_ ## X_NAME **R) {\
if (!T) {\
*L = *R = 0;\
} else {\
int c = X_CMP (x, T->x);\
if (c < 0) {\
tree_split_ ## X_NAME (T->left, x, L, &T->left);\
*R = T;\
} else {\
tree_split_ ## X_NAME (T->right, x, &T->right, R);\
*L = T;\
}\
}\
}\
\
struct tree_ ## X_NAME *tree_insert_ ## X_NAME (struct tree_ ## X_NAME *T, X_TYPE x, int y) {\
if (!T) {\
return new_tree_node_ ## X_NAME (x, y);\
} else {\
if (y > T->y) {\
struct tree_ ## X_NAME *N = new_tree_node_ ## X_NAME (x, y);\
tree_split_ ## X_NAME (T, x, &N->left, &N->right);\
return N;\
} else {\
int c = X_CMP (x, T->x);\
assert (c);\
return tree_insert_ ## X_NAME (c < 0 ? T->left : T->right, x, y);\
}\
}\
}\
\
struct tree_ ## X_NAME *tree_merge_ ## X_NAME (struct tree_ ## X_NAME *L, struct tree_ ## X_NAME *R) {\
if (!L || !R) {\
return L ? L : R;\
} else {\
if (L->y > R->y) {\
L->right = tree_merge_ ## X_NAME (L->right, R);\
return L;\
} else {\
R->left = tree_merge_ ## X_NAME (L, R->left);\
return R;\
}\
}\
}\
\
struct tree_ ## X_NAME *tree_delete_ ## X_NAME (struct tree_ ## X_NAME *T, X_TYPE x) {\
assert (T);\
int c = X_CMP (x, T->x);\
if (!c) {\
struct tree_ ## X_NAME *N = tree_merge_ ## X_NAME (T->left, T->right);\
delete_tree_node_ ## X_NAME (T);\
return N;\
} else {\
return tree_delete_ ## X_NAME (c < 0 ? T->left : T->right, x);\
}\
}\
\
X_TYPE tree_get_min_ ## X_NAME (struct tree_ ## X_NAME *T) {\
if (!T) { return X_UNSET; } \
while (T->left) { T = T->left; }\
return T->x; \
} \
\
X_TYPE tree_lookup_ ## X_NAME (struct tree_ ## X_NAME *T, X_TYPE x) {\
int c;\
while (T && (c = X_CMP (x, T->x))) {\
T = (c < 0 ? T->left : T->right);\
}\
return T ? T->x : X_UNSET;\
}\
\
int tree_count_ ## X_NAME (struct tree_ ## X_NAME *T) { \
if (!T) { return 0; }\
return 1 + tree_count_ ## X_NAME (T->left) + tree_count_ ## X_NAME (T->right); \
}\
void tree_check_ ## X_NAME (struct tree_ ## X_NAME *T) { \
if (!T) { return; }\
if (T->left) { \
assert (T->left->y <= T->y);\
assert (X_CMP (T->left->x, T->x) < 0); \
}\
if (T->right) { \
assert (T->right->y <= T->y);\
assert (X_CMP (T->right->x, T->x) > 0); \
}\
}\
#define int_cmp(a,b) ((a) - (b))
#endif
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment