/* handlers.C
 * John Viega
 *
 * Jan 28-29 2000
 */

#include "lex.H"
#include "handlers.H"
#include "resultsdb.H"
#include "config.H"
#include "dict.H"
#include "strpool.H"

void ConditionalAdd(char *source, int line, Severity s,
		    VulnInfo *v, int explanation=0)
{
  // If we're in input scanning mode, severitys are ignored.
  if(s>=GetSeverityCutoff() || GetInputScanning())
    AddResult(source, line, s, v, explanation);
}

static void ReportFormatAttack(VulnInfo *v, Token *t, char *source, 
			       int construct_new_vulninfo=1)
{
  VulnInfo *new_v, *proto_info;
  if(construct_new_vulninfo) {
    proto_info = GetVulnInfo("printf");
    if(!proto_info) {
      // Shitty kludge to make sure this doesn't crash on the off chance
      // an old or incomplete database is being used with new handlers.
      new_v = new VulnInfo(AddStringToPool("Format strings should be constant."), 
			   AddStringToPool("Don't use variable format strings."), 
			   S_MOST_RISKY,
			   0, v->id, 0);
    }
    else {
      new_v = new VulnInfo(proto_info->desc, proto_info->solution, S_MOST_RISKY,
			   0 /* ignored */, v->id, 0);
    }
  }
  else {
    new_v = v;
  }
  ConditionalAdd(source, t?(t->GetLineNo()):-1, S_MOST_RISKY, new_v, 
		 construct_new_vulninfo /* If 1, not the default explanation. */);
}

void FindNextArgument(TokenContainer *tc, int &i)
{
  int nesting = 0;
  char *repr;
  Token *tok;
  while((tok = tc->GetToken(i++)))
    {
      switch(tok->GetTokenType())
	{
	case OPERATOR:
	  repr = ((OperatorTok *)tok)->GetOperatorName();
	  if(!strcmp(repr, "("))
	    {
	      nesting++;
	      continue;
	    }
	  if(!strcmp(repr, ")"))
	    {
	      if(!nesting--) 
		return;
	      continue;
	    }
	  if(!nesting && !strcmp(repr, ","))
	    {
	      return;
	    }
	default:
	  continue;
	}
    }
  i = -1;
  return;
}

void DefaultHandler(VulnInfo *v, TokenContainer *tc, int i, char *source)
{
  
  /* NOTE: i points to the token AFTER the identifier, because handlers
   * generally check the stuff after the identifier.  So when we report
   * the line number, ask the original token, just in case the next token
   * is on the next line.
   */
  IdTok *tok = (IdTok *)tc->GetToken(i-1);
  // Assert t->GetTokenType() == IDENTIFIER
  
  ConditionalAdd(source, tok?(tok->GetLineNo()):-1, v->severity, v);
}

/* Handle functions of the form f(arg1, arg2), where arg2 being a string
 * constant more or less renders the call harmless from a breakin pov.
 * TODO: Should actually check that the first token is a left paren!
 */
void StrcpyHandler(VulnInfo *v, TokenContainer *tc, int i, char *source)
{
  int original_i = i;
  Token *tok = tc->GetToken(i++);
  if(!tok) 
    {
    default_handler:
      DefaultHandler(v, tc, original_i, source);
      return;
    }
  FindNextArgument(tc, i);
  if(i<0){goto default_handler;}
  tok = tc->GetToken(i);
  if(tok->GetTokenType() != STRING){goto default_handler;}

  // You can still overflow the buffer, but you really won't have to
  // worry about a breakin, etc. except in very rare cases.
  ConditionalAdd(source, tc->GetToken(original_i)->GetLineNo(),S_NO_RISK,v);
}

/* Scan the format string if we can find it (it's the 2nd arg).  If there's
 * no %s, then we're not too worried, unless there aren't quotes, in which
 * case we're very worried.
 */
