#!/usr/bin/python
import argparse
from collections import namedtuple
from datetime import datetime
import requests
import os
import sys
import time

from textual.app import App, ComposeResult
from textual.widgets import DataTable, Label

# Rapid API key must be set in the environment.
RAPIDAPI_KEY = os.environ.get('RAPIDAPI_KEY')

# The file where the application's state is persistent.
STATE_FILE = "state.txt"

# API endpoint.
Endpoint = namedtuple('Endpoint', ['url', 'headers', 'update_delay'])

# API endpoints.
ENDPOINTS = {
    'stock': Endpoint(url="https://realstonks.p.rapidapi.com/", headers={
        "X-RapidAPI-Key": f"{RAPIDAPI_KEY}",
        "X-RapidAPI-Host": "realstonks.p.rapidapi.com"
    }, update_delay=5 * 60),  # 5 minutes
    'currency': Endpoint(
        url="https://exchange-rate-api1.p.rapidapi.com/convert", headers={
            "X-RapidAPI-Key": f"{RAPIDAPI_KEY}",
            "X-RapidAPI-Host": "realstonks.p.rapidapi.com"
        }, update_delay=60 * 60)  # 1 hour
}

# Application state.
State = namedtuple('State', ['stocks', 'exchanges'])

# Stock quote.
Stock = namedtuple('Stock',
                   ['sticker', 'price', 'change_point', 'change_percent'])

# Exchange rate.
Exchange = namedtuple('Exchange', ['source', 'target', 'rate'])


def get_stock(stickers: list[str]) -> list[Stock]:
    """Query the stock prices for the given stickers.

    The result may not have prices for all the input stickers if some of the
    queries fail. This function attempts to get as many prices as possible such
    that failure in a query does not preclude other stocks from being queried.
    """
    # This API does not allow querying multiple stickers in a single request.
    # Free tier: 100,000 requests/month.
    #
    # Make sure that a request failure does not preclude from getting other
    # stocks.
    #
    # Example response:
    # {
    #   "price": 466.4,
    #   "change_point": 7.4,
    #   "change_percentage": 1.61,
    #   "total_vol": "11.29M"
    # }
    stocks = []
    for sticker in stickers:
        try:
            endpoint = ENDPOINTS['stock']
            response = requests.get(f"{endpoint.url}{sticker}",
                                    headers=endpoint.headers).json()
            stocks.append(
                Stock(sticker, float(response['price']),
                      float(response['change_point']),
                      float(response['change_percentage'])))
        except Exception as e:
            print(e)
    return stocks


def get_exchange_rate(source: str, target: str) -> float:
    """Get the exchange rate between two currencies. Return 0 on failure."""
    # Free tier:
    #
    # Example response:
    # {
    #   "code": "0",
    #   "msg": "success",
    #   "convert_result": {
    #     "base": "USD",
    #     "target": "EUR",
    #     "rate": 0.9063
    #   },
    #   "time_update": {
    #     "time_unix": 1690556940,
    #     "time_utc": "2023-07-28T08:09:00Z",
    #     "time_zone": "America/Los_Angeles"
    #   }
    # }
    try:
        query = {"base": source, "target": target}
        endpoint = ENDPOINTS['currency']
        response = requests.get(endpoint.url, headers=endpoint.headers,
                                params=query).json()
        return float(response['convert_result']['rate'])
    except Exception as e:
        print(e)
    return 0.0


def update_stocks(state: State) -> State:
    stickers = [stock.sticker for stock in state.stocks]
    updated_stocks = get_stock(stickers)
    # Note that updated_stocks may not have all the stocks in the input.
    updated_stocks_stickers = [stock.sticker for stock in updated_stocks]
    missing_stocks = [stock for stock in state.stocks if
                      stock.sticker not in updated_stocks_stickers]
    stocks = updated_stocks + missing_stocks
    return State(stocks, state.exchanges)


def update_exchanges(state: State) -> State:
    exchanges = []
    for exchange in state.exchanges:
        rate = get_exchange_rate(exchange.source, exchange.target)
        if rate != 0:
            exchanges.append(Exchange(exchange.source, exchange.target, rate))
        else:
            exchanges.append(exchange)
    return State(state.stocks, exchanges)


