| | from typing import Annotated, Literal |
| | from typing_extensions import TypedDict |
| |
|
| | from langgraph.graph import StateGraph, MessagesState, START, END |
| | from langchain_core.messages import AIMessage, HumanMessage, SystemMessage |
| | from langchain_core.output_parsers import JsonOutputParser |
| | from langchain_community.document_transformers import BeautifulSoupTransformer, beautiful_soup_transformer |
| |
|
| | from langgraph.types import Command |
| |
|
| | from langchain_groq import ChatGroq |
| |
|
| | import operator |
| | import pprint |
| | import os |
| | import requests |
| | import html2text |
| |
|
| | API_KEY = os.getenv("GROQ_API_KEY") |
| | OUT_RES = "<|FINISHED|>" |
| |
|
| | HTML_TRANSFORMER = html2text.HTML2Text() |
| | HTML_TRANSFORMER.ignore_links = True |
| | HTML_TRANSFORMER.ignore_images = True |
| |
|
| | BS_TRANSFORMER = BeautifulSoupTransformer() |
| |
|
| |
|
| | def local_message_add(dict1, dict2): |
| | key2 = list(dict2.keys())[0] |
| | if key2 not in dict1: |
| | dict1[key2] = dict2[key2] |
| | else: |
| | dict1[key2] = dict1[key2] + dict2[key2] |
| | return dict1 |
| |
|
| | def variable_state_update(dict1, dict2): |
| | dict1.update(dict2) |
| | return dict1 |
| |
|
| | class GeneralStates(TypedDict): |
| | messages: Annotated[list[dict[str, str]], lambda x,y:x+y] |
| | checkpoints: dict[str,list] |
| | local_messages: Annotated[dict, local_message_add] |
| | variables: Annotated[dict, variable_state_update] |
| |
|
| |
|
| | def format_sequence(seq, nested=False): |
| | if isinstance(seq, (list, tuple, set, frozenset, dict)): |
| | |
| | if isinstance(seq, dict): |
| | return format_dict(seq, nested=nested) |
| | |
| | else: |
| | return format_list_like(seq, nested=nested) |
| | else: |
| | return seq |
| | |
| | |
| |
|
| | def format_dict(d, nested=False): |
| | |
| | items = [] |
| | for i, (key, value) in enumerate(d.items()): |
| | if isinstance(value, (list, tuple, set, frozenset, dict)): |
| | value = format_sequence(value, nested=True) |
| | if not nested: |
| | items.append(f"{i+1}. {key}: {value}") |
| | else: |
| | items.append(f"{key}: {value}") |
| | return ",\n".join(items) |
| |
|
| | def format_list_like(seq, nested=False): |
| | |
| | items = [] |
| | for i,item in enumerate(seq): |
| | if isinstance(item, (list, tuple, set, frozenset, dict)): |
| | item = format_sequence(item, nested=True) |
| | if not nested: |
| | items.append(f"{i+1}. {item}") |
| | else: |
| | items.append(str(item)) |
| | return ",\n".join(items) |
| |
|
| |
|
| | def format_dict_api(input_dict, combined): |
| | formatted_dict = {} |
| | for key, value in input_dict.items(): |
| | if isinstance(value, dict): |
| | formatted_dict[key] = format_dict_api(value, combined) |
| | elif isinstance(value, str): |
| | |
| | formatted_dict[key] = value.format(**combined) |
| | |
| | |
| | |
| |
|
| | else: |
| | formatted_dict[key] = value |
| |
|
| | return formatted_dict |
| |
|
| |
|
| | def run_api(api_endpoints, variables, response, input_message, chain_id): |
| | if not api_endpoints: |
| | return {} |
| | combined = variables.copy() |
| | if response: |
| | api_endpoint_type = "output" |
| | if isinstance(response, dict): |
| | combined = combined | response |
| | |
| | else: |
| | combined["output_message"] = response |
| | else: |
| | api_endpoint_type = "input" |
| |
|
| | combined["input_message"] = input_message |
| | resp = [] |
| | errors = [] |
| | for x in api_endpoints: |
| | try: |
| | input_var = {inp: combined[inp] for inp in x["input_variables"]} |
| | res = requests.request( |
| | x['method'], |
| | x['url'], |
| | headers = format_dict_api(x['headers'], input_var) if x["headers"] else None, |
| | params = format_dict_api(x["params"], input_var) if x["params"] else None, |
| | json = format_dict_api(x["request_body"], input_var) if x["request_body"] else None, |
| | ) |
| |
|
| | if x['response_type'] == 'json': |
| | res = res.json() |
| | else: |
| | res = res.text |
| | |
| | if res[:15] == "<!DOCTYPE html>": |
| | if x["html_to_markdown"]: |
| | res = HTML_TRANSFORMER.handle(res) |
| | elif x["html_tags_to_extract"]: |
| | res = BS_TRANSFORMER.extract_tags(res, tags=x["html_tags_to_extract"]) |
| | resp.append([res, x["name"]]) |
| | except Exception as e: |
| | errors.append([e, x["name"]]) |
| |
|
| | api_dict = {} |
| | |
| | for x in resp: |
| | |
| | api_dict[f"{api_endpoint_type}_{x[1]}_{chain_id}_success"] = x[0] |
| | for x in errors: |
| | |
| | api_dict[f"{api_endpoint_type}_{x[1]}_{chain_id}_error"] = x[0] |
| | variables.update(api_dict) |
| | return api_dict |
| |
|
| |
|
| | def agent_builder(states: GeneralStates, chain: dict, row:int, depth: int): |
| | |
| | |
| | model_config = chain.get("agent") |
| | print("[MODEL CONFIG]", model_config) |
| | child = chain.get("child") |
| | checkpoints = states.get("checkpoints", {}) |
| |
|
| | print("[STATES]", states) |
| |
|
| | for k,v in checkpoints.items(): |
| | |
| | if k == chain["id"]: |
| | return Command(goto=v) |
| |
|
| | api_dict = {"variables":{}} |
| |
|
| | variables = states.get("variables", {}) |
| | variables["input_message"] = states["messages"][-1].content |
| |
|
| | |
| | |
| |
|
| | api_res = run_api(model_config["input_api_endpoints"], variables, None, states["messages"][-1].content, chain["id"]) |
| | api_dict["variables"].update(api_res) |
| |
|
| | for c in child: |
| | if c["condition_from"] == "input" and states['messages'][-1].content.strip() == c["condition"]: |
| | redirect_agent_message = AIMessage(f"Switch to Agent {c['id']}") |
| |
|
| | |
| | local_message = states["local_messages"].get(chain["id"]) |
| | if local_message: |
| | update_dict = { |
| | "local_messages": { |
| | |
| | chain["id"]:[redirect_agent_message], |
| | c["id"]:[states['messages'][-1]] |
| | }, |
| | } |
| | else: |
| | update_dict = {} |
| |
|
| | if c.get("checkpoint"): |
| | |
| | update_dict["checkpoints"] = {chain["id"]:c["id"]} |
| |
|
| | return Command(goto=c["id"], update=update_dict | api_dict) |
| |
|
| | |
| | messages = states["local_messages"].get(chain["id"]) |
| |
|
| | if messages: |
| | messages.append(states["messages"][-1]) |
| | else: |
| | messages = states["messages"] |
| |
|
| | input_var = model_config.get("input_variables") |
| | output_variables = model_config.get("output_variables") |
| |
|
| |
|
| | if model_config.get("is_template"): |
| | response = model_config.get("prompt") |
| | if input_var: |
| | |
| | response = response.format(**{var: variables[var] for var in input_var}) |
| | response = AIMessage(response) |
| |
|
| | api_res = run_api(model_config["output_api_endpoints"], variables, response.content, messages[-1].content, chain["id"]) |
| | api_dict["variables"].update(api_res) |
| | |
| |
|
| | if output_variables: |
| | out = {out_var: response.content for out_var in output_variables} |
| | if "messages" not in output_variables: |
| | api_dict["variables"].update(out) |
| | return api_dict |
| | else: |
| | out.pop("messages") |
| | api_dict["variables"].update(out) |
| |
|
| | return {"messages":[response], "local_messages":{chain["id"]:[response]}} | api_dict |
| |
|
| | def run_agent(i, loop_input_variables, variables): |
| | if input_var: |
| | print("[AGENT ID]", chain['id']) |
| | print("[INPUT VARIABLES]", input_var) |
| | print("[VARIABLES]", variables) |
| | user_input = "\n".join([str(variables[var]) for var in input_var]) |
| | |
| | if i == -1: |
| | prompt = model_config.get("prompt").format(**{var: variables[var] for var in input_var}) |
| | else: |
| | prompt = model_config.get("prompt").format( |
| | **{var: variables[var][i] if var in loop_input_variables else variables[var] for var in input_var} |
| | ) |
| | else: |
| | user_input = messages[-1].content |
| | prompt = model_config.get("prompt") + "\n\n" + messages[-1].content |
| |
|
| |
|
| | model = ChatGroq( |
| | |
| | |
| | |
| | |
| | model="llama-3.3-70b-versatile", |
| | |
| | temperature=model_config.get("creativity"), |
| | max_tokens=None, |
| | timeout=None, |
| | max_retries=2, |
| | api_key=API_KEY |
| | ) |
| |
|
| | routes = model_config.get("routes") |
| | output_collector = model_config.get("output_collector") |
| |
|
| | if routes: |
| | add_prompt = f"YOU MUST GENERATE OUTPUT STRICTLY one of the following list : [{', '.join(routes)}]\n\n" |
| | if model_config.get("routes_description"): |
| | add_prompt += "HERE IS THE CONDITIONS FOR EACH OUTPUT:\n" |
| | add_prompt += "\n".join([f"{x}: {y}" for x,y in zip(routes, model_config.get("routes_description"))]) |
| | add_prompt += "\n\n" |
| |
|
| | prompt = add_prompt + prompt |
| | elif output_collector: |
| | |
| | add_prompt = f"YOU MUST GENERATE OUTPUT STRICTLY IN THE FOLLOWING JSON FORMAT, REMEMBER TO ADD {{}} BEFORE AND AFTER JSON CODE:\n" |
| | |
| | add_prompt += "\n".join(output_collector) |
| | add_prompt += "\n\n" |
| |
|
| | prompt = prompt +"\n\n"+ add_prompt |
| |
|
| | response = (model | JsonOutputParser()).invoke(messages[:-1] + [HumanMessage(content=prompt)]) |
| |
|
| | if output_variables: |
| | for k in response.keys(): |
| | if k not in output_variables: |
| | del response[k] |
| |
|
| | api_res = run_api(model_config["output_api_endpoints"], variables, response, messages[-1].content, chain["id"]) |
| |
|
| | api_dict["variables"].update(api_res) |
| |
|
| | return {"variables":response | api_dict["variables"]} |
| |
|
| | response = model.invoke(messages[:-1] + [HumanMessage(content=prompt)]) |
| |
|
| | for c in child: |
| | if c["condition_from"] == "output" and response.content.strip() == c["condition"]: |
| | redirect_agent_message = AIMessage(f"Switch to Agent {c['id']}") |
| |
|
| | |
| | local_message = states["local_messages"].get(chain["id"]) |
| | if local_message: |
| | update_dict = { |
| | "local_messages": { |
| | |
| | chain["id"]:[redirect_agent_message], |
| | c["id"]:[HumanMessage(user_input)] |
| | }, |
| | } |
| | else: |
| | update_dict = {} |
| |
|
| | if c.get("checkpoint"): |
| | |
| | update_dict["checkpoints"] = {chain["id"]:c["id"]} |
| |
|
| | api_res = run_api(model_config["output_api_endpoints"], variables, response.content, messages[-1].content, chain["id"]) |
| | api_dict["variables"].update(api_res) |
| |
|
| | |
| | |
| | |
| | if output_variables: |
| | out = {out_var: response.content for out_var in output_variables} |
| | if "messages" not in output_variables: |
| | api_dict["variables"].update(out) |
| | return api_dict |
| | else: |
| | api_dict["messages"] = out.pop("messages") |
| | api_dict["variables"].update(out) |
| |
|
| | return Command(goto=c["id"], update=update_dict | api_dict) |
| | elif response.content.strip() == OUT_RES: |
| |
|
| | api_res = run_api(model_config["output_api_endpoints"], variables, None, messages[-1].content, chain["id"]) |
| | api_dict["variables"].update(api_res) |
| |
|
| | return {} | api_dict |
| |
|
| | api_res = run_api(model_config["output_api_endpoints"], variables, response.content, messages[-1].content, chain["id"]) |
| | api_dict["variables"].update(api_res) |
| |
|
| | if output_variables: |
| | out = {out_var: response.content for out_var in output_variables} |
| | if "messages" not in output_variables: |
| | api_dict["variables"].update(out) |
| | return api_dict |
| | else: |
| | out.pop("messages") |
| | api_dict["variables"].update(out) |
| |
|
| |
|
| | |
| | return {"messages":[response], "local_messages":{chain["id"]:[response]}} | api_dict |
| |
|
| | if not chain["loop_input_variables"]: |
| | return run_agent(-1, [], variables) |
| | else: |
| | max_loop = min([len(states["variables"].get(x)) for x in chain["loop_input_variables"]]) |
| |
|
| | updates = {"variables":{}} |
| |
|
| | for i in range(max_loop): |
| | out_variables = run_agent(i, chain["loop_input_variables"], variables) |
| |
|
| | if type(out_variables) == dict: |
| | if not out_variables.get("variables"): |
| | continue |
| | for k,v in out_variables["variables"].items(): |
| | if k not in updates["variables"].keys(): |
| | updates["variables"][k] = [] |
| | if type(v) == list: |
| | updates["variables"][k] += v |
| | else: |
| | updates["variables"][k].append(v) |
| | else: |
| | updates = out_variables |
| | return updates |
| |
|
| |
|
| | def route(states, routes): |
| | if states["messages"][-1].content.strip() in routes: |
| | return states["messages"][-1].content.strip() |
| | return END |
| |
|
| | def build_chain(chains, checkpointer, parent_name=None, depth=0): |
| | print("[BUILD CHAIN] START....") |
| |
|
| | stack = [(chains, parent_name, depth, 0)] |
| |
|
| | builder = StateGraph(GeneralStates) |
| |
|
| | while stack: |
| | current_chains, current_parent, current_depth, i = stack.pop() |
| | print("STACK", i) |
| |
|
| | if i >= len(current_chains): |
| | continue |
| |
|
| | c = current_chains[i] |
| | c_id = c["id"] |
| |
|
| | |
| |
|
| | try: |
| | print("ADDED NODE!") |
| | builder.add_node( |
| | c_id, |
| | lambda states, c=c, i=i, depth=current_depth: agent_builder(states, c, i, depth) |
| | ) |
| | |
| | except ValueError as e: |
| | print("[ERROR]",e) |
| | pass |
| |
|
| |
|
| | |
| | if i + 1 < len(current_chains): |
| | stack.append((current_chains, current_parent, current_depth, i + 1)) |
| |
|
| | |
| | if c.get("child"): |
| | stack.append(( |
| | c["child"], |
| | c_id, |
| | current_depth + 1, |
| | 0 |
| | )) |
| |
|
| | condition_ids = [] |
| |
|
| | for x in c["child"]: |
| | if x["condition"]: |
| | condition_ids.append(x["id"]) |
| | else: |
| | builder.add_edge(c_id, x["id"]) |
| |
|
| | if condition_ids: |
| | builder.add_conditional_edges( |
| | c_id, |
| | lambda states: route(states, condition_ids), path_map=condition_ids + [END] |
| | ) |
| | else: |
| | builder.add_edge( |
| | c_id, |
| | END |
| | ) |
| |
|
| | print("SET STARTING POINTS") |
| |
|
| | for start_point in chains: |
| | builder.add_edge(START, start_point["id"]) |
| | print("[NODES]", builder.nodes) |
| | print("[EDGES]", builder.edges) |
| | graph = builder.compile(checkpointer=checkpointer) |
| | return graph |
| |
|