void SprintfHandler(VulnInfo *v, TokenContainer *tc, int i, char *source)
{
  int original_i = i;
  Token *tok = tc->GetToken(i++);
  if(!tok) 
    {
    default_handler:
      DefaultHandler(v, tc, original_i, source);
      return;
    }
  FindNextArgument(tc, i);
  if(i<0){goto default_handler;}
  tok = tc->GetToken(i);
  if(tok->GetTokenType() != STRING){
    ReportFormatAttack(v, tok, source);
    return;
  }
  char *s = ((StringTok*)tok)->GetContents();
  while((s=strchr(s,'%')))
    {
      if(*(++s) == 's')
	goto default_handler;
    }
  // We can't make it NO risk, because the s might not follow the 
  // %, but it usually does.  And when it doesn't, people are often
  // using the precision modifiers, which also helps avoid the problem.
  ConditionalAdd(source, tc->GetToken(original_i)->GetLineNo(),S_LOW_RISK,v);
}

/* Scan the format string if we can find it (it's the 3rd arg on Linux at least)
 * If there's no %s, then we're not too worried (tho we weren't worried 
 * before either).
 */
void SnprintfHandler(VulnInfo *v, TokenContainer *tc, int i, char *source)
{
  int original_i = i;
  Token *tok = tc->GetToken(i++);
  if(!tok) 
    {
    default_handler:
      DefaultHandler(v, tc, original_i, source);
      return;
    }
  FindNextArgument(tc, i);
  if(i<0){goto default_handler;}
  FindNextArgument(tc, i);
  if(i<0){goto default_handler;}
  tok = tc->GetToken(i);
  if(tok->GetTokenType() != STRING){
    ReportFormatAttack(v, tok, source);
    return;
  }
  char *s = ((StringTok*)tok)->GetContents();
  while((s=strchr(s,'%')))
    {
      if(*(++s) == 's')
	  goto default_handler;
    }
  // We can't make it NO risk, because the s might not follow the 
  // %, but it usually does.  And when it doesn't, people are often
  // using the precision modifiers, which also helps avoid the problem.
  ConditionalAdd(source, tc->GetToken(original_i)->GetLineNo(),S_LOW_RISK,v);
}

void ScanfHandler(VulnInfo *v, TokenContainer *tc, int i, char *source)
{
  int original_i = i;
  Token *tok = tc->GetToken(i++);
  if(!tok) 
    {
    default_handler:
      DefaultHandler(v, tc, original_i, source);
      return;
    }
  tok = tc->GetToken(i);
  if(!tok || (tok->GetTokenType() != STRING)){goto default_handler;}
  char *s = ((StringTok*)tok)->GetContents();
  while((s=strchr(s,'%')))
    {
      if(*(++s) == 's')
	goto default_handler;
    }
  ConditionalAdd(source, tc->GetToken(original_i)->GetLineNo(),S_NO_RISK,v);
}

// This is exactly the same as the Sprintf handler, but if we don't see
// a %s, I'm willing to classify it as no risk.  Might not be the best
// idea...
void SscanfHandler(VulnInfo *v, TokenContainer *tc, int i, char *source)
{
  int original_i = i;
  Token *tok = tc->GetToken(i++);
  if(!tok) 
    {
    default_handler:
      DefaultHandler(v, tc, original_i, source);
      return;
    }
  FindNextArgument(tc, i);
  if(i<0){goto default_handler;}
  tok = tc->GetToken(i);
  if(tok->GetTokenType() != STRING){goto default_handler;}
  char *s = ((StringTok*)tok)->GetContents();
  while((s=strchr(s,'%')))
    {
      if(*(++s) == 's')
	goto default_handler;
    }

  ConditionalAdd(source, tc->GetToken(original_i)->GetLineNo(),S_NO_RISK,v);
}

void FprintfHandler(VulnInfo *v, TokenContainer *tc, int i, char *source) {
  int original_i = i;
  Token *tok = tc->GetToken(i++);
  if(!tok) {
  default_handler:
    DefaultHandler(v, tc, original_i, source);
    return;
  }
  FindNextArgument(tc, i);
  if(i<0) goto default_handler;
  FindNextArgument(tc, i);
  if(i<0) goto default_handler;
  tok = tc->GetToken(i);
  if(tok->GetTokenType() != STRING) {
    ReportFormatAttack(v, tok, source, 0);
    return;
  }
  goto default_handler;
}