def format_delta(delta: float, percent: bool = False) -> str:
    sign = "+" if delta >= 0 else "-"
    change = f"{sign}{abs(delta)}{'%' if percent else ''}"
    return change


def format_exchange_name(exchange: Exchange) -> str:
    return f"{exchange.source}/{exchange.target}"


def load_state(filepath: str) -> State:
    stocks = []
    exchanges = []

    lines = []
    with open(filepath, 'r') as file:
        lines = file.readlines()

    for line in lines:
        values = line.split(' ')
        key = values[0]
        if '/' in key:
            source, target = key.split('/')
            rate = float(values[1])
            exchanges.append(Exchange(source, target, rate))
        else:
            sticker = key
            price = float(values[1])
            change_point = float(values[2])
            change_percent = float(values[3])
            stocks.append(
                Stock(sticker, price, change_point, change_percent))

    return State(stocks, exchanges)


def save_state(state: State, filepath: str):
    with open(filepath, 'w') as file:
        for stock in state.stocks:
            values = [str(x) for x in list(stock)]
            file.write(f"{' '.join(values)}\n")

        for exchange in state.exchanges:
            file.write(f"{format_exchange_name(exchange)} {exchange.rate}\n")


class Updater:
    def __init__(self, update, delay):
        self.update = update
        self.delay = delay
        self.last_update_time = 0


def update_stub(msg: str, state: State) -> State:
    print(msg)
    return state


def make_updaters(use_stubs: bool) -> list[Updater]:
    updaters = []
    if use_stubs:
        updaters = [
            Updater(lambda s: update_stub("Update stocks", s), 1),
            Updater(lambda s: update_stub("Update exchange", s), 5)
        ]
    else:
        updaters = [
            Updater(update_stocks, ENDPOINTS['stock'].update_delay),
            Updater(update_exchanges, ENDPOINTS['currency'].update_delay)
        ]
    return updaters


def update_state(t: float, updaters: list[Updater], state: State) -> State:
    for updater in updaters:
        if t - updater.last_update_time >= updater.delay:
            state = updater.update(state)
            updater.last_update_time = t
    return state


class MarketApp(App):
    TITLE = "Market Watch"
    CSS = """
    #footer {
        dock: bottom;
    }
    """

    def __init__(self, updaters: list[Updater]):
        super().__init__()
        self.state = None
        self.table = None
        self.updaters = updaters
        self.min_update_delay = min([updater.delay for updater in updaters])

    def compose(self) -> ComposeResult:
        table = DataTable()
        table.show_cursor = False
        self.table = table
        yield table

        footer = Label(id="footer")
        self.footer = footer
        yield footer

    def on_mount(self):
        self.state = load_state(STATE_FILE)
        self.update()
        self.render()
        self.set_interval(self.min_update_delay, self.update)

    def render(self):
        assert self.state is not None
        assert self.table is not None
        assert self.footer is not None

        update_time = datetime.now()
        self.footer.update(f"Last update: {update_time}")

        #  Stock/ex | Price | Change
        #   xyz     |  xxx  |  xxx
        #  usd/eur  |  xxx  | <empty>
        table = self.table
        table.clear(columns=True)
        table.add_columns("Stock", "Price($)", "Change($)", "%")
        for stock in self.state.stocks:
            table.add_row(stock.sticker, stock.price,
                          format_delta(stock.change_point),
                          format_delta(stock.change_percent, percent=True))
        for exchange in self.state.exchanges:
            table.add_row(format_exchange_name(exchange), exchange.rate, "", "")

    def update(self) -> None:
        t = time.time()
        self.state = update_state(t, self.updaters, self.state)
        self.render()
        save_state(self.state, STATE_FILE)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--stub", action='store_true',
                        help="Use stub update functions")
    args = parser.parse_args()

    updaters = make_updaters(args.stub)

    app = MarketApp(updaters)
    app.run()


if __name__ == '__main__':
    sys.exit(main())