void PrintfHandler(VulnInfo *v, TokenContainer *tc, int i, char *source) {
  int original_i = i;
  Token *tok = tc->GetToken(i++);
  if(!tok) {
  default_handler:
    DefaultHandler(v, tc, original_i, source);
    return;
  }
  FindNextArgument(tc, i);
  if(i<0) goto default_handler;
  tok = tc->GetToken(i);
  if(tok->GetTokenType() != STRING) {
    ReportFormatAttack(v, tok, source, 0);
    return;
  }
  goto default_handler;
}

void SyslogHandler(VulnInfo *v, TokenContainer *tc, int i, char *source) {
  int original_i = i;
  Token *tok = tc->GetToken(i++);
  if(!tok) {
  default_handler:
    DefaultHandler(v, tc, original_i, source);
    return;
  }
  FindNextArgument(tc, i);
  if(i<0) goto default_handler;
  FindNextArgument(tc, i);
  if(i<0) goto default_handler;
  tok = tc->GetToken(i);
  if(tok->GetTokenType() != STRING) {
    ReportFormatAttack(v, tok, source);
    return;
  }
  goto default_handler;
}

static Dictionary<TTBucket> *toctou_info;

char *GrabSomeMem(int x) { return new char[x]; }

void Generic_TOCTOU_Handler(VulnInfo *v, TokenContainer *tc, int i, 
			    char *source, int which)
{
  int original_i = i;
  int line_no;
  // assert which in [0,1,2]
  Token *tok = tc->GetToken(i-1);
  line_no = tok->GetLineNo();
  tok = tc->GetToken(i++);
  // If the next thing isn't a left paren
  if(!tok || (tok->GetTokenType() != OPERATOR) || 
     strcmp(((OperatorTok *)tok)->GetOperatorName(), "("))
    {
    default_handler:
      DefaultHandler(v, tc, original_i, source);
      return;
    }
  int j = i;
  FindNextArgument(tc, j);
  j--;
  char *arg_repr = 0;
  int  arg_size = 0;
  if((i>=j) || (j==-1)) goto default_handler;
  for(;i<j;i++)
    {
      tok = tc->GetToken(i);
      char *varname = 0;
      varname = tok->GetValue();
      int x = strlen(varname);
      char *tmp = new char[arg_size+x+2];
      arg_size += x+1;
      if(arg_repr)
	{
	  // This will add a shitload of extra spaces, but hey.
	  sprintf(tmp, "%s%s", arg_repr, varname); // ITS4: ignore sprintf
	  delete[] arg_repr;
	  arg_repr = tmp;
	}
      else
	{
	  sprintf(tmp, "%s", varname); // ITS4: ignore sprintf
	  arg_repr = tmp;
	}
      if(tok->AllocedValue()) delete[] varname;
    }
  TTSite *t = new TTSite(v, source, line_no);
  short error = 0;
  TTBucket *b = toctou_info->GetItem(arg_repr, error);
  if(!b)
    {
      b = new TTBucket();
      if(!b) OutOfMemory();
      toctou_info->SetItem(arg_repr, b);
    }
  else
    {
      delete[] arg_repr;
    }
  t->next = 0;
  if(!b->calls[which])
    {
      b->calls[which] = b->ends[which] = t;
      b->num[which] = 1;
    }
  else
    {
      b->ends[which]->next = t;
      b->ends[which] = t;
      b->num[which]++;
    }
}
    
void TOCTOU_A_Handler(VulnInfo *v, TokenContainer *tc, int i, char *src)
{
  Generic_TOCTOU_Handler(v,tc,i,src,0);
}

void TOCTOU_B_Handler(VulnInfo *v, TokenContainer *tc, int i, char *src)
{
  Generic_TOCTOU_Handler(v,tc,i,src,1);
}

void TOCTOU_C_Handler(VulnInfo *v, TokenContainer *tc, int i, char *src)
{
  DefaultHandler(v,tc,i,src);
}

void RunTOCTOUScan()
{
  int numkeys = toctou_info->GetNumKeys();
  int n;
  char **varnames = new char* [numkeys];
  if(!varnames)
    OutOfMemory();
  toctou_info->GetKeys(varnames);
  for(int i = 0; i<numkeys; i++)
    {
      short error;

      TTBucket *b = toctou_info->GetItem(varnames[i], error);
      TTSite *first_t = b->calls[0];
      // This being true means there is no increase in severity
      // but we use first_t to report all the stuff we didn't report eariler
      if(!first_t)first_t = b->calls[1];

      // assert b != NULL
      if((b->num[0] > 1) || (b->num[0] && b->num[1]))
	{
	  const char *fnamefmt = (GetMSVSFormat() ? "%s(%d) " : "%s:%d: ");
	  TTSite *cur = b->calls[0];
	  n = strlen(fnamefmt)+strlen(cur->source_file)+
	                         3*sizeof(cur->line);
	  char *buf1  = new char[n];
	  if(!buf1) OutOfMemory();
	  /* ITS4: ignore */
	  sprintf(buf1, fnamefmt, cur->source_file, cur->line);
	  const char *fmt = (GetMSVSFormat() ? 
		  "Potential race condition on: %s\n  Points of concern are:\n    %s: %s" :
	      "Potential race condition on: %s\nPoints of concern are:\n%s%s");
	  char *funcname = GetNameById(cur->vuln->id);
	  n = strlen(fmt)+2*strlen(buf1)+
				  strlen(varnames[i])+strlen(funcname);
	  char *buf2   = new char[n];
	  if(!buf2) OutOfMemory();
	  /* ITS4: ignore */
	  sprintf(buf2, fmt, varnames[i], buf1, funcname);
	  delete[] buf1;
	  fmt = (GetMSVSFormat() ? "%s\n    %s(%d) : %s\n  Advice" :
	      "%s\n%s:%d: %s");
	  cur = cur->next;
	  for(int j=1;j<b->num[0];j++)
	    {
	      funcname = GetNameById(cur->vuln->id);
	  n = strlen(fmt)+strlen(cur->source_file)+
	                      strlen(funcname)+3*sizeof(cur->line)+
     	                      strlen(buf2);
	  buf1 = new char[n];
	      if(!buf1)
		OutOfMemory();
	      /* ITS4: ignore */
	      sprintf(buf1, fmt, buf2, cur->source_file, cur->line,
		      funcname);
	      delete[] buf2;
	      buf2 = buf1;
	      cur = cur->next;
	    }
	  cur = b->calls[1];
	  for(int k=0;k<b->num[1];k++)
	    {
	      funcname = GetNameById(cur->vuln->id);
	      n = strlen(fmt)+strlen(cur->source_file)+
	                      strlen(funcname)+3*sizeof(cur->line)+
			      strlen(buf2);
	      buf1 = new char[n]; 
	      if(!buf1)
		OutOfMemory();
	      /* ITS4: ignore */
	      sprintf(buf1, fmt, buf2, cur->source_file, cur->line,
		      funcname);
	      delete[] buf2;
	      buf2 = buf1;
	      cur = cur->next;
	    }
	  VulnInfo *new_v = new VulnInfo(AddStringToPool(buf2), 
					 first_t->vuln->solution, 
					 S_VERY_RISKY, 0, 
					 first_t->vuln->id, 0);
	  AddResult(first_t->source_file,first_t->line, S_VERY_RISKY, new_v);
	  delete[] buf2;
	}
      else  // give the standard warning for anything that is there.
	{
	  TTSite *cur = first_t;
	  while(cur)
	    {
	      AddResult(cur->source_file, cur->line, 
			cur->vuln->severity, cur->vuln);
	      cur = cur->next;
	    }
	}
    }
  delete[] varnames;
}

void DoPostProcessing()
{
  RunTOCTOUScan();
}

void InitHandlers()
{
  toctou_info = new Dictionary<TTBucket>(5);
  if(!toctou_info)
    OutOfMemory();
